From b1643414d27b637dea837926ec81f768abb164b3 Mon Sep 17 00:00:00 2001 From: Boris Date: Wed, 10 Sep 2025 16:33:08 +0200 Subject: [PATCH] feat: implement combined context search (#1341) ## Description ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. --- .../v1/search/routers/get_search_router.py | 3 + cognee/api/v1/search/search.py | 8 +- .../modules/graph/cognee_graph/CogneeGraph.py | 2 +- .../graph/utils/resolve_edges_to_text.py | 14 +- .../modules/retrieval/base_graph_retriever.py | 18 + cognee/modules/retrieval/base_retriever.py | 2 +- .../retrieval/coding_rules_retriever.py | 19 +- .../modules/retrieval/completion_retriever.py | 9 +- .../TripletSearchContextProvider.py | 1 + ..._completion_context_extension_retriever.py | 36 +- .../graph_completion_cot_retriever.py | 33 +- .../retrieval/graph_completion_retriever.py | 47 +-- .../modules/retrieval/insights_retriever.py | 17 +- .../modules/retrieval/temporal_retriever.py | 33 +- .../utils/brute_force_triplet_search.py | 36 +- cognee/modules/retrieval/utils/completion.py | 14 +- .../search/methods/get_search_type_tools.py | 168 +++++++++ .../methods/no_access_control_search.py | 47 +++ cognee/modules/search/methods/search.py | 354 ++++++++---------- cognee/modules/search/types/SearchResult.py | 21 ++ cognee/modules/search/types/__init__.py | 1 + cognee/modules/search/utils/__init__.py | 2 + .../search/utils/prepare_search_result.py | 41 ++ .../utils/transform_context_to_graph.py | 38 ++ .../codingagents/coding_rule_associations.py | 6 +- cognee/tests/test_kuzu.py | 8 +- cognee/tests/test_neo4j.py | 8 +- cognee/tests/test_permissions.py | 6 +- cognee/tests/test_relational_db_migration.py | 12 +- cognee/tests/test_search_db.py | 42 +-- ...letion_retriever_context_extension_test.py | 29 +- .../graph_completion_retriever_cot_test.py | 27 +- .../graph_completion_retriever_test.py | 9 +- .../retrieval/insights_retriever_test.py | 6 +- .../modules/search/search_methods_test.py | 230 ------------ examples/python/graphiti_example.py | 1 + 36 files changed, 706 insertions(+), 642 deletions(-) create mode 100644 cognee/modules/retrieval/base_graph_retriever.py create mode 100644 cognee/modules/search/methods/get_search_type_tools.py create mode 100644 cognee/modules/search/methods/no_access_control_search.py create mode 100644 cognee/modules/search/types/SearchResult.py create mode 100644 cognee/modules/search/utils/__init__.py create mode 100644 cognee/modules/search/utils/prepare_search_result.py create mode 100644 cognee/modules/search/utils/transform_context_to_graph.py delete mode 100644 cognee/tests/unit/modules/search/search_methods_test.py diff --git a/cognee/api/v1/search/routers/get_search_router.py b/cognee/api/v1/search/routers/get_search_router.py index b7a8df9d3..b158002a7 100644 --- a/cognee/api/v1/search/routers/get_search_router.py +++ b/cognee/api/v1/search/routers/get_search_router.py @@ -27,6 +27,7 @@ class SearchPayloadDTO(InDTO): node_name: Optional[list[str]] = Field(default=None, example=[]) top_k: Optional[int] = Field(default=10) only_context: bool = Field(default=False) + use_combined_context: bool = Field(default=False) def get_search_router() -> APIRouter: @@ -115,6 +116,7 @@ def get_search_router() -> APIRouter: "node_name": payload.node_name, "top_k": payload.top_k, "only_context": payload.only_context, + "use_combined_context": payload.use_combined_context, }, ) @@ -131,6 +133,7 @@ def get_search_router() -> APIRouter: node_name=payload.node_name, top_k=payload.top_k, only_context=payload.only_context, + use_combined_context=payload.use_combined_context, ) return JSONResponse(content=results) diff --git a/cognee/api/v1/search/search.py b/cognee/api/v1/search/search.py index 49f7aee51..0e7cb6d85 100644 --- a/cognee/api/v1/search/search.py +++ b/cognee/api/v1/search/search.py @@ -3,7 +3,7 @@ from typing import Union, Optional, List, Type from cognee.modules.engine.models.node_set import NodeSet from cognee.modules.users.models import User -from cognee.modules.search.types import SearchType +from cognee.modules.search.types import SearchResult, SearchType, CombinedSearchResult from cognee.modules.users.methods import get_default_user from cognee.modules.search.methods import search as search_function from cognee.modules.data.methods import get_authorized_existing_datasets @@ -13,7 +13,7 @@ from cognee.modules.data.exceptions import DatasetNotFoundError async def search( query_text: str, query_type: SearchType = SearchType.GRAPH_COMPLETION, - user: User = None, + user: Optional[User] = None, datasets: Optional[Union[list[str], str]] = None, dataset_ids: Optional[Union[list[UUID], UUID]] = None, system_prompt_path: str = "answer_simple_question.txt", @@ -24,7 +24,8 @@ async def search( save_interaction: bool = False, last_k: Optional[int] = None, only_context: bool = False, -) -> list: + use_combined_context: bool = False, +) -> Union[List[SearchResult], CombinedSearchResult]: """ Search and query the knowledge graph for insights, information, and connections. @@ -193,6 +194,7 @@ async def search( save_interaction=save_interaction, last_k=last_k, only_context=only_context, + use_combined_context=use_combined_context, ) return filtered_search_results diff --git a/cognee/modules/graph/cognee_graph/CogneeGraph.py b/cognee/modules/graph/cognee_graph/CogneeGraph.py index acfe04de7..28e04cce4 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraph.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraph.py @@ -180,7 +180,7 @@ class CogneeGraph(CogneeAbstractGraph): logger.error(f"Error mapping vector distances to edges: {str(ex)}") raise ex - async def calculate_top_triplet_importances(self, k: int) -> List: + async def calculate_top_triplet_importances(self, k: int) -> List[Edge]: def score(edge): n1 = edge.node1.attributes.get("vector_distance", 1) n2 = edge.node2.attributes.get("vector_distance", 1) diff --git a/cognee/modules/graph/utils/resolve_edges_to_text.py b/cognee/modules/graph/utils/resolve_edges_to_text.py index 56c303abc..eb5bedd2c 100644 --- a/cognee/modules/graph/utils/resolve_edges_to_text.py +++ b/cognee/modules/graph/utils/resolve_edges_to_text.py @@ -1,4 +1,8 @@ -async def resolve_edges_to_text(retrieved_edges: list) -> str: +from typing import List +from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge + + +async def resolve_edges_to_text(retrieved_edges: List[Edge]) -> str: """ Converts retrieved graph edges into a human-readable string format. @@ -13,7 +17,7 @@ async def resolve_edges_to_text(retrieved_edges: list) -> str: - str: A formatted string representation of the nodes and their connections. """ - def _get_nodes(retrieved_edges: list) -> dict: + def _get_nodes(retrieved_edges: List[Edge]) -> dict: def _get_title(text: str, first_n_words: int = 7, top_n_words: int = 3) -> str: def _top_n_words(text, stop_words=None, top_n=3, separator=", "): """Concatenates the top N frequent words in text.""" @@ -36,9 +40,9 @@ async def resolve_edges_to_text(retrieved_edges: list) -> str: return separator.join(top_words) """Creates a title, by combining first words with most frequent words from the text.""" - first_n_words = text.split()[:first_n_words] - top_n_words = _top_n_words(text, top_n=top_n_words) - return f"{' '.join(first_n_words)}... [{top_n_words}]" + first_words = text.split()[:first_n_words] + top_words = _top_n_words(text, top_n=first_n_words) + return f"{' '.join(first_words)}... [{top_words}]" """Creates a dictionary of nodes with their names and content.""" nodes = {} diff --git a/cognee/modules/retrieval/base_graph_retriever.py b/cognee/modules/retrieval/base_graph_retriever.py new file mode 100644 index 000000000..2aaf3468f --- /dev/null +++ b/cognee/modules/retrieval/base_graph_retriever.py @@ -0,0 +1,18 @@ +from typing import List, Optional +from abc import ABC, abstractmethod + +from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge + + +class BaseGraphRetriever(ABC): + """Base class for all graph based retrievers.""" + + @abstractmethod + async def get_context(self, query: str) -> List[Edge]: + """Retrieves triplets based on the query.""" + pass + + @abstractmethod + async def get_completion(self, query: str, context: Optional[List[Edge]] = None) -> str: + """Generates a response using the query and optional context (triplets).""" + pass diff --git a/cognee/modules/retrieval/base_retriever.py b/cognee/modules/retrieval/base_retriever.py index 2df1c5f63..88313b253 100644 --- a/cognee/modules/retrieval/base_retriever.py +++ b/cognee/modules/retrieval/base_retriever.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Optional, Callable +from typing import Any, Optional class BaseRetriever(ABC): diff --git a/cognee/modules/retrieval/coding_rules_retriever.py b/cognee/modules/retrieval/coding_rules_retriever.py index 364ff3236..606cd79da 100644 --- a/cognee/modules/retrieval/coding_rules_retriever.py +++ b/cognee/modules/retrieval/coding_rules_retriever.py @@ -1,3 +1,6 @@ +import asyncio +from functools import reduce +from typing import List, Optional from cognee.shared.logging_utils import get_logger from cognee.tasks.codingagents.coding_rule_associations import get_existing_rules @@ -7,16 +10,22 @@ logger = get_logger("CodingRulesRetriever") class CodingRulesRetriever: """Retriever for handling codeing rule based searches.""" - def __init__(self, rules_nodeset_name="coding_agent_rules"): + def __init__(self, rules_nodeset_name: Optional[List[str]] = None): if isinstance(rules_nodeset_name, list): if not rules_nodeset_name: # If there is no provided nodeset set to coding_agent_rules rules_nodeset_name = ["coding_agent_rules"] - rules_nodeset_name = rules_nodeset_name[0] + self.rules_nodeset_name = rules_nodeset_name """Initialize retriever with search parameters.""" async def get_existing_rules(self, query_text): - return await get_existing_rules( - rules_nodeset_name=self.rules_nodeset_name, return_list=True - ) + if self.rules_nodeset_name: + rules_list = await asyncio.gather( + *[ + get_existing_rules(rules_nodeset_name=nodeset) + for nodeset in self.rules_nodeset_name + ] + ) + + return reduce(lambda x, y: x + y, rules_list, []) diff --git a/cognee/modules/retrieval/completion_retriever.py b/cognee/modules/retrieval/completion_retriever.py index 4d34dfdbe..44f27bece 100644 --- a/cognee/modules/retrieval/completion_retriever.py +++ b/cognee/modules/retrieval/completion_retriever.py @@ -23,16 +23,14 @@ class CompletionRetriever(BaseRetriever): self, user_prompt_path: str = "context_for_question.txt", system_prompt_path: str = "answer_simple_question.txt", - system_prompt: str = None, + system_prompt: Optional[str] = None, top_k: Optional[int] = 1, - only_context: bool = False, ): """Initialize retriever with optional custom prompt paths.""" self.user_prompt_path = user_prompt_path self.system_prompt_path = system_prompt_path self.top_k = top_k if top_k is not None else 1 self.system_prompt = system_prompt - self.only_context = only_context async def get_context(self, query: str) -> str: """ @@ -69,7 +67,7 @@ class CompletionRetriever(BaseRetriever): logger.error("DocumentChunk_text collection not found") raise NoDataError("No data found in the system, please add data first.") from error - async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: + async def get_completion(self, query: str, context: Optional[Any] = None) -> str: """ Generates an LLM completion using the context. @@ -97,6 +95,5 @@ class CompletionRetriever(BaseRetriever): user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, system_prompt=self.system_prompt, - only_context=self.only_context, ) - return [completion] + return completion diff --git a/cognee/modules/retrieval/context_providers/TripletSearchContextProvider.py b/cognee/modules/retrieval/context_providers/TripletSearchContextProvider.py index ac29231ee..b539055fa 100644 --- a/cognee/modules/retrieval/context_providers/TripletSearchContextProvider.py +++ b/cognee/modules/retrieval/context_providers/TripletSearchContextProvider.py @@ -49,6 +49,7 @@ class TripletSearchContextProvider(BaseContextProvider): tasks = [ brute_force_triplet_search( query=f"{entity_text} {query}", + user=user, top_k=self.top_k, collections=self.collections, properties_to_project=self.properties_to_project, diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py index 8bdf5f1a0..4f4af1f06 100644 --- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py +++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py @@ -1,4 +1,5 @@ -from typing import Any, Optional, List, Type +from typing import Optional, List, Type +from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge from cognee.shared.logging_utils import get_logger from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever from cognee.modules.retrieval.utils.completion import generate_completion @@ -31,7 +32,6 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, save_interaction: bool = False, - only_context: bool = False, ): super().__init__( user_prompt_path=user_prompt_path, @@ -41,15 +41,14 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): node_name=node_name, save_interaction=save_interaction, system_prompt=system_prompt, - only_context=only_context, ) async def get_completion( self, query: str, - context: Optional[Any] = None, + context: Optional[List[Edge]] = None, context_extension_rounds=4, - ) -> List[str]: + ) -> str: """ Extends the context for a given query by retrieving related triplets and generating new completions based on them. @@ -74,11 +73,12 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): - List[str]: A list containing the generated answer based on the query and the extended context. """ - triplets = [] + triplets = context - if context is None: - triplets += await self.get_triplets(query) - context = await self.resolve_edges_to_text(triplets) + if triplets is None: + triplets = await self.get_context(query) + + context_text = await self.resolve_edges_to_text(triplets) round_idx = 1 @@ -90,15 +90,15 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): ) completion = await generate_completion( query=query, - context=context, + context=context_text, user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, system_prompt=self.system_prompt, ) - triplets += await self.get_triplets(completion) + triplets += await self.get_context(completion) triplets = list(set(triplets)) - context = await self.resolve_edges_to_text(triplets) + context_text = await self.resolve_edges_to_text(triplets) num_triplets = len(triplets) @@ -117,19 +117,15 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): completion = await generate_completion( query=query, - context=context, + context=context_text, user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, system_prompt=self.system_prompt, - only_context=self.only_context, ) - if self.save_interaction and context and triplets and completion: + if self.save_interaction and context_text and triplets and completion: await self.save_qa( - question=query, answer=completion, context=context, triplets=triplets + question=query, answer=completion, context=context_text, triplets=triplets ) - if self.only_context: - return [context] - else: - return [completion] + return completion diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index 7e14078e4..282c6147e 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -1,4 +1,5 @@ -from typing import Any, Optional, List, Tuple, Type +from typing import Optional, List, Type +from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge from cognee.shared.logging_utils import get_logger from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever @@ -32,18 +33,16 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): validation_system_prompt_path: str = "cot_validation_system_prompt.txt", followup_system_prompt_path: str = "cot_followup_system_prompt.txt", followup_user_prompt_path: str = "cot_followup_user_prompt.txt", - system_prompt: str = None, + system_prompt: Optional[str] = None, top_k: Optional[int] = 5, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, save_interaction: bool = False, - only_context: bool = False, ): super().__init__( user_prompt_path=user_prompt_path, system_prompt_path=system_prompt_path, system_prompt=system_prompt, - only_context=only_context, top_k=top_k, node_type=node_type, node_name=node_name, @@ -57,9 +56,9 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): async def get_completion( self, query: str, - context: Optional[Any] = None, + context: Optional[List[Edge]] = None, max_iter=4, - ) -> List[str]: + ) -> str: """ Generate completion responses based on a user query and contextual information. @@ -84,26 +83,29 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): """ followup_question = "" triplets = [] - completion = [""] + completion = "" for round_idx in range(max_iter + 1): if round_idx == 0: if context is None: - context = await self.get_context(query) + triplets = await self.get_context(query) + context_text = await self.resolve_edges_to_text(triplets) + else: + context_text = await self.resolve_edges_to_text(context) else: - triplets += await self.get_triplets(followup_question) - context = await self.resolve_edges_to_text(list(set(triplets))) + triplets += await self.get_context(followup_question) + context_text = await self.resolve_edges_to_text(list(set(triplets))) completion = await generate_completion( query=query, - context=context, + context=context_text, user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, system_prompt=self.system_prompt, ) logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}") if round_idx < max_iter: - valid_args = {"query": query, "answer": completion, "context": context} + valid_args = {"query": query, "answer": completion, "context": context_text} valid_user_prompt = LLMGateway.render_prompt( filename=self.validation_user_prompt_path, context=valid_args ) @@ -133,10 +135,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): if self.save_interaction and context and triplets and completion: await self.save_qa( - question=query, answer=completion, context=context, triplets=triplets + question=query, answer=completion, context=context_text, triplets=triplets ) - if self.only_context: - return [context] - else: - return [completion] + return completion diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index 478fe4300..45e7f85ff 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -1,16 +1,15 @@ from typing import Any, Optional, Type, List -from collections import Counter from uuid import NAMESPACE_OID, uuid5 -import string from cognee.infrastructure.engine import DataPoint +from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge +from cognee.modules.users.methods import get_default_user from cognee.tasks.storage import add_data_points from cognee.modules.graph.utils import resolve_edges_to_text from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses -from cognee.modules.retrieval.base_retriever import BaseRetriever +from cognee.modules.retrieval.base_graph_retriever import BaseGraphRetriever from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search from cognee.modules.retrieval.utils.completion import generate_completion -from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS from cognee.shared.logging_utils import get_logger from cognee.modules.retrieval.utils.extract_uuid_from_node import extract_uuid_from_node from cognee.modules.retrieval.utils.models import CogneeUserInteraction @@ -20,7 +19,7 @@ from cognee.infrastructure.databases.graph import get_graph_engine logger = get_logger("GraphCompletionRetriever") -class GraphCompletionRetriever(BaseRetriever): +class GraphCompletionRetriever(BaseGraphRetriever): """ Retriever for handling graph-based completion searches. @@ -37,19 +36,17 @@ class GraphCompletionRetriever(BaseRetriever): self, user_prompt_path: str = "graph_context_for_question.txt", system_prompt_path: str = "answer_simple_question.txt", - system_prompt: str = None, + system_prompt: Optional[str] = None, top_k: Optional[int] = 5, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, save_interaction: bool = False, - only_context: bool = False, ): """Initialize retriever with prompt paths and search parameters.""" self.save_interaction = save_interaction self.user_prompt_path = user_prompt_path self.system_prompt_path = system_prompt_path self.system_prompt = system_prompt - self.only_context = only_context self.top_k = top_k if top_k is not None else 5 self.node_type = node_type self.node_name = node_name @@ -70,7 +67,7 @@ class GraphCompletionRetriever(BaseRetriever): """ return await resolve_edges_to_text(retrieved_edges) - async def get_triplets(self, query: str) -> list: + async def get_triplets(self, query: str) -> List[Edge]: """ Retrieves relevant graph triplets based on a query string. @@ -85,7 +82,7 @@ class GraphCompletionRetriever(BaseRetriever): - list: A list of found triplets that match the query. """ subclasses = get_all_subclasses(DataPoint) - vector_index_collections = [] + vector_index_collections: List[str] = [] for subclass in subclasses: if "metadata" in subclass.model_fields: @@ -96,8 +93,11 @@ class GraphCompletionRetriever(BaseRetriever): for field_name in index_fields: vector_index_collections.append(f"{subclass.__name__}_{field_name}") + user = await get_default_user() + found_triplets = await brute_force_triplet_search( query, + user=user, top_k=self.top_k, collections=vector_index_collections or None, node_type=self.node_type, @@ -106,7 +106,7 @@ class GraphCompletionRetriever(BaseRetriever): return found_triplets - async def get_context(self, query: str) -> tuple[str, list]: + async def get_context(self, query: str) -> List[Edge]: """ Retrieves and resolves graph triplets into context based on a query. @@ -125,17 +125,17 @@ class GraphCompletionRetriever(BaseRetriever): if len(triplets) == 0: logger.warning("Empty context was provided to the completion") - return "", triplets + return [] - context = await self.resolve_edges_to_text(triplets) + # context = await self.resolve_edges_to_text(triplets) - return context, triplets + return triplets async def get_completion( self, query: str, - context: Optional[Any] = None, - ) -> List[str]: + context: Optional[List[Edge]] = None, + ) -> Any: """ Generates a completion using graph connections context based on a query. @@ -151,26 +151,27 @@ class GraphCompletionRetriever(BaseRetriever): - Any: A generated completion based on the query and context provided. """ - triplets = None + triplets = context - if context is None: - context, triplets = await self.get_context(query) + if triplets is None: + triplets = await self.get_context(query) + + context_text = await resolve_edges_to_text(triplets) completion = await generate_completion( query=query, - context=context, + context=context_text, user_prompt_path=self.user_prompt_path, system_prompt_path=self.system_prompt_path, system_prompt=self.system_prompt, - only_context=self.only_context, ) if self.save_interaction and context and triplets and completion: await self.save_qa( - question=query, answer=completion, context=context, triplets=triplets + question=query, answer=completion, context=context_text, triplets=triplets ) - return [completion] + return completion async def save_qa(self, question: str, answer: str, context: str, triplets: List) -> None: """ diff --git a/cognee/modules/retrieval/insights_retriever.py b/cognee/modules/retrieval/insights_retriever.py index 49acbe6f3..43b77e951 100644 --- a/cognee/modules/retrieval/insights_retriever.py +++ b/cognee/modules/retrieval/insights_retriever.py @@ -1,17 +1,18 @@ import asyncio from typing import Any, Optional +from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge, Node +from cognee.modules.retrieval.base_graph_retriever import BaseGraphRetriever from cognee.shared.logging_utils import get_logger from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.vector import get_vector_engine -from cognee.modules.retrieval.base_retriever import BaseRetriever from cognee.modules.retrieval.exceptions.exceptions import NoDataError from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError logger = get_logger("InsightsRetriever") -class InsightsRetriever(BaseRetriever): +class InsightsRetriever(BaseGraphRetriever): """ Retriever for handling graph connection-based insights. @@ -95,7 +96,17 @@ class InsightsRetriever(BaseRetriever): unique_node_connections_map[unique_id] = True unique_node_connections.append(node_connection) - return unique_node_connections + return [ + Edge( + node1=Node(node_id=connection[0]["id"], attributes=connection[0]), + node2=Node(node_id=connection[2]["id"], attributes=connection[2]), + attributes={ + **connection[1], + "relationship_type": connection[1]["relationship_name"], + }, + ) + for connection in unique_node_connections + ] async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: """ diff --git a/cognee/modules/retrieval/temporal_retriever.py b/cognee/modules/retrieval/temporal_retriever.py index 0ab8e2ecf..09f2980dd 100644 --- a/cognee/modules/retrieval/temporal_retriever.py +++ b/cognee/modules/retrieval/temporal_retriever.py @@ -1,5 +1,5 @@ import os -from typing import Any, Optional, List, Tuple, Type +from typing import Any, Optional, List, Type from operator import itemgetter @@ -113,8 +113,8 @@ class TemporalRetriever(GraphCompletionRetriever): logger.info( "No timestamps identified based on the query, performing retrieval using triplet search on events and entities." ) - triplets = await self.get_triplets(query) - return await self.resolve_edges_to_text(triplets), triplets + triplets = await self.get_context(query) + return await self.resolve_edges_to_text(triplets) if ids: relevant_events = await graph_engine.collect_events(ids=ids) @@ -122,8 +122,8 @@ class TemporalRetriever(GraphCompletionRetriever): logger.info( "No events identified based on timestamp filtering, performing retrieval using triplet search on events and entities." ) - triplets = await self.get_triplets(query) - return await self.resolve_edges_to_text(triplets), triplets + triplets = await self.get_context(query) + return await self.resolve_edges_to_text(triplets) vector_engine = get_vector_engine() query_vector = (await vector_engine.embedding_engine.embed_text([query]))[0] @@ -134,18 +134,19 @@ class TemporalRetriever(GraphCompletionRetriever): top_k_events = await self.filter_top_k_events(relevant_events, vector_search_results) - return self.descriptions_to_string(top_k_events), triplets + return self.descriptions_to_string(top_k_events) - async def get_completion(self, query: str, context: Optional[Any] = None) -> List[str]: + async def get_completion(self, query: str, context: Optional[str] = None) -> str: """Generates a response using the query and optional context.""" + if not context: + context = await self.get_context(query=query) - context, triplets = await self.get_context(query=query) + if context: + completion = await generate_completion( + query=query, + context=context, + user_prompt_path=self.user_prompt_path, + system_prompt_path=self.system_prompt_path, + ) - completion = await generate_completion( - query=query, - context=context, - user_prompt_path=self.user_prompt_path, - system_prompt_path=self.system_prompt_path, - ) - - return [completion] + return completion diff --git a/cognee/modules/retrieval/utils/brute_force_triplet_search.py b/cognee/modules/retrieval/utils/brute_force_triplet_search.py index 44bb10dcb..9e57ddc67 100644 --- a/cognee/modules/retrieval/utils/brute_force_triplet_search.py +++ b/cognee/modules/retrieval/utils/brute_force_triplet_search.py @@ -8,7 +8,7 @@ from cognee.infrastructure.databases.vector.exceptions import CollectionNotFound from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.vector import get_vector_engine from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph -from cognee.modules.users.methods import get_default_user +from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge from cognee.modules.users.models import User from cognee.shared.utils import send_telemetry @@ -87,41 +87,15 @@ async def get_memory_fragment( async def brute_force_triplet_search( - query: str, - user: User = None, - top_k: int = 5, - collections: List[str] = None, - properties_to_project: List[str] = None, - memory_fragment: Optional[CogneeGraph] = None, - node_type: Optional[Type] = None, - node_name: Optional[List[str]] = None, -) -> list: - if user is None: - user = await get_default_user() - - retrieved_results = await brute_force_search( - query, - user, - top_k, - collections=collections, - properties_to_project=properties_to_project, - memory_fragment=memory_fragment, - node_type=node_type, - node_name=node_name, - ) - return retrieved_results - - -async def brute_force_search( query: str, user: User, - top_k: int, - collections: List[str] = None, - properties_to_project: List[str] = None, + top_k: int = 5, + collections: Optional[List[str]] = None, + properties_to_project: Optional[List[str]] = None, memory_fragment: Optional[CogneeGraph] = None, node_type: Optional[Type] = None, node_name: Optional[List[str]] = None, -) -> list: +) -> List[Edge]: """ Performs a brute force search to retrieve the top triplets from the graph. diff --git a/cognee/modules/retrieval/utils/completion.py b/cognee/modules/retrieval/utils/completion.py index 81e636aad..09c2dce5c 100644 --- a/cognee/modules/retrieval/utils/completion.py +++ b/cognee/modules/retrieval/utils/completion.py @@ -8,7 +8,6 @@ async def generate_completion( user_prompt_path: str, system_prompt_path: str, system_prompt: Optional[str] = None, - only_context: bool = False, ) -> str: """Generates a completion using LLM with given context and prompts.""" args = {"question": query, "context": context} @@ -17,14 +16,11 @@ async def generate_completion( system_prompt if system_prompt else LLMGateway.read_query_prompt(system_prompt_path) ) - if only_context: - return context - else: - return await LLMGateway.acreate_structured_output( - text_input=user_prompt, - system_prompt=system_prompt, - response_model=str, - ) + return await LLMGateway.acreate_structured_output( + text_input=user_prompt, + system_prompt=system_prompt, + response_model=str, + ) async def summarize_text( diff --git a/cognee/modules/search/methods/get_search_type_tools.py b/cognee/modules/search/methods/get_search_type_tools.py new file mode 100644 index 000000000..e671a7db3 --- /dev/null +++ b/cognee/modules/search/methods/get_search_type_tools.py @@ -0,0 +1,168 @@ +from typing import Callable, List, Optional, Type + +from cognee.modules.engine.models.node_set import NodeSet +from cognee.modules.search.types import SearchType +from cognee.modules.search.operations import select_search_type +from cognee.modules.search.exceptions import UnsupportedSearchTypeError + +# Retrievers +from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback +from cognee.modules.retrieval.chunks_retriever import ChunksRetriever +from cognee.modules.retrieval.insights_retriever import InsightsRetriever +from cognee.modules.retrieval.summaries_retriever import SummariesRetriever +from cognee.modules.retrieval.completion_retriever import CompletionRetriever +from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever +from cognee.modules.retrieval.temporal_retriever import TemporalRetriever +from cognee.modules.retrieval.coding_rules_retriever import CodingRulesRetriever +from cognee.modules.retrieval.graph_summary_completion_retriever import ( + GraphSummaryCompletionRetriever, +) +from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever +from cognee.modules.retrieval.graph_completion_context_extension_retriever import ( + GraphCompletionContextExtensionRetriever, +) +from cognee.modules.retrieval.code_retriever import CodeRetriever +from cognee.modules.retrieval.cypher_search_retriever import CypherSearchRetriever +from cognee.modules.retrieval.natural_language_retriever import NaturalLanguageRetriever + + +async def get_search_type_tools( + query_type: SearchType, + query_text: str, + system_prompt_path: str = "answer_simple_question.txt", + system_prompt: Optional[str] = None, + top_k: int = 10, + node_type: Optional[Type] = NodeSet, + node_name: Optional[List[str]] = None, + save_interaction: bool = False, + last_k: Optional[int] = None, +) -> list: + search_tasks: dict[SearchType, List[Callable]] = { + SearchType.SUMMARIES: [ + SummariesRetriever(top_k=top_k).get_completion, + SummariesRetriever(top_k=top_k).get_context, + ], + SearchType.INSIGHTS: [ + InsightsRetriever(top_k=top_k).get_completion, + InsightsRetriever(top_k=top_k).get_context, + ], + SearchType.CHUNKS: [ + ChunksRetriever(top_k=top_k).get_completion, + ChunksRetriever(top_k=top_k).get_context, + ], + SearchType.RAG_COMPLETION: [ + CompletionRetriever( + system_prompt_path=system_prompt_path, + top_k=top_k, + system_prompt=system_prompt, + ).get_completion, + CompletionRetriever( + system_prompt_path=system_prompt_path, + top_k=top_k, + system_prompt=system_prompt, + ).get_context, + ], + SearchType.GRAPH_COMPLETION: [ + GraphCompletionRetriever( + system_prompt_path=system_prompt_path, + top_k=top_k, + node_type=node_type, + node_name=node_name, + save_interaction=save_interaction, + system_prompt=system_prompt, + ).get_completion, + GraphCompletionRetriever( + system_prompt_path=system_prompt_path, + top_k=top_k, + node_type=node_type, + node_name=node_name, + save_interaction=save_interaction, + system_prompt=system_prompt, + ).get_context, + ], + SearchType.GRAPH_COMPLETION_COT: [ + GraphCompletionCotRetriever( + system_prompt_path=system_prompt_path, + top_k=top_k, + node_type=node_type, + node_name=node_name, + save_interaction=save_interaction, + system_prompt=system_prompt, + ).get_completion, + GraphCompletionCotRetriever( + system_prompt_path=system_prompt_path, + top_k=top_k, + node_type=node_type, + node_name=node_name, + save_interaction=save_interaction, + system_prompt=system_prompt, + ).get_context, + ], + SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: [ + GraphCompletionContextExtensionRetriever( + system_prompt_path=system_prompt_path, + top_k=top_k, + node_type=node_type, + node_name=node_name, + save_interaction=save_interaction, + system_prompt=system_prompt, + ).get_completion, + GraphCompletionContextExtensionRetriever( + system_prompt_path=system_prompt_path, + top_k=top_k, + node_type=node_type, + node_name=node_name, + save_interaction=save_interaction, + system_prompt=system_prompt, + ).get_context, + ], + SearchType.GRAPH_SUMMARY_COMPLETION: [ + GraphSummaryCompletionRetriever( + system_prompt_path=system_prompt_path, + top_k=top_k, + node_type=node_type, + node_name=node_name, + save_interaction=save_interaction, + system_prompt=system_prompt, + ).get_completion, + GraphSummaryCompletionRetriever( + system_prompt_path=system_prompt_path, + top_k=top_k, + node_type=node_type, + node_name=node_name, + save_interaction=save_interaction, + system_prompt=system_prompt, + ).get_context, + ], + SearchType.CODE: [ + CodeRetriever(top_k=top_k).get_completion, + CodeRetriever(top_k=top_k).get_context, + ], + SearchType.CYPHER: [ + CypherSearchRetriever().get_completion, + CypherSearchRetriever().get_context, + ], + SearchType.NATURAL_LANGUAGE: [ + NaturalLanguageRetriever().get_completion, + NaturalLanguageRetriever().get_context, + ], + SearchType.FEEDBACK: [UserQAFeedback(last_k=last_k).add_feedback], + SearchType.TEMPORAL: [ + TemporalRetriever(top_k=top_k).get_completion, + TemporalRetriever(top_k=top_k).get_context, + ], + SearchType.CODING_RULES: [ + CodingRulesRetriever(rules_nodeset_name=node_name).get_existing_rules, + ], + } + + # If the query type is FEELING_LUCKY, select the search type intelligently + if query_type is SearchType.FEELING_LUCKY: + query_type = await select_search_type(query_text) + + search_type_tools = search_tasks.get(query_type) + + if not search_type_tools: + raise UnsupportedSearchTypeError(str(query_type)) + + return search_type_tools diff --git a/cognee/modules/search/methods/no_access_control_search.py b/cognee/modules/search/methods/no_access_control_search.py new file mode 100644 index 000000000..bb3eaba42 --- /dev/null +++ b/cognee/modules/search/methods/no_access_control_search.py @@ -0,0 +1,47 @@ +from typing import Any, List, Optional, Tuple, Type, Union + +from cognee.modules.data.models.Dataset import Dataset +from cognee.modules.engine.models.node_set import NodeSet +from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge +from cognee.modules.search.types import SearchType + +from .get_search_type_tools import get_search_type_tools + + +async def no_access_control_search( + query_type: SearchType, + query_text: str, + system_prompt_path: str = "answer_simple_question.txt", + system_prompt: Optional[str] = None, + top_k: int = 10, + node_type: Optional[Type] = NodeSet, + node_name: Optional[List[str]] = None, + save_interaction: bool = False, + last_k: Optional[int] = None, + only_context: bool = False, +) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]: + search_tools = await get_search_type_tools( + query_type=query_type, + query_text=query_text, + system_prompt_path=system_prompt_path, + system_prompt=system_prompt, + top_k=top_k, + node_type=node_type, + node_name=node_name, + save_interaction=save_interaction, + last_k=last_k, + ) + if len(search_tools) == 2: + [get_completion, get_context] = search_tools + + if only_context: + return await get_context(query_text) + + context = await get_context(query_text) + result = await get_completion(query_text, context) + else: + unknown_tool = search_tools[0] + result = await unknown_tool(query_text) + context = "" + + return result, context, [] diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py index cb2cc2d20..8e4d41509 100644 --- a/cognee/modules/search/methods/search.py +++ b/cognee/modules/search/methods/search.py @@ -3,37 +3,27 @@ import json import asyncio from uuid import UUID from fastapi.encoders import jsonable_encoder -from typing import Callable, List, Optional, Type, Union +from typing import Any, List, Optional, Tuple, Type, Union +from cognee.shared.utils import send_telemetry +from cognee.context_global_variables import set_database_global_context_variables from cognee.modules.engine.models.node_set import NodeSet -from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback -from cognee.modules.search.exceptions import UnsupportedSearchTypeError -from cognee.context_global_variables import set_database_global_context_variables -from cognee.modules.retrieval.chunks_retriever import ChunksRetriever -from cognee.modules.retrieval.insights_retriever import InsightsRetriever -from cognee.modules.retrieval.summaries_retriever import SummariesRetriever -from cognee.modules.retrieval.completion_retriever import CompletionRetriever -from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever -from cognee.modules.retrieval.temporal_retriever import TemporalRetriever -from cognee.modules.retrieval.coding_rules_retriever import CodingRulesRetriever -from cognee.modules.retrieval.graph_summary_completion_retriever import ( - GraphSummaryCompletionRetriever, +from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge +from cognee.modules.search.types import ( + SearchResult, + CombinedSearchResult, + SearchResultDataset, + SearchType, ) -from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever -from cognee.modules.retrieval.graph_completion_context_extension_retriever import ( - GraphCompletionContextExtensionRetriever, -) -from cognee.modules.retrieval.code_retriever import CodeRetriever -from cognee.modules.retrieval.cypher_search_retriever import CypherSearchRetriever -from cognee.modules.retrieval.natural_language_retriever import NaturalLanguageRetriever -from cognee.modules.search.types import SearchType -from cognee.modules.storage.utils import JSONEncoder +from cognee.modules.search.operations import log_query, log_result from cognee.modules.users.models import User from cognee.modules.data.models import Dataset -from cognee.shared.utils import send_telemetry from cognee.modules.users.permissions.methods import get_specific_user_permission_datasets -from cognee.modules.search.operations import log_query, log_result, select_search_type + +from .get_search_type_tools import get_search_type_tools +from .no_access_control_search import no_access_control_search +from ..utils.prepare_search_result import prepare_search_result async def search( @@ -46,10 +36,11 @@ async def search( top_k: int = 10, node_type: Optional[Type] = NodeSet, node_name: Optional[List[str]] = None, - save_interaction: Optional[bool] = False, + save_interaction: bool = False, last_k: Optional[int] = None, only_context: bool = False, -): + use_combined_context: bool = False, +) -> Union[CombinedSearchResult, List[SearchResult]]: """ Args: @@ -65,9 +56,12 @@ async def search( Notes: Searching by dataset is only available in ENABLE_BACKEND_ACCESS_CONTROL mode """ + query = await log_query(query_text, query_type.value, user.id) + send_telemetry("cognee.search EXECUTION STARTED", user.id) + # Use search function filtered by permissions if access control is enabled if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true": - return await authorized_search( + search_results = await authorized_search( query_type=query_type, query_text=query_text, user=user, @@ -80,119 +74,68 @@ async def search( save_interaction=save_interaction, last_k=last_k, only_context=only_context, + use_combined_context=use_combined_context, ) + else: + search_results = [ + await no_access_control_search( + query_type=query_type, + query_text=query_text, + system_prompt_path=system_prompt_path, + system_prompt=system_prompt, + top_k=top_k, + node_type=node_type, + node_name=node_name, + save_interaction=save_interaction, + last_k=last_k, + only_context=only_context, + ) + ] - query = await log_query(query_text, query_type.value, user.id) - - search_results = await specific_search( - query_type=query_type, - query_text=query_text, - user=user, - system_prompt_path=system_prompt_path, - system_prompt=system_prompt, - top_k=top_k, - node_type=node_type, - node_name=node_name, - save_interaction=save_interaction, - last_k=last_k, - only_context=only_context, - ) + send_telemetry("cognee.search EXECUTION COMPLETED", user.id) await log_result( query.id, json.dumps( - search_results if len(search_results) > 1 else search_results[0], cls=JSONEncoder + jsonable_encoder( + await prepare_search_result(search_results) + if use_combined_context + else [ + await prepare_search_result(search_result) for search_result in search_results + ] + ) ), user.id, ) - return search_results + if use_combined_context: + prepared_search_results = await prepare_search_result(search_results) + result = prepared_search_results["result"] + graphs = prepared_search_results["graphs"] + context = prepared_search_results["context"] + datasets = prepared_search_results["datasets"] - -async def specific_search( - query_type: SearchType, - query_text: str, - user: User, - system_prompt_path: str = "answer_simple_question.txt", - system_prompt: Optional[str] = None, - top_k: int = 10, - node_type: Optional[Type] = NodeSet, - node_name: Optional[List[str]] = None, - save_interaction: Optional[bool] = False, - last_k: Optional[int] = None, - only_context: bool = None, -) -> list: - search_tasks: dict[SearchType, Callable] = { - SearchType.SUMMARIES: SummariesRetriever(top_k=top_k).get_completion, - SearchType.INSIGHTS: InsightsRetriever(top_k=top_k).get_completion, - SearchType.CHUNKS: ChunksRetriever(top_k=top_k).get_completion, - SearchType.RAG_COMPLETION: CompletionRetriever( - system_prompt_path=system_prompt_path, - top_k=top_k, - system_prompt=system_prompt, - only_context=only_context, - ).get_completion, - SearchType.GRAPH_COMPLETION: GraphCompletionRetriever( - system_prompt_path=system_prompt_path, - top_k=top_k, - node_type=node_type, - node_name=node_name, - save_interaction=save_interaction, - system_prompt=system_prompt, - only_context=only_context, - ).get_completion, - SearchType.GRAPH_COMPLETION_COT: GraphCompletionCotRetriever( - system_prompt_path=system_prompt_path, - top_k=top_k, - node_type=node_type, - node_name=node_name, - save_interaction=save_interaction, - system_prompt=system_prompt, - only_context=only_context, - ).get_completion, - SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: GraphCompletionContextExtensionRetriever( - system_prompt_path=system_prompt_path, - top_k=top_k, - node_type=node_type, - node_name=node_name, - save_interaction=save_interaction, - system_prompt=system_prompt, - only_context=only_context, - ).get_completion, - SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever( - system_prompt_path=system_prompt_path, - top_k=top_k, - node_type=node_type, - node_name=node_name, - save_interaction=save_interaction, - system_prompt=system_prompt, - ).get_completion, - SearchType.CODE: CodeRetriever(top_k=top_k).get_completion, - SearchType.CYPHER: CypherSearchRetriever().get_completion, - SearchType.NATURAL_LANGUAGE: NaturalLanguageRetriever().get_completion, - SearchType.FEEDBACK: UserQAFeedback(last_k=last_k).add_feedback, - SearchType.TEMPORAL: TemporalRetriever(top_k=top_k).get_completion, - SearchType.CODING_RULES: CodingRulesRetriever( - rules_nodeset_name=node_name - ).get_existing_rules, - } - - # If the query type is FEELING_LUCKY, select the search type intelligently - if query_type is SearchType.FEELING_LUCKY: - query_type = await select_search_type(query_text) - - search_task = search_tasks.get(query_type) - - if search_task is None: - raise UnsupportedSearchTypeError(str(query_type)) - - send_telemetry("cognee.search EXECUTION STARTED", user.id) - - results = await search_task(query_text) - - send_telemetry("cognee.search EXECUTION COMPLETED", user.id) - - return results + return CombinedSearchResult( + result=result, + graphs=graphs, + context=context, + datasets=[ + SearchResultDataset( + id=dataset.id, + name=dataset.name, + ) + for dataset in datasets + ], + ) + else: + return [ + SearchResult( + search_result=result, + dataset_id=datasets[min(index, len(datasets) - 1)].id if datasets else None, + dataset_name=datasets[min(index, len(datasets) - 1)].name if datasets else None, + ) + for index, (result, _, datasets) in enumerate(search_results) + ] async def authorized_search( @@ -205,26 +148,85 @@ async def authorized_search( top_k: int = 10, node_type: Optional[Type] = NodeSet, node_name: Optional[List[str]] = None, - save_interaction: Optional[bool] = False, + save_interaction: bool = False, last_k: Optional[int] = None, - only_context: bool = None, -) -> list: + only_context: bool = False, + use_combined_context: bool = False, +) -> Union[ + Tuple[Any, Union[List[Edge], str], List[Dataset]], + List[Tuple[Any, Union[List[Edge], str], List[Dataset]]], +]: """ Verifies access for provided datasets or uses all datasets user has read access for and performs search per dataset. Not to be used outside of active access control mode. """ - - query = await log_query(query_text, query_type.value, user.id) - # Find datasets user has read access for (if datasets are provided only return them. Provided user has read access) search_datasets = await get_specific_user_permission_datasets(user.id, "read", dataset_ids) + if use_combined_context: + search_responses = await search_in_datasets_context( + search_datasets=search_datasets, + query_type=query_type, + query_text=query_text, + system_prompt_path=system_prompt_path, + system_prompt=system_prompt, + top_k=top_k, + node_type=node_type, + node_name=node_name, + save_interaction=save_interaction, + last_k=last_k, + only_context=True, + ) + + context = {} + datasets: List[Dataset] = [] + + for _, search_context, datasets in search_responses: + for dataset in datasets: + context[str(dataset.id)] = search_context + + datasets.extend(datasets) + + specific_search_tools = await get_search_type_tools( + query_type=query_type, + query_text=query_text, + system_prompt_path=system_prompt_path, + system_prompt=system_prompt, + top_k=top_k, + node_type=node_type, + node_name=node_name, + save_interaction=save_interaction, + last_k=last_k, + ) + search_tools = specific_search_tools + if len(search_tools) == 2: + [get_completion, _] = search_tools + else: + get_completion = search_tools[0] + + def prepare_combined_context( + context, + ) -> Union[List[Edge], str]: + combined_context = [] + + for dataset_context in context.values(): + combined_context += dataset_context + + if combined_context and isinstance(combined_context[0], str): + return "\n".join(combined_context) + + return combined_context + + combined_context = prepare_combined_context(context) + completion = await get_completion(query_text, combined_context) + + return completion, combined_context, datasets + # Searches all provided datasets and handles setting up of appropriate database context based on permissions - search_results = await specific_search_by_context( + search_results = await search_in_datasets_context( search_datasets=search_datasets, query_type=query_type, query_text=query_text, - user=user, system_prompt_path=system_prompt_path, system_prompt=system_prompt, top_k=top_k, @@ -235,51 +237,48 @@ async def authorized_search( only_context=only_context, ) - await log_result(query.id, json.dumps(jsonable_encoder(search_results)), user.id) - return search_results -async def specific_search_by_context( +async def search_in_datasets_context( search_datasets: list[Dataset], query_type: SearchType, query_text: str, - user: User, system_prompt_path: str = "answer_simple_question.txt", system_prompt: Optional[str] = None, top_k: int = 10, node_type: Optional[Type] = NodeSet, node_name: Optional[List[str]] = None, - save_interaction: Optional[bool] = False, + save_interaction: bool = False, last_k: Optional[int] = None, - only_context: bool = None, -): + only_context: bool = False, + context: Optional[Any] = None, +) -> List[Tuple[Any, Union[str, List[Edge]], List[Dataset]]]: """ Searches all provided datasets and handles setting up of appropriate database context based on permissions. Not to be used outside of active access control mode. """ - async def _search_by_context( + async def _search_in_dataset_context( dataset: Dataset, query_type: SearchType, query_text: str, - user: User, system_prompt_path: str = "answer_simple_question.txt", system_prompt: Optional[str] = None, top_k: int = 10, node_type: Optional[Type] = NodeSet, node_name: Optional[List[str]] = None, - save_interaction: Optional[bool] = False, + save_interaction: bool = False, last_k: Optional[int] = None, - only_context: bool = None, - ): + only_context: bool = False, + context: Optional[Any] = None, + ) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]: # Set database configuration in async context for each dataset user has access for await set_database_global_context_variables(dataset.id, dataset.owner_id) - result = await specific_search( + specific_search_tools = await get_search_type_tools( query_type=query_type, query_text=query_text, - user=user, system_prompt_path=system_prompt_path, system_prompt=system_prompt, top_k=top_k, @@ -287,57 +286,31 @@ async def specific_search_by_context( node_name=node_name, save_interaction=save_interaction, last_k=last_k, - only_context=only_context, ) + search_tools = specific_search_tools + if len(search_tools) == 2: + [get_completion, get_context] = search_tools - if isinstance(result, tuple): - search_results = result[0] - triplets = result[1] + if only_context: + return None, await get_context(query_text), [dataset] + + search_context = context or await get_context(query_text) + search_result = await get_completion(query_text, search_context) + + return search_result, search_context, [dataset] else: - search_results = result - triplets = [] + unknown_tool = search_tools[0] - return { - "search_result": search_results, - "graph": [ - { - "source": { - "id": triplet.node1.id, - "attributes": { - "name": triplet.node1.attributes["name"], - "type": triplet.node1.attributes["type"], - "description": triplet.node1.attributes["description"], - "vector_distance": triplet.node1.attributes["vector_distance"], - }, - }, - "destination": { - "id": triplet.node2.id, - "attributes": { - "name": triplet.node2.attributes["name"], - "type": triplet.node2.attributes["type"], - "description": triplet.node2.attributes["description"], - "vector_distance": triplet.node2.attributes["vector_distance"], - }, - }, - "attributes": { - "relationship_name": triplet.attributes["relationship_name"], - }, - } - for triplet in triplets - ], - "dataset_id": dataset.id, - "dataset_name": dataset.name, - } + return await unknown_tool(query_text), "", [dataset] # Search every dataset async based on query and appropriate database configuration tasks = [] for dataset in search_datasets: tasks.append( - _search_by_context( + _search_in_dataset_context( dataset=dataset, query_type=query_type, query_text=query_text, - user=user, system_prompt_path=system_prompt_path, system_prompt=system_prompt, top_k=top_k, @@ -346,6 +319,7 @@ async def specific_search_by_context( save_interaction=save_interaction, last_k=last_k, only_context=only_context, + context=context, ) ) diff --git a/cognee/modules/search/types/SearchResult.py b/cognee/modules/search/types/SearchResult.py new file mode 100644 index 000000000..8ea5d3990 --- /dev/null +++ b/cognee/modules/search/types/SearchResult.py @@ -0,0 +1,21 @@ +from uuid import UUID +from pydantic import BaseModel +from typing import Any, Dict, List, Optional + + +class SearchResultDataset(BaseModel): + id: UUID + name: str + + +class CombinedSearchResult(BaseModel): + result: Optional[Any] + context: Dict[str, Any] + graphs: Optional[Dict[str, Any]] = {} + datasets: Optional[List[SearchResultDataset]] = None + + +class SearchResult(BaseModel): + search_result: Any + dataset_id: Optional[UUID] + dataset_name: Optional[str] diff --git a/cognee/modules/search/types/__init__.py b/cognee/modules/search/types/__init__.py index 62b49fa74..06e267f95 100644 --- a/cognee/modules/search/types/__init__.py +++ b/cognee/modules/search/types/__init__.py @@ -1 +1,2 @@ from .SearchType import SearchType +from .SearchResult import SearchResult, SearchResultDataset, CombinedSearchResult diff --git a/cognee/modules/search/utils/__init__.py b/cognee/modules/search/utils/__init__.py new file mode 100644 index 000000000..bad84a672 --- /dev/null +++ b/cognee/modules/search/utils/__init__.py @@ -0,0 +1,2 @@ +from .prepare_search_result import prepare_search_result +from .transform_context_to_graph import transform_context_to_graph diff --git a/cognee/modules/search/utils/prepare_search_result.py b/cognee/modules/search/utils/prepare_search_result.py new file mode 100644 index 000000000..bdcfa9928 --- /dev/null +++ b/cognee/modules/search/utils/prepare_search_result.py @@ -0,0 +1,41 @@ +from typing import List, cast + +from cognee.modules.graph.utils import resolve_edges_to_text +from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge +from cognee.modules.search.utils.transform_context_to_graph import transform_context_to_graph + + +async def prepare_search_result(search_result): + result, context, datasets = search_result + + graphs = None + result_graph = None + context_texts = {} + + if isinstance(context, List) and len(context) > 0 and isinstance(context[0], Edge): + result_graph = transform_context_to_graph(context) + + graphs = { + "*": result_graph, + } + context_texts = { + "*": await resolve_edges_to_text(context), + } + elif isinstance(context, str): + context_texts = { + "*": context, + } + elif isinstance(context, List) and len(context) > 0 and isinstance(context[0], str): + context_texts = { + "*": "\n".join(cast(List[str], context)), + } + + if isinstance(result, List) and len(result) > 0 and isinstance(result[0], Edge): + result_graph = transform_context_to_graph(result) + + return { + "result": result_graph or result, + "graphs": graphs, + "context": context_texts, + "datasets": datasets, + } diff --git a/cognee/modules/search/utils/transform_context_to_graph.py b/cognee/modules/search/utils/transform_context_to_graph.py new file mode 100644 index 000000000..0bc889575 --- /dev/null +++ b/cognee/modules/search/utils/transform_context_to_graph.py @@ -0,0 +1,38 @@ +from typing import List + +from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge + + +def transform_context_to_graph(context: List[Edge]): + nodes = {} + edges = {} + + for triplet in context: + nodes[triplet.node1.id] = { + "id": triplet.node1.id, + "label": triplet.node1.attributes["name"] + if "name" in triplet.node1.attributes + else triplet.node1.id, + "type": triplet.node1.attributes["type"], + "attributes": triplet.node2.attributes, + } + nodes[triplet.node2.id] = { + "id": triplet.node2.id, + "label": triplet.node2.attributes["name"] + if "name" in triplet.node2.attributes + else triplet.node2.id, + "type": triplet.node2.attributes["type"], + "attributes": triplet.node2.attributes, + } + edges[ + f"{triplet.node1.id}_{triplet.attributes['relationship_name']}_{triplet.node2.id}" + ] = { + "source": triplet.node1.id, + "target": triplet.node2.id, + "label": triplet.attributes["relationship_name"], + } + + return { + "nodes": list(nodes.values()), + "edges": list(edges.values()), + } diff --git a/cognee/tasks/codingagents/coding_rule_associations.py b/cognee/tasks/codingagents/coding_rule_associations.py index c809bc68f..807148275 100644 --- a/cognee/tasks/codingagents/coding_rule_associations.py +++ b/cognee/tasks/codingagents/coding_rule_associations.py @@ -31,7 +31,7 @@ class RuleSet(DataPoint): ) -async def get_existing_rules(rules_nodeset_name: str, return_list: bool = False) -> str: +async def get_existing_rules(rules_nodeset_name: str) -> List[str]: graph_engine = await get_graph_engine() nodes_data, _ = await graph_engine.get_nodeset_subgraph( node_type=NodeSet, node_name=[rules_nodeset_name] @@ -46,9 +46,6 @@ async def get_existing_rules(rules_nodeset_name: str, return_list: bool = False) and "text" in item[1] ] - if not return_list: - existing_rules = "\n".join(f"- {rule}" for rule in existing_rules) - return existing_rules @@ -103,6 +100,7 @@ async def add_rule_associations( graph_engine = await get_graph_engine() existing_rules = await get_existing_rules(rules_nodeset_name=rules_nodeset_name) + existing_rules = "\n".join(f"- {rule}" for rule in existing_rules) user_context = {"chat": data, "rules": existing_rules} diff --git a/cognee/tests/test_kuzu.py b/cognee/tests/test_kuzu.py index 16c7b9cf6..6afd4540a 100644 --- a/cognee/tests/test_kuzu.py +++ b/cognee/tests/test_kuzu.py @@ -94,21 +94,21 @@ async def main(): await cognee.cognify([dataset_name]) - context_nonempty, _ = await GraphCompletionRetriever( + context_nonempty = await GraphCompletionRetriever( node_type=NodeSet, node_name=["first"], ).get_context("What is in the context?") - context_empty, _ = await GraphCompletionRetriever( + context_empty = await GraphCompletionRetriever( node_type=NodeSet, node_name=["nonexistent"], ).get_context("What is in the context?") - assert isinstance(context_nonempty, str) and context_nonempty != "", ( + assert isinstance(context_nonempty, list) and context_nonempty != [], ( f"Nodeset_search_test:Expected non-empty string for context_nonempty, got: {context_nonempty!r}" ) - assert context_empty == "", ( + assert context_empty == [], ( f"Nodeset_search_test:Expected empty string for context_empty, got: {context_empty!r}" ) diff --git a/cognee/tests/test_neo4j.py b/cognee/tests/test_neo4j.py index d5ccbc19e..7f24e8418 100644 --- a/cognee/tests/test_neo4j.py +++ b/cognee/tests/test_neo4j.py @@ -98,21 +98,21 @@ async def main(): await cognee.cognify([dataset_name]) - context_nonempty, _ = await GraphCompletionRetriever( + context_nonempty = await GraphCompletionRetriever( node_type=NodeSet, node_name=["first"], ).get_context("What is in the context?") - context_empty, _ = await GraphCompletionRetriever( + context_empty = await GraphCompletionRetriever( node_type=NodeSet, node_name=["nonexistent"], ).get_context("What is in the context?") - assert isinstance(context_nonempty, str) and context_nonempty != "", ( + assert isinstance(context_nonempty, list) and context_nonempty != [], ( f"Nodeset_search_test:Expected non-empty string for context_nonempty, got: {context_nonempty!r}" ) - assert context_empty == "", ( + assert context_empty == [], ( f"Nodeset_search_test:Expected empty string for context_empty, got: {context_empty!r}" ) diff --git a/cognee/tests/test_permissions.py b/cognee/tests/test_permissions.py index 95f769263..cfa3aade2 100644 --- a/cognee/tests/test_permissions.py +++ b/cognee/tests/test_permissions.py @@ -79,7 +79,7 @@ async def main(): print("\n\nExtracted sentences are:\n") for result in search_results: print(f"{result}\n") - assert search_results[0]["dataset_name"] == "NLP", ( + assert search_results[0].dataset_name == "NLP", ( f"Dict must contain dataset name 'NLP': {search_results[0]}" ) @@ -93,7 +93,7 @@ async def main(): print("\n\nExtracted sentences are:\n") for result in search_results: print(f"{result}\n") - assert search_results[0]["dataset_name"] == "QUANTUM", ( + assert search_results[0].dataset_name == "QUANTUM", ( f"Dict must contain dataset name 'QUANTUM': {search_results[0]}" ) @@ -170,7 +170,7 @@ async def main(): for result in search_results: print(f"{result}\n") - assert search_results[0]["dataset_name"] == "QUANTUM", ( + assert search_results[0].dataset_name == "QUANTUM", ( f"Dict must contain dataset name 'QUANTUM': {search_results[0]}" ) diff --git a/cognee/tests/test_relational_db_migration.py b/cognee/tests/test_relational_db_migration.py index 8a9670a7c..49508144f 100644 --- a/cognee/tests/test_relational_db_migration.py +++ b/cognee/tests/test_relational_db_migration.py @@ -1,6 +1,6 @@ -import json import pathlib import os +from typing import List from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.relational import ( get_migration_relational_engine, @@ -10,7 +10,7 @@ from cognee.infrastructure.databases.vector.pgvector import ( create_db_and_tables as create_pgvector_db_and_tables, ) from cognee.tasks.ingestion import migrate_relational_database -from cognee.modules.search.types import SearchType +from cognee.modules.search.types import SearchResult, SearchType import cognee @@ -45,13 +45,15 @@ async def relational_db_migration(): await migrate_relational_database(graph_engine, schema=schema) # 1. Search the graph - search_results = await cognee.search( + search_results: List[SearchResult] = await cognee.search( query_type=SearchType.GRAPH_COMPLETION, query_text="Tell me about the artist AC/DC" - ) + ) # type: ignore print("Search results:", search_results) # 2. Assert that the search results contain "AC/DC" - assert any("AC/DC" in r for r in search_results), "AC/DC not found in search results!" + assert any("AC/DC" in r.search_result for r in search_results), ( + "AC/DC not found in search results!" + ) migration_db_provider = migration_engine.engine.dialect.name if migration_db_provider == "postgresql": diff --git a/cognee/tests/test_search_db.py b/cognee/tests/test_search_db.py index ce0d8d473..62b07f31a 100644 --- a/cognee/tests/test_search_db.py +++ b/cognee/tests/test_search_db.py @@ -1,11 +1,7 @@ -import os -import pathlib - -from dns.e164 import query - import cognee from cognee.infrastructure.databases.graph import get_graph_engine from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge +from cognee.modules.graph.utils import resolve_edges_to_text from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever from cognee.modules.retrieval.graph_completion_context_extension_retriever import ( GraphCompletionContextExtensionRetriever, @@ -14,11 +10,8 @@ from cognee.modules.retrieval.graph_completion_cot_retriever import GraphComplet from cognee.modules.retrieval.graph_summary_completion_retriever import ( GraphSummaryCompletionRetriever, ) -from cognee.modules.search.operations import get_history -from cognee.modules.users.methods import get_default_user from cognee.shared.logging_utils import get_logger from cognee.modules.search.types import SearchType -from cognee.modules.engine.models import NodeSet from collections import Counter logger = get_logger() @@ -46,16 +39,16 @@ async def main(): await cognee.cognify([dataset_name]) - context_gk, _ = await GraphCompletionRetriever().get_context( + context_gk = await GraphCompletionRetriever().get_context( query="Next to which country is Germany located?" ) - context_gk_cot, _ = await GraphCompletionCotRetriever().get_context( + context_gk_cot = await GraphCompletionCotRetriever().get_context( query="Next to which country is Germany located?" ) - context_gk_ext, _ = await GraphCompletionContextExtensionRetriever().get_context( + context_gk_ext = await GraphCompletionContextExtensionRetriever().get_context( query="Next to which country is Germany located?" ) - context_gk_sum, _ = await GraphSummaryCompletionRetriever().get_context( + context_gk_sum = await GraphSummaryCompletionRetriever().get_context( query="Next to which country is Germany located?" ) @@ -65,9 +58,11 @@ async def main(): ("GraphCompletionContextExtensionRetriever", context_gk_ext), ("GraphSummaryCompletionRetriever", context_gk_sum), ]: - assert isinstance(context, str), f"{name}: Context should be a string" - assert context.strip(), f"{name}: Context should not be empty" - lower = context.lower() + assert isinstance(context, list), f"{name}: Context should be a list" + assert len(context) > 0, f"{name}: Context should not be empty" + + context_text = await resolve_edges_to_text(context) + lower = context_text.lower() assert "germany" in lower or "netherlands" in lower, ( f"{name}: Context did not contain 'germany' or 'netherlands'; got: {context!r}" ) @@ -143,20 +138,19 @@ async def main(): last_k=1, ) - for name, completion in [ + for name, search_results in [ ("GRAPH_COMPLETION", completion_gk), ("GRAPH_COMPLETION_COT", completion_cot), ("GRAPH_COMPLETION_CONTEXT_EXTENSION", completion_ext), ("GRAPH_SUMMARY_COMPLETION", completion_sum), ]: - assert isinstance(completion, list), f"{name}: should return a list" - assert len(completion) == 1, f"{name}: expected single-element list, got {len(completion)}" - text = completion[0] - assert isinstance(text, str), f"{name}: element should be a string" - assert text.strip(), f"{name}: string should not be empty" - assert "netherlands" in text.lower(), ( - f"{name}: expected 'netherlands' in result, got: {text!r}" - ) + for search_result in search_results: + completion = search_result.search_result + assert isinstance(completion, str), f"{name}: should return a string" + assert completion.strip(), f"{name}: string should not be empty" + assert "netherlands" in completion.lower(), ( + f"{name}: expected 'netherlands' in result, got: {completion!r}" + ) graph_engine = await get_graph_engine() graph = await graph_engine.get_graph_data() diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py index 26ae2f883..02e3f73e2 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py @@ -6,6 +6,7 @@ from typing import Optional, Union import cognee from cognee.low_level import setup, DataPoint from cognee.tasks.storage import add_data_points +from cognee.modules.graph.utils import resolve_edges_to_text from cognee.infrastructure.databases.exceptions import DatabaseNotCreatedError from cognee.modules.retrieval.graph_completion_context_extension_retriever import ( GraphCompletionContextExtensionRetriever, @@ -51,17 +52,15 @@ class TestGraphCompletionWithContextExtensionRetriever: retriever = GraphCompletionContextExtensionRetriever() - context, _ = await retriever.get_context("Who works at Canva?") + context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?")) assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski" assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer" answer = await retriever.get_completion("Who works at Canva?") - assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" - assert all(isinstance(item, str) and item.strip() for item in answer), ( - "Answer must contain only non-empty strings" - ) + assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}" + assert answer.strip(), "Answer must contain only non-empty strings" @pytest.mark.asyncio async def test_graph_completion_extension_context_complex(self): @@ -129,7 +128,9 @@ class TestGraphCompletionWithContextExtensionRetriever: retriever = GraphCompletionContextExtensionRetriever(top_k=20) - context, _ = await retriever.get_context("Who works at Figma?") + context = await resolve_edges_to_text( + await retriever.get_context("Who works at Figma and drives Tesla?") + ) print(context) @@ -139,10 +140,8 @@ class TestGraphCompletionWithContextExtensionRetriever: answer = await retriever.get_completion("Who works at Figma?") - assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" - assert all(isinstance(item, str) and item.strip() for item in answer), ( - "Answer must contain only non-empty strings" - ) + assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}" + assert answer.strip(), "Answer must contain only non-empty strings" @pytest.mark.asyncio async def test_get_graph_completion_extension_context_on_empty_graph(self): @@ -167,12 +166,10 @@ class TestGraphCompletionWithContextExtensionRetriever: await setup() - context, _ = await retriever.get_context("Who works at Figma?") - assert context == "", "Context should be empty on an empty graph" + context = await retriever.get_context("Who works at Figma?") + assert context == [], "Context should be empty on an empty graph" answer = await retriever.get_completion("Who works at Figma?") - assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" - assert all(isinstance(item, str) and item.strip() for item in answer), ( - "Answer must contain only non-empty strings" - ) + assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}" + assert answer.strip(), "Answer must contain only non-empty strings" diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py index be25299aa..54fa12f41 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py @@ -5,6 +5,7 @@ from typing import Optional, Union import cognee from cognee.low_level import setup, DataPoint +from cognee.modules.graph.utils import resolve_edges_to_text from cognee.tasks.storage import add_data_points from cognee.infrastructure.databases.exceptions import DatabaseNotCreatedError from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever @@ -47,17 +48,15 @@ class TestGraphCompletionCoTRetriever: retriever = GraphCompletionCotRetriever() - context, _ = await retriever.get_context("Who works at Canva?") + context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?")) assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski" assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer" answer = await retriever.get_completion("Who works at Canva?") - assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" - assert all(isinstance(item, str) and item.strip() for item in answer), ( - "Answer must contain only non-empty strings" - ) + assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}" + assert answer.strip(), "Answer must contain only non-empty strings" @pytest.mark.asyncio async def test_graph_completion_cot_context_complex(self): @@ -124,7 +123,7 @@ class TestGraphCompletionCoTRetriever: retriever = GraphCompletionCotRetriever(top_k=20) - context, _ = await retriever.get_context("Who works at Figma?") + context = await resolve_edges_to_text(await retriever.get_context("Who works at Figma?")) print(context) @@ -134,10 +133,8 @@ class TestGraphCompletionCoTRetriever: answer = await retriever.get_completion("Who works at Figma?") - assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" - assert all(isinstance(item, str) and item.strip() for item in answer), ( - "Answer must contain only non-empty strings" - ) + assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}" + assert answer.strip(), "Answer must contain only non-empty strings" @pytest.mark.asyncio async def test_get_graph_completion_cot_context_on_empty_graph(self): @@ -162,12 +159,10 @@ class TestGraphCompletionCoTRetriever: await setup() - context, _ = await retriever.get_context("Who works at Figma?") - assert context == "", "Context should be empty on an empty graph" + context = await retriever.get_context("Who works at Figma?") + assert context == [], "Context should be empty on an empty graph" answer = await retriever.get_completion("Who works at Figma?") - assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" - assert all(isinstance(item, str) and item.strip() for item in answer), ( - "Answer must contain only non-empty strings" - ) + assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}" + assert answer.strip(), "Answer must contain only non-empty strings" diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py index 50784f94a..18fd94114 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py @@ -5,6 +5,7 @@ from typing import Optional, Union import cognee from cognee.low_level import setup, DataPoint +from cognee.modules.graph.utils import resolve_edges_to_text from cognee.tasks.storage import add_data_points from cognee.infrastructure.databases.exceptions import DatabaseNotCreatedError from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever @@ -67,7 +68,7 @@ class TestGraphCompletionRetriever: retriever = GraphCompletionRetriever() - context, _ = await retriever.get_context("Who works at Canva?") + context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?")) # Ensure the top-level sections are present assert "Nodes:" in context, "Missing 'Nodes:' section in context" @@ -191,7 +192,7 @@ class TestGraphCompletionRetriever: retriever = GraphCompletionRetriever(top_k=20) - context, _ = await retriever.get_context("Who works at Figma?") + context = await resolve_edges_to_text(await retriever.get_context("Who works at Figma?")) print(context) @@ -222,5 +223,5 @@ class TestGraphCompletionRetriever: await setup() - context, _ = await retriever.get_context("Who works at Figma?") - assert context == "", "Context should be empty on an empty graph" + context = await retriever.get_context("Who works at Figma?") + assert context == [], "Context should be empty on an empty graph" diff --git a/cognee/tests/unit/modules/retrieval/insights_retriever_test.py b/cognee/tests/unit/modules/retrieval/insights_retriever_test.py index 21dbc98dd..a3d9da63a 100644 --- a/cognee/tests/unit/modules/retrieval/insights_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/insights_retriever_test.py @@ -82,7 +82,7 @@ class TestInsightsRetriever: context = await retriever.get_context("Mike") - assert context[0][0]["name"] == "Mike Broski", "Failed to get Mike Broski" + assert context[0].node1.attributes["name"] == "Mike Broski", "Failed to get Mike Broski" @pytest.mark.asyncio async def test_insights_context_complex(self): @@ -222,7 +222,9 @@ class TestInsightsRetriever: context = await retriever.get_context("Christina") - assert context[0][0]["name"] == "Christina Mayer", "Failed to get Christina Mayer" + assert context[0].node1.attributes["name"] == "Christina Mayer", ( + "Failed to get Christina Mayer" + ) @pytest.mark.asyncio async def test_insights_context_on_empty_graph(self): diff --git a/cognee/tests/unit/modules/search/search_methods_test.py b/cognee/tests/unit/modules/search/search_methods_test.py deleted file mode 100644 index 3a6bdc51e..000000000 --- a/cognee/tests/unit/modules/search/search_methods_test.py +++ /dev/null @@ -1,230 +0,0 @@ -import json -import uuid -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest - -from cognee.modules.engine.models.node_set import NodeSet -from cognee.modules.search.exceptions import UnsupportedSearchTypeError -from cognee.modules.search.methods.search import search, specific_search -from cognee.modules.search.types import SearchType -from cognee.modules.users.models import User -import sys - -search_module = sys.modules.get("cognee.modules.search.methods.search") - - -@pytest.fixture -def mock_user(): - user = MagicMock(spec=User) - user.id = uuid.uuid4() - return user - - -@pytest.mark.asyncio -@patch.object(search_module, "log_query") -@patch.object(search_module, "log_result") -@patch.object(search_module, "specific_search") -async def test_search( - mock_specific_search, - mock_log_result, - mock_log_query, - mock_user, -): - # Setup - query_text = "test query" - query_type = SearchType.CHUNKS - datasets = ["dataset1", "dataset2"] - - # Mock the query logging - mock_query = MagicMock() - mock_query.id = uuid.uuid4() - mock_log_query.return_value = mock_query - - # Mock document IDs - doc_id1 = uuid.uuid4() - doc_id2 = uuid.uuid4() - - # Mock search results - search_results = [ - {"document_id": str(doc_id1), "content": "Result 1"}, - {"document_id": str(doc_id2), "content": "Result 2"}, - ] - mock_specific_search.return_value = search_results - - # Execute - await search(query_text, query_type, datasets, mock_user) - - # Verify - mock_log_query.assert_called_once_with(query_text, query_type.value, mock_user.id) - mock_specific_search.assert_called_once_with( - query_type=query_type, - query_text=query_text, - user=mock_user, - system_prompt_path="answer_simple_question.txt", - system_prompt=None, - top_k=10, - node_type=NodeSet, - node_name=None, - save_interaction=False, - last_k=None, - only_context=False, - ) - - # Verify result logging - mock_log_result.assert_called_once() - # Check that the first argument is the query ID - assert mock_log_result.call_args[0][0] == mock_query.id - # The second argument should be the JSON string of the filtered results - # We can't directly compare the JSON strings due to potential ordering differences - # So we parse the JSON and compare the objects - logged_results = json.loads(mock_log_result.call_args[0][1]) - assert len(logged_results) == 2 - assert logged_results[0]["document_id"] == str(doc_id1) - assert logged_results[1]["document_id"] == str(doc_id2) - - -@pytest.mark.asyncio -@patch.object(search_module, "SummariesRetriever") -@patch.object(search_module, "send_telemetry") -async def test_specific_search_summaries(mock_send_telemetry, mock_summaries_retriever, mock_user): - # Setup - query = "test query" - query_type = SearchType.SUMMARIES - - # Mock the retriever - mock_retriever = MagicMock() - mock_retriever.get_completion = AsyncMock() - mock_retriever.get_completion.return_value = [{"content": "Summary result"}] - mock_summaries_retriever.return_value = mock_retriever - - # Execute - results = await specific_search(query_type, query, mock_user) - - # Verify - mock_summaries_retriever.assert_called_once() - mock_retriever.get_completion.assert_called_once_with(query) - mock_send_telemetry.assert_called() - assert len(results) == 1 - assert results[0]["content"] == "Summary result" - - -@pytest.mark.asyncio -@patch.object(search_module, "InsightsRetriever") -@patch.object(search_module, "send_telemetry") -async def test_specific_search_insights(mock_send_telemetry, mock_insights_retriever, mock_user): - # Setup - query = "test query" - query_type = SearchType.INSIGHTS - - # Mock the retriever - mock_retriever = MagicMock() - mock_retriever.get_completion = AsyncMock() - mock_retriever.get_completion.return_value = [{"content": "Insight result"}] - mock_insights_retriever.return_value = mock_retriever - - # Execute - results = await specific_search(query_type, query, mock_user) - - # Verify - mock_insights_retriever.assert_called_once() - mock_retriever.get_completion.assert_called_once_with(query) - mock_send_telemetry.assert_called() - assert len(results) == 1 - assert results[0]["content"] == "Insight result" - - -@pytest.mark.asyncio -@patch.object(search_module, "ChunksRetriever") -@patch.object(search_module, "send_telemetry") -async def test_specific_search_chunks(mock_send_telemetry, mock_chunks_retriever, mock_user): - # Setup - query = "test query" - query_type = SearchType.CHUNKS - - # Mock the retriever - mock_retriever = MagicMock() - mock_retriever.get_completion = AsyncMock() - mock_retriever.get_completion.return_value = [{"content": "Chunk result"}] - mock_chunks_retriever.return_value = mock_retriever - - # Execute - results = await specific_search(query_type, query, mock_user) - - # Verify - mock_chunks_retriever.assert_called_once() - mock_retriever.get_completion.assert_called_once_with(query) - mock_send_telemetry.assert_called() - assert len(results) == 1 - assert results[0]["content"] == "Chunk result" - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "selected_type, retriever_name, expected_content, top_k", - [ - (SearchType.RAG_COMPLETION, "CompletionRetriever", "RAG result from lucky search", 10), - (SearchType.CHUNKS, "ChunksRetriever", "Chunk result from lucky search", 5), - (SearchType.SUMMARIES, "SummariesRetriever", "Summary from lucky search", 15), - (SearchType.INSIGHTS, "InsightsRetriever", "Insight result from lucky search", 20), - ], -) -@patch.object(search_module, "select_search_type") -@patch.object(search_module, "send_telemetry") -async def test_specific_search_feeling_lucky( - mock_send_telemetry, - mock_select_search_type, - selected_type, - retriever_name, - expected_content, - top_k, - mock_user, -): - with patch.object(search_module, retriever_name) as mock_retriever_class: - # Setup - query = f"test query for {retriever_name}" - query_type = SearchType.FEELING_LUCKY - - # Mock the intelligent search type selection - mock_select_search_type.return_value = selected_type - - # Mock the retriever - mock_retriever_instance = MagicMock() - mock_retriever_instance.get_completion = AsyncMock( - return_value=[{"content": expected_content}] - ) - mock_retriever_class.return_value = mock_retriever_instance - - # Execute - results = await specific_search(query_type, query, mock_user, top_k=top_k) - - # Verify - mock_select_search_type.assert_called_once_with(query) - - if retriever_name == "CompletionRetriever": - mock_retriever_class.assert_called_once_with( - system_prompt_path="answer_simple_question.txt", - top_k=top_k, - system_prompt=None, - only_context=None, - ) - else: - mock_retriever_class.assert_called_once_with(top_k=top_k) - - mock_retriever_instance.get_completion.assert_called_once_with(query) - mock_send_telemetry.assert_called() - assert len(results) == 1 - assert results[0]["content"] == expected_content - - -@pytest.mark.asyncio -async def test_specific_search_invalid_type(mock_user): - # Setup - query = "test query" - query_type = "INVALID_TYPE" # Not a valid SearchType - - # Execute and verify - with pytest.raises(UnsupportedSearchTypeError) as excinfo: - await specific_search(query_type, query, mock_user) - - assert "Unsupported search type" in str(excinfo.value) diff --git a/examples/python/graphiti_example.py b/examples/python/graphiti_example.py index 471fd1bb7..facce4684 100644 --- a/examples/python/graphiti_example.py +++ b/examples/python/graphiti_example.py @@ -47,6 +47,7 @@ async def main(): query = "When was Kamala Harris in office?" triplets = await brute_force_triplet_search( query=query, + user=user, top_k=3, collections=["graphitinode_content", "graphitinode_name", "graphitinode_summary"], )