diff --git a/cognee/api/v1/search/search.py b/cognee/api/v1/search/search.py index 8afd8545c..c02b938e8 100644 --- a/cognee/api/v1/search/search.py +++ b/cognee/api/v1/search/search.py @@ -1,5 +1,6 @@ -from typing import Union +from typing import Union, Optional, Type, List +from cognee.infrastructure.engine.models.DataPoint import DataPoint from cognee.modules.search.types import SearchType from cognee.modules.users.exceptions import UserNotFoundError from cognee.modules.users.models import User @@ -13,6 +14,9 @@ async def search( user: User = None, datasets: Union[list[str], str, None] = None, system_prompt_path: str = "answer_simple_question.txt", + top_k: int = 10, + node_type: Optional[Type] = None, + node_name: List[Optional[str]] = None, ) -> list: # We use lists from now on for datasets if isinstance(datasets, str): @@ -25,7 +29,14 @@ async def search( raise UserNotFoundError filtered_search_results = await search_function( - query_text, query_type, datasets, user, system_prompt_path=system_prompt_path + query_text, + query_type, + datasets, + user, + system_prompt_path=system_prompt_path, + top_k=top_k, + node_type=node_type, + node_name=node_name, ) return filtered_search_results diff --git a/cognee/infrastructure/databases/graph/graph_db_interface.py b/cognee/infrastructure/databases/graph/graph_db_interface.py index 348bcdf91..85b4f0f8f 100644 --- a/cognee/infrastructure/databases/graph/graph_db_interface.py +++ b/cognee/infrastructure/databases/graph/graph_db_interface.py @@ -1,4 +1,4 @@ -from typing import Protocol, Optional, Dict, Any +from typing import Protocol, Optional, Dict, Any, Type, List from abc import abstractmethod @@ -51,6 +51,10 @@ class GraphDBInterface(Protocol): ): raise NotImplementedError + @abstractmethod + async def get_subgraph(self, node_type: Type[Any], node_name: List[str]): + raise NotImplementedError + @abstractmethod async def get_graph_data(self): raise NotImplementedError diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py index 1f2e3b65b..cd1c4f88b 100644 --- a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +++ b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py @@ -4,7 +4,7 @@ import json from cognee.shared.logging_utils import get_logger, ERROR import asyncio from textwrap import dedent -from typing import Optional, Any, List, Dict +from typing import Optional, Any, List, Dict, Type, Tuple from contextlib import asynccontextmanager from uuid import UUID from neo4j import AsyncSession @@ -496,6 +496,58 @@ class Neo4jAdapter(GraphDBInterface): return (nodes, edges) + async def get_subgraph( + self, node_type: Type[Any], node_name: List[str] + ) -> Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]]: + label = node_type.__name__ + + query = f""" + UNWIND $names AS wantedName + MATCH (n:`{label}`) + WHERE n.name = wantedName + WITH collect(DISTINCT n) AS primary + + UNWIND primary AS p + OPTIONAL MATCH (p)--(nbr) + WITH primary, collect(DISTINCT nbr) AS nbrs + WITH primary + nbrs AS nodelist + + UNWIND nodelist AS node + WITH collect(DISTINCT node) AS nodes + + MATCH (a)-[r]-(b) + WHERE a IN nodes AND b IN nodes + WITH nodes, collect(DISTINCT r) AS rels + + RETURN + [n IN nodes | + {{ id: n.id, + properties: properties(n) }}] AS rawNodes, + [r IN rels | + {{ type: type(r), + properties: properties(r) }}] AS rawRels + """ + + result = await self.query(query, {"names": node_name}) + if not result: + return [], [] + + raw_nodes = result[0]["rawNodes"] + raw_rels = result[0]["rawRels"] + + nodes = [(n["properties"]["id"], n["properties"]) for n in raw_nodes] + edges = [ + ( + r["properties"]["source_node_id"], + r["properties"]["target_node_id"], + r["type"], + r["properties"], + ) + for r in raw_rels + ] + + return nodes, edges + async def get_filtered_graph_data(self, attribute_filters): """ Fetches nodes and relationships filtered by specified attribute values. diff --git a/cognee/modules/graph/cognee_graph/CogneeGraph.py b/cognee/modules/graph/cognee_graph/CogneeGraph.py index ce227f296..a14cd68a9 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraph.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraph.py @@ -1,5 +1,5 @@ from cognee.shared.logging_utils import get_logger -from typing import List, Dict, Union +from typing import List, Dict, Union, Optional, Type from cognee.exceptions import InvalidValueError from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError @@ -61,12 +61,18 @@ class CogneeGraph(CogneeAbstractGraph): node_dimension=1, edge_dimension=1, memory_fragment_filter=[], + node_type: Optional[Type] = None, + node_name: List[Optional[str]] = None, ) -> None: if node_dimension < 1 or edge_dimension < 1: raise InvalidValueError(message="Dimensions must be positive integers") try: - if len(memory_fragment_filter) == 0: + if node_type is not None and node_name is not None: + nodes_data, edges_data = await adapter.get_subgraph( + node_type=node_type, node_name=node_name + ) + elif len(memory_fragment_filter) == 0: nodes_data, edges_data = await adapter.get_graph_data() else: nodes_data, edges_data = await adapter.get_filtered_graph_data( @@ -74,9 +80,11 @@ class CogneeGraph(CogneeAbstractGraph): ) if not nodes_data: - raise EntityNotFoundError(message="No node data retrieved from the database.") + #:TODO: quick and dirty solution for sf demo, as the list of nodes can be empty + return None if not edges_data: - raise EntityNotFoundError(message="No edge data retrieved from the database.") + #:TODO: quick and dirty solution for sf demo, as the list of edges can be empty + return None for node_id, properties in nodes_data: node_attributes = {key: properties.get(key) for key in node_properties_to_project} diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index 16b22eab7..3994d9da2 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any, Optional, Type, List from collections import Counter import string @@ -19,11 +19,15 @@ class GraphCompletionRetriever(BaseRetriever): user_prompt_path: str = "graph_context_for_question.txt", system_prompt_path: str = "answer_simple_question.txt", top_k: Optional[int] = 5, + node_type: Optional[Type] = None, + node_name: List[Optional[str]] = None, ): """Initialize retriever with prompt paths and search parameters.""" self.user_prompt_path = user_prompt_path self.system_prompt_path = system_prompt_path self.top_k = top_k if top_k is not None else 5 + self.node_type = node_type + self.node_name = node_name def _get_nodes(self, retrieved_edges: list) -> dict: """Creates a dictionary of nodes with their names and content.""" @@ -69,11 +73,16 @@ class GraphCompletionRetriever(BaseRetriever): vector_index_collections.append(f"{subclass.__name__}_{field_name}") found_triplets = await brute_force_triplet_search( - query, top_k=self.top_k, collections=vector_index_collections or None + query, + top_k=self.top_k, + collections=vector_index_collections or None, + node_type=self.node_type, + node_name=self.node_name, ) if len(found_triplets) == 0: - raise NoRelevantDataFound + #:TODO: quick and dirty solution for sf demo, as the triplets can be empty + return [] return found_triplets diff --git a/cognee/modules/retrieval/utils/brute_force_triplet_search.py b/cognee/modules/retrieval/utils/brute_force_triplet_search.py index bef4493b4..10f465474 100644 --- a/cognee/modules/retrieval/utils/brute_force_triplet_search.py +++ b/cognee/modules/retrieval/utils/brute_force_triplet_search.py @@ -1,6 +1,6 @@ import asyncio from cognee.shared.logging_utils import get_logger, ERROR -from typing import List, Optional +from typing import List, Optional, Type from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.vector import get_vector_engine @@ -54,6 +54,8 @@ def format_triplets(edges): async def get_memory_fragment( properties_to_project: Optional[List[str]] = None, + node_type: Optional[Type] = None, + node_name: List[Optional[str]] = None, ) -> CogneeGraph: """Creates and initializes a CogneeGraph memory fragment with optional property projections.""" graph_engine = await get_graph_engine() @@ -66,6 +68,8 @@ async def get_memory_fragment( graph_engine, node_properties_to_project=properties_to_project, edge_properties_to_project=["relationship_name"], + node_type=node_type, + node_name=node_name, ) return memory_fragment @@ -78,6 +82,8 @@ async def brute_force_triplet_search( collections: List[str] = None, properties_to_project: List[str] = None, memory_fragment: Optional[CogneeGraph] = None, + node_type: Optional[Type] = None, + node_name: List[Optional[str]] = None, ) -> list: if user is None: user = await get_default_user() @@ -92,6 +98,8 @@ async def brute_force_triplet_search( collections=collections, properties_to_project=properties_to_project, memory_fragment=memory_fragment, + node_type=node_type, + node_name=node_name, ) return retrieved_results @@ -103,6 +111,8 @@ async def brute_force_search( collections: List[str] = None, properties_to_project: List[str] = None, memory_fragment: Optional[CogneeGraph] = None, + node_type: Optional[Type] = None, + node_name: List[Optional[str]] = None, ) -> list: """ Performs a brute force search to retrieve the top triplets from the graph. @@ -114,6 +124,8 @@ async def brute_force_search( collections (Optional[List[str]]): List of collections to query. properties_to_project (Optional[List[str]]): List of properties to project. memory_fragment (Optional[CogneeGraph]): Existing memory fragment to reuse. + node_type: node type to filter + node_name: node name to filter Returns: list: The top triplet results. @@ -124,7 +136,9 @@ async def brute_force_search( raise ValueError("top_k must be a positive integer.") if memory_fragment is None: - memory_fragment = await get_memory_fragment(properties_to_project) + memory_fragment = await get_memory_fragment( + properties_to_project=properties_to_project, node_type=node_type, node_name=node_name + ) if collections is None: collections = [ diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py index 4a79a29a8..2a0866f05 100644 --- a/cognee/modules/search/methods/search.py +++ b/cognee/modules/search/methods/search.py @@ -1,5 +1,5 @@ import json -from typing import Callable +from typing import Callable, Optional, Type, List from cognee.exceptions import InvalidValueError from cognee.infrastructure.engine.utils import parse_id @@ -11,6 +11,7 @@ from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionR from cognee.modules.retrieval.graph_summary_completion_retriever import ( GraphSummaryCompletionRetriever, ) +from cognee.infrastructure.engine.models.DataPoint import DataPoint from cognee.modules.retrieval.code_retriever import CodeRetriever from cognee.modules.retrieval.cypher_search_retriever import CypherSearchRetriever from cognee.modules.retrieval.natural_language_retriever import NaturalLanguageRetriever @@ -28,12 +29,21 @@ async def search( datasets: list[str], user: User, system_prompt_path="answer_simple_question.txt", + top_k: int = 10, + node_type: Optional[Type] = None, + node_name: List[Optional[str]] = None, ): query = await log_query(query_text, query_type.value, user.id) own_document_ids = await get_document_ids_for_user(user.id, datasets) search_results = await specific_search( - query_type, query_text, user, system_prompt_path=system_prompt_path + query_type, + query_text, + user, + system_prompt_path=system_prompt_path, + top_k=top_k, + node_type=node_type, + node_name=node_name, ) filtered_search_results = [] @@ -51,7 +61,13 @@ async def search( async def specific_search( - query_type: SearchType, query: str, user: User, system_prompt_path="answer_simple_question.txt" + query_type: SearchType, + query: str, + user: User, + system_prompt_path="answer_simple_question.txt", + top_k: int = 10, + node_type: Optional[Type] = None, + node_name: List[Optional[str]] = None, ) -> list: search_tasks: dict[SearchType, Callable] = { SearchType.SUMMARIES: SummariesRetriever().get_completion, @@ -61,7 +77,10 @@ async def specific_search( system_prompt_path=system_prompt_path ).get_completion, SearchType.GRAPH_COMPLETION: GraphCompletionRetriever( - system_prompt_path=system_prompt_path + system_prompt_path=system_prompt_path, + top_k=top_k, + node_type=node_type, + node_name=node_name, ).get_completion, SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever( system_prompt_path=system_prompt_path diff --git a/examples/python/dynamic_steps_example.py b/examples/python/dynamic_steps_example.py index 8a39ce72c..2226d260a 100644 --- a/examples/python/dynamic_steps_example.py +++ b/examples/python/dynamic_steps_example.py @@ -1,163 +1,15 @@ import cognee import asyncio + + from cognee.shared.logging_utils import get_logger, ERROR from cognee.modules.metrics.operations import get_pipeline_run_metrics - +from cognee.modules.engine.models.Entity import Entity from cognee.api.v1.search import SearchType job_1 = """ -CV 1: Relevant -Name: Dr. Emily Carter -Contact Information: - -Email: emily.carter@example.com -Phone: (555) 123-4567 -Summary: - -Senior Data Scientist with over 8 years of experience in machine learning and predictive analytics. Expertise in developing advanced algorithms and deploying scalable models in production environments. - -Education: - -Ph.D. in Computer Science, Stanford University (2014) -B.S. in Mathematics, University of California, Berkeley (2010) -Experience: - -Senior Data Scientist, InnovateAI Labs (2016 – Present) -Led a team in developing machine learning models for natural language processing applications. -Implemented deep learning algorithms that improved prediction accuracy by 25%. -Collaborated with cross-functional teams to integrate models into cloud-based platforms. -Data Scientist, DataWave Analytics (2014 – 2016) -Developed predictive models for customer segmentation and churn analysis. -Analyzed large datasets using Hadoop and Spark frameworks. -Skills: - -Programming Languages: Python, R, SQL -Machine Learning: TensorFlow, Keras, Scikit-Learn -Big Data Technologies: Hadoop, Spark -Data Visualization: Tableau, Matplotlib -""" - -job_2 = """ -CV 2: Relevant -Name: Michael Rodriguez -Contact Information: - -Email: michael.rodriguez@example.com -Phone: (555) 234-5678 -Summary: - -Data Scientist with a strong background in machine learning and statistical modeling. Skilled in handling large datasets and translating data into actionable business insights. - -Education: - -M.S. in Data Science, Carnegie Mellon University (2013) -B.S. in Computer Science, University of Michigan (2011) -Experience: - -Senior Data Scientist, Alpha Analytics (2017 – Present) -Developed machine learning models to optimize marketing strategies. -Reduced customer acquisition cost by 15% through predictive modeling. -Data Scientist, TechInsights (2013 – 2017) -Analyzed user behavior data to improve product features. -Implemented A/B testing frameworks to evaluate product changes. -Skills: - -Programming Languages: Python, Java, SQL -Machine Learning: Scikit-Learn, XGBoost -Data Visualization: Seaborn, Plotly -Databases: MySQL, MongoDB -""" - - -job_3 = """ -CV 3: Relevant -Name: Sarah Nguyen -Contact Information: - -Email: sarah.nguyen@example.com -Phone: (555) 345-6789 -Summary: - -Data Scientist specializing in machine learning with 6 years of experience. Passionate about leveraging data to drive business solutions and improve product performance. - -Education: - -M.S. in Statistics, University of Washington (2014) -B.S. in Applied Mathematics, University of Texas at Austin (2012) -Experience: - -Data Scientist, QuantumTech (2016 – Present) -Designed and implemented machine learning algorithms for financial forecasting. -Improved model efficiency by 20% through algorithm optimization. -Junior Data Scientist, DataCore Solutions (2014 – 2016) -Assisted in developing predictive models for supply chain optimization. -Conducted data cleaning and preprocessing on large datasets. -Skills: - -Programming Languages: Python, R -Machine Learning Frameworks: PyTorch, Scikit-Learn -Statistical Analysis: SAS, SPSS -Cloud Platforms: AWS, Azure -""" - - -job_4 = """ -CV 4: Not Relevant -Name: David Thompson -Contact Information: - -Email: david.thompson@example.com -Phone: (555) 456-7890 -Summary: - -Creative Graphic Designer with over 8 years of experience in visual design and branding. Proficient in Adobe Creative Suite and passionate about creating compelling visuals. - -Education: - -B.F.A. in Graphic Design, Rhode Island School of Design (2012) -Experience: - -Senior Graphic Designer, CreativeWorks Agency (2015 – Present) -Led design projects for clients in various industries. -Created branding materials that increased client engagement by 30%. -Graphic Designer, Visual Innovations (2012 – 2015) -Designed marketing collateral, including brochures, logos, and websites. -Collaborated with the marketing team to develop cohesive brand strategies. -Skills: - -Design Software: Adobe Photoshop, Illustrator, InDesign -Web Design: HTML, CSS -Specialties: Branding and Identity, Typography -""" - - -job_5 = """ -CV 5: Not Relevant -Name: Jessica Miller -Contact Information: - -Email: jessica.miller@example.com -Phone: (555) 567-8901 -Summary: - -Experienced Sales Manager with a strong track record in driving sales growth and building high-performing teams. Excellent communication and leadership skills. - -Education: - -B.A. in Business Administration, University of Southern California (2010) -Experience: - -Sales Manager, Global Enterprises (2015 – Present) -Managed a sales team of 15 members, achieving a 20% increase in annual revenue. -Developed sales strategies that expanded customer base by 25%. -Sales Representative, Market Leaders Inc. (2010 – 2015) -Consistently exceeded sales targets and received the 'Top Salesperson' award in 2013. -Skills: - -Sales Strategy and Planning -Team Leadership and Development -CRM Software: Salesforce, Zoho -Negotiation and Relationship Building + Natural language processing (NLP) is an interdisciplinary + subfield of computer science and information retrieval. """ @@ -173,7 +25,7 @@ async def main(enable_steps): # Step 2: Add text if enable_steps.get("add_text"): - text_list = [job_1, job_2, job_3, job_4, job_5] + text_list = [job_1] for text in text_list: await cognee.add(text) print(f"Added text: {text[:35]}...") @@ -191,7 +43,10 @@ async def main(enable_steps): # Step 5: Query insights if enable_steps.get("retriever"): search_results = await cognee.search( - query_type=SearchType.GRAPH_COMPLETION, query_text="Who has experience in design tools?" + query_type=SearchType.GRAPH_COMPLETION, + query_text="What is computer science?", + node_type=Entity, + node_name=["computer science"], ) print(search_results)