feat: implement combined context search (#1341)

<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
Boris 2025-09-10 16:33:08 +02:00 committed by GitHub
parent ba33dca592
commit b1643414d2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
36 changed files with 706 additions and 642 deletions

View file

@ -27,6 +27,7 @@ class SearchPayloadDTO(InDTO):
node_name: Optional[list[str]] = Field(default=None, example=[])
top_k: Optional[int] = Field(default=10)
only_context: bool = Field(default=False)
use_combined_context: bool = Field(default=False)
def get_search_router() -> APIRouter:
@ -115,6 +116,7 @@ def get_search_router() -> APIRouter:
"node_name": payload.node_name,
"top_k": payload.top_k,
"only_context": payload.only_context,
"use_combined_context": payload.use_combined_context,
},
)
@ -131,6 +133,7 @@ def get_search_router() -> APIRouter:
node_name=payload.node_name,
top_k=payload.top_k,
only_context=payload.only_context,
use_combined_context=payload.use_combined_context,
)
return JSONResponse(content=results)

View file

@ -3,7 +3,7 @@ from typing import Union, Optional, List, Type
from cognee.modules.engine.models.node_set import NodeSet
from cognee.modules.users.models import User
from cognee.modules.search.types import SearchType
from cognee.modules.search.types import SearchResult, SearchType, CombinedSearchResult
from cognee.modules.users.methods import get_default_user
from cognee.modules.search.methods import search as search_function
from cognee.modules.data.methods import get_authorized_existing_datasets
@ -13,7 +13,7 @@ from cognee.modules.data.exceptions import DatasetNotFoundError
async def search(
query_text: str,
query_type: SearchType = SearchType.GRAPH_COMPLETION,
user: User = None,
user: Optional[User] = None,
datasets: Optional[Union[list[str], str]] = None,
dataset_ids: Optional[Union[list[UUID], UUID]] = None,
system_prompt_path: str = "answer_simple_question.txt",
@ -24,7 +24,8 @@ async def search(
save_interaction: bool = False,
last_k: Optional[int] = None,
only_context: bool = False,
) -> list:
use_combined_context: bool = False,
) -> Union[List[SearchResult], CombinedSearchResult]:
"""
Search and query the knowledge graph for insights, information, and connections.
@ -193,6 +194,7 @@ async def search(
save_interaction=save_interaction,
last_k=last_k,
only_context=only_context,
use_combined_context=use_combined_context,
)
return filtered_search_results

View file

@ -180,7 +180,7 @@ class CogneeGraph(CogneeAbstractGraph):
logger.error(f"Error mapping vector distances to edges: {str(ex)}")
raise ex
async def calculate_top_triplet_importances(self, k: int) -> List:
async def calculate_top_triplet_importances(self, k: int) -> List[Edge]:
def score(edge):
n1 = edge.node1.attributes.get("vector_distance", 1)
n2 = edge.node2.attributes.get("vector_distance", 1)

View file

@ -1,4 +1,8 @@
async def resolve_edges_to_text(retrieved_edges: list) -> str:
from typing import List
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
async def resolve_edges_to_text(retrieved_edges: List[Edge]) -> str:
"""
Converts retrieved graph edges into a human-readable string format.
@ -13,7 +17,7 @@ async def resolve_edges_to_text(retrieved_edges: list) -> str:
- str: A formatted string representation of the nodes and their connections.
"""
def _get_nodes(retrieved_edges: list) -> dict:
def _get_nodes(retrieved_edges: List[Edge]) -> dict:
def _get_title(text: str, first_n_words: int = 7, top_n_words: int = 3) -> str:
def _top_n_words(text, stop_words=None, top_n=3, separator=", "):
"""Concatenates the top N frequent words in text."""
@ -36,9 +40,9 @@ async def resolve_edges_to_text(retrieved_edges: list) -> str:
return separator.join(top_words)
"""Creates a title, by combining first words with most frequent words from the text."""
first_n_words = text.split()[:first_n_words]
top_n_words = _top_n_words(text, top_n=top_n_words)
return f"{' '.join(first_n_words)}... [{top_n_words}]"
first_words = text.split()[:first_n_words]
top_words = _top_n_words(text, top_n=first_n_words)
return f"{' '.join(first_words)}... [{top_words}]"
"""Creates a dictionary of nodes with their names and content."""
nodes = {}

View file

@ -0,0 +1,18 @@
from typing import List, Optional
from abc import ABC, abstractmethod
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
class BaseGraphRetriever(ABC):
"""Base class for all graph based retrievers."""
@abstractmethod
async def get_context(self, query: str) -> List[Edge]:
"""Retrieves triplets based on the query."""
pass
@abstractmethod
async def get_completion(self, query: str, context: Optional[List[Edge]] = None) -> str:
"""Generates a response using the query and optional context (triplets)."""
pass

View file

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Optional, Callable
from typing import Any, Optional
class BaseRetriever(ABC):

View file

@ -1,3 +1,6 @@
import asyncio
from functools import reduce
from typing import List, Optional
from cognee.shared.logging_utils import get_logger
from cognee.tasks.codingagents.coding_rule_associations import get_existing_rules
@ -7,16 +10,22 @@ logger = get_logger("CodingRulesRetriever")
class CodingRulesRetriever:
"""Retriever for handling codeing rule based searches."""
def __init__(self, rules_nodeset_name="coding_agent_rules"):
def __init__(self, rules_nodeset_name: Optional[List[str]] = None):
if isinstance(rules_nodeset_name, list):
if not rules_nodeset_name:
# If there is no provided nodeset set to coding_agent_rules
rules_nodeset_name = ["coding_agent_rules"]
rules_nodeset_name = rules_nodeset_name[0]
self.rules_nodeset_name = rules_nodeset_name
"""Initialize retriever with search parameters."""
async def get_existing_rules(self, query_text):
return await get_existing_rules(
rules_nodeset_name=self.rules_nodeset_name, return_list=True
)
if self.rules_nodeset_name:
rules_list = await asyncio.gather(
*[
get_existing_rules(rules_nodeset_name=nodeset)
for nodeset in self.rules_nodeset_name
]
)
return reduce(lambda x, y: x + y, rules_list, [])

View file

@ -23,16 +23,14 @@ class CompletionRetriever(BaseRetriever):
self,
user_prompt_path: str = "context_for_question.txt",
system_prompt_path: str = "answer_simple_question.txt",
system_prompt: str = None,
system_prompt: Optional[str] = None,
top_k: Optional[int] = 1,
only_context: bool = False,
):
"""Initialize retriever with optional custom prompt paths."""
self.user_prompt_path = user_prompt_path
self.system_prompt_path = system_prompt_path
self.top_k = top_k if top_k is not None else 1
self.system_prompt = system_prompt
self.only_context = only_context
async def get_context(self, query: str) -> str:
"""
@ -69,7 +67,7 @@ class CompletionRetriever(BaseRetriever):
logger.error("DocumentChunk_text collection not found")
raise NoDataError("No data found in the system, please add data first.") from error
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
async def get_completion(self, query: str, context: Optional[Any] = None) -> str:
"""
Generates an LLM completion using the context.
@ -97,6 +95,5 @@ class CompletionRetriever(BaseRetriever):
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
only_context=self.only_context,
)
return [completion]
return completion

View file

@ -49,6 +49,7 @@ class TripletSearchContextProvider(BaseContextProvider):
tasks = [
brute_force_triplet_search(
query=f"{entity_text} {query}",
user=user,
top_k=self.top_k,
collections=self.collections,
properties_to_project=self.properties_to_project,

View file

@ -1,4 +1,5 @@
from typing import Any, Optional, List, Type
from typing import Optional, List, Type
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.shared.logging_utils import get_logger
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.retrieval.utils.completion import generate_completion
@ -31,7 +32,6 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
only_context: bool = False,
):
super().__init__(
user_prompt_path=user_prompt_path,
@ -41,15 +41,14 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
only_context=only_context,
)
async def get_completion(
self,
query: str,
context: Optional[Any] = None,
context: Optional[List[Edge]] = None,
context_extension_rounds=4,
) -> List[str]:
) -> str:
"""
Extends the context for a given query by retrieving related triplets and generating new
completions based on them.
@ -74,11 +73,12 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
- List[str]: A list containing the generated answer based on the query and the
extended context.
"""
triplets = []
triplets = context
if context is None:
triplets += await self.get_triplets(query)
context = await self.resolve_edges_to_text(triplets)
if triplets is None:
triplets = await self.get_context(query)
context_text = await self.resolve_edges_to_text(triplets)
round_idx = 1
@ -90,15 +90,15 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
)
completion = await generate_completion(
query=query,
context=context,
context=context_text,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
)
triplets += await self.get_triplets(completion)
triplets += await self.get_context(completion)
triplets = list(set(triplets))
context = await self.resolve_edges_to_text(triplets)
context_text = await self.resolve_edges_to_text(triplets)
num_triplets = len(triplets)
@ -117,19 +117,15 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
completion = await generate_completion(
query=query,
context=context,
context=context_text,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
only_context=self.only_context,
)
if self.save_interaction and context and triplets and completion:
if self.save_interaction and context_text and triplets and completion:
await self.save_qa(
question=query, answer=completion, context=context, triplets=triplets
question=query, answer=completion, context=context_text, triplets=triplets
)
if self.only_context:
return [context]
else:
return [completion]
return completion

View file

@ -1,4 +1,5 @@
from typing import Any, Optional, List, Tuple, Type
from typing import Optional, List, Type
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.shared.logging_utils import get_logger
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
@ -32,18 +33,16 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
validation_system_prompt_path: str = "cot_validation_system_prompt.txt",
followup_system_prompt_path: str = "cot_followup_system_prompt.txt",
followup_user_prompt_path: str = "cot_followup_user_prompt.txt",
system_prompt: str = None,
system_prompt: Optional[str] = None,
top_k: Optional[int] = 5,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
only_context: bool = False,
):
super().__init__(
user_prompt_path=user_prompt_path,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
only_context=only_context,
top_k=top_k,
node_type=node_type,
node_name=node_name,
@ -57,9 +56,9 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
async def get_completion(
self,
query: str,
context: Optional[Any] = None,
context: Optional[List[Edge]] = None,
max_iter=4,
) -> List[str]:
) -> str:
"""
Generate completion responses based on a user query and contextual information.
@ -84,26 +83,29 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
"""
followup_question = ""
triplets = []
completion = [""]
completion = ""
for round_idx in range(max_iter + 1):
if round_idx == 0:
if context is None:
context = await self.get_context(query)
triplets = await self.get_context(query)
context_text = await self.resolve_edges_to_text(triplets)
else:
context_text = await self.resolve_edges_to_text(context)
else:
triplets += await self.get_triplets(followup_question)
context = await self.resolve_edges_to_text(list(set(triplets)))
triplets += await self.get_context(followup_question)
context_text = await self.resolve_edges_to_text(list(set(triplets)))
completion = await generate_completion(
query=query,
context=context,
context=context_text,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
)
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
if round_idx < max_iter:
valid_args = {"query": query, "answer": completion, "context": context}
valid_args = {"query": query, "answer": completion, "context": context_text}
valid_user_prompt = LLMGateway.render_prompt(
filename=self.validation_user_prompt_path, context=valid_args
)
@ -133,10 +135,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
if self.save_interaction and context and triplets and completion:
await self.save_qa(
question=query, answer=completion, context=context, triplets=triplets
question=query, answer=completion, context=context_text, triplets=triplets
)
if self.only_context:
return [context]
else:
return [completion]
return completion

View file

@ -1,16 +1,15 @@
from typing import Any, Optional, Type, List
from collections import Counter
from uuid import NAMESPACE_OID, uuid5
import string
from cognee.infrastructure.engine import DataPoint
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.modules.users.methods import get_default_user
from cognee.tasks.storage import add_data_points
from cognee.modules.graph.utils import resolve_edges_to_text
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.base_graph_retriever import BaseGraphRetriever
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
from cognee.modules.retrieval.utils.completion import generate_completion
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
from cognee.shared.logging_utils import get_logger
from cognee.modules.retrieval.utils.extract_uuid_from_node import extract_uuid_from_node
from cognee.modules.retrieval.utils.models import CogneeUserInteraction
@ -20,7 +19,7 @@ from cognee.infrastructure.databases.graph import get_graph_engine
logger = get_logger("GraphCompletionRetriever")
class GraphCompletionRetriever(BaseRetriever):
class GraphCompletionRetriever(BaseGraphRetriever):
"""
Retriever for handling graph-based completion searches.
@ -37,19 +36,17 @@ class GraphCompletionRetriever(BaseRetriever):
self,
user_prompt_path: str = "graph_context_for_question.txt",
system_prompt_path: str = "answer_simple_question.txt",
system_prompt: str = None,
system_prompt: Optional[str] = None,
top_k: Optional[int] = 5,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
only_context: bool = False,
):
"""Initialize retriever with prompt paths and search parameters."""
self.save_interaction = save_interaction
self.user_prompt_path = user_prompt_path
self.system_prompt_path = system_prompt_path
self.system_prompt = system_prompt
self.only_context = only_context
self.top_k = top_k if top_k is not None else 5
self.node_type = node_type
self.node_name = node_name
@ -70,7 +67,7 @@ class GraphCompletionRetriever(BaseRetriever):
"""
return await resolve_edges_to_text(retrieved_edges)
async def get_triplets(self, query: str) -> list:
async def get_triplets(self, query: str) -> List[Edge]:
"""
Retrieves relevant graph triplets based on a query string.
@ -85,7 +82,7 @@ class GraphCompletionRetriever(BaseRetriever):
- list: A list of found triplets that match the query.
"""
subclasses = get_all_subclasses(DataPoint)
vector_index_collections = []
vector_index_collections: List[str] = []
for subclass in subclasses:
if "metadata" in subclass.model_fields:
@ -96,8 +93,11 @@ class GraphCompletionRetriever(BaseRetriever):
for field_name in index_fields:
vector_index_collections.append(f"{subclass.__name__}_{field_name}")
user = await get_default_user()
found_triplets = await brute_force_triplet_search(
query,
user=user,
top_k=self.top_k,
collections=vector_index_collections or None,
node_type=self.node_type,
@ -106,7 +106,7 @@ class GraphCompletionRetriever(BaseRetriever):
return found_triplets
async def get_context(self, query: str) -> tuple[str, list]:
async def get_context(self, query: str) -> List[Edge]:
"""
Retrieves and resolves graph triplets into context based on a query.
@ -125,17 +125,17 @@ class GraphCompletionRetriever(BaseRetriever):
if len(triplets) == 0:
logger.warning("Empty context was provided to the completion")
return "", triplets
return []
context = await self.resolve_edges_to_text(triplets)
# context = await self.resolve_edges_to_text(triplets)
return context, triplets
return triplets
async def get_completion(
self,
query: str,
context: Optional[Any] = None,
) -> List[str]:
context: Optional[List[Edge]] = None,
) -> Any:
"""
Generates a completion using graph connections context based on a query.
@ -151,26 +151,27 @@ class GraphCompletionRetriever(BaseRetriever):
- Any: A generated completion based on the query and context provided.
"""
triplets = None
triplets = context
if context is None:
context, triplets = await self.get_context(query)
if triplets is None:
triplets = await self.get_context(query)
context_text = await resolve_edges_to_text(triplets)
completion = await generate_completion(
query=query,
context=context,
context=context_text,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
system_prompt=self.system_prompt,
only_context=self.only_context,
)
if self.save_interaction and context and triplets and completion:
await self.save_qa(
question=query, answer=completion, context=context, triplets=triplets
question=query, answer=completion, context=context_text, triplets=triplets
)
return [completion]
return completion
async def save_qa(self, question: str, answer: str, context: str, triplets: List) -> None:
"""

View file

@ -1,17 +1,18 @@
import asyncio
from typing import Any, Optional
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge, Node
from cognee.modules.retrieval.base_graph_retriever import BaseGraphRetriever
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
logger = get_logger("InsightsRetriever")
class InsightsRetriever(BaseRetriever):
class InsightsRetriever(BaseGraphRetriever):
"""
Retriever for handling graph connection-based insights.
@ -95,7 +96,17 @@ class InsightsRetriever(BaseRetriever):
unique_node_connections_map[unique_id] = True
unique_node_connections.append(node_connection)
return unique_node_connections
return [
Edge(
node1=Node(node_id=connection[0]["id"], attributes=connection[0]),
node2=Node(node_id=connection[2]["id"], attributes=connection[2]),
attributes={
**connection[1],
"relationship_type": connection[1]["relationship_name"],
},
)
for connection in unique_node_connections
]
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
"""

View file

@ -1,5 +1,5 @@
import os
from typing import Any, Optional, List, Tuple, Type
from typing import Any, Optional, List, Type
from operator import itemgetter
@ -113,8 +113,8 @@ class TemporalRetriever(GraphCompletionRetriever):
logger.info(
"No timestamps identified based on the query, performing retrieval using triplet search on events and entities."
)
triplets = await self.get_triplets(query)
return await self.resolve_edges_to_text(triplets), triplets
triplets = await self.get_context(query)
return await self.resolve_edges_to_text(triplets)
if ids:
relevant_events = await graph_engine.collect_events(ids=ids)
@ -122,8 +122,8 @@ class TemporalRetriever(GraphCompletionRetriever):
logger.info(
"No events identified based on timestamp filtering, performing retrieval using triplet search on events and entities."
)
triplets = await self.get_triplets(query)
return await self.resolve_edges_to_text(triplets), triplets
triplets = await self.get_context(query)
return await self.resolve_edges_to_text(triplets)
vector_engine = get_vector_engine()
query_vector = (await vector_engine.embedding_engine.embed_text([query]))[0]
@ -134,18 +134,19 @@ class TemporalRetriever(GraphCompletionRetriever):
top_k_events = await self.filter_top_k_events(relevant_events, vector_search_results)
return self.descriptions_to_string(top_k_events), triplets
return self.descriptions_to_string(top_k_events)
async def get_completion(self, query: str, context: Optional[Any] = None) -> List[str]:
async def get_completion(self, query: str, context: Optional[str] = None) -> str:
"""Generates a response using the query and optional context."""
if not context:
context = await self.get_context(query=query)
context, triplets = await self.get_context(query=query)
if context:
completion = await generate_completion(
query=query,
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
)
completion = await generate_completion(
query=query,
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
)
return [completion]
return completion

View file

@ -8,7 +8,7 @@ from cognee.infrastructure.databases.vector.exceptions import CollectionNotFound
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
from cognee.modules.users.methods import get_default_user
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.modules.users.models import User
from cognee.shared.utils import send_telemetry
@ -87,41 +87,15 @@ async def get_memory_fragment(
async def brute_force_triplet_search(
query: str,
user: User = None,
top_k: int = 5,
collections: List[str] = None,
properties_to_project: List[str] = None,
memory_fragment: Optional[CogneeGraph] = None,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
) -> list:
if user is None:
user = await get_default_user()
retrieved_results = await brute_force_search(
query,
user,
top_k,
collections=collections,
properties_to_project=properties_to_project,
memory_fragment=memory_fragment,
node_type=node_type,
node_name=node_name,
)
return retrieved_results
async def brute_force_search(
query: str,
user: User,
top_k: int,
collections: List[str] = None,
properties_to_project: List[str] = None,
top_k: int = 5,
collections: Optional[List[str]] = None,
properties_to_project: Optional[List[str]] = None,
memory_fragment: Optional[CogneeGraph] = None,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
) -> list:
) -> List[Edge]:
"""
Performs a brute force search to retrieve the top triplets from the graph.

View file

@ -8,7 +8,6 @@ async def generate_completion(
user_prompt_path: str,
system_prompt_path: str,
system_prompt: Optional[str] = None,
only_context: bool = False,
) -> str:
"""Generates a completion using LLM with given context and prompts."""
args = {"question": query, "context": context}
@ -17,14 +16,11 @@ async def generate_completion(
system_prompt if system_prompt else LLMGateway.read_query_prompt(system_prompt_path)
)
if only_context:
return context
else:
return await LLMGateway.acreate_structured_output(
text_input=user_prompt,
system_prompt=system_prompt,
response_model=str,
)
return await LLMGateway.acreate_structured_output(
text_input=user_prompt,
system_prompt=system_prompt,
response_model=str,
)
async def summarize_text(

View file

@ -0,0 +1,168 @@
from typing import Callable, List, Optional, Type
from cognee.modules.engine.models.node_set import NodeSet
from cognee.modules.search.types import SearchType
from cognee.modules.search.operations import select_search_type
from cognee.modules.search.exceptions import UnsupportedSearchTypeError
# Retrievers
from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
from cognee.modules.retrieval.insights_retriever import InsightsRetriever
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
from cognee.modules.retrieval.coding_rules_retriever import CodingRulesRetriever
from cognee.modules.retrieval.graph_summary_completion_retriever import (
GraphSummaryCompletionRetriever,
)
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
GraphCompletionContextExtensionRetriever,
)
from cognee.modules.retrieval.code_retriever import CodeRetriever
from cognee.modules.retrieval.cypher_search_retriever import CypherSearchRetriever
from cognee.modules.retrieval.natural_language_retriever import NaturalLanguageRetriever
async def get_search_type_tools(
query_type: SearchType,
query_text: str,
system_prompt_path: str = "answer_simple_question.txt",
system_prompt: Optional[str] = None,
top_k: int = 10,
node_type: Optional[Type] = NodeSet,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
last_k: Optional[int] = None,
) -> list:
search_tasks: dict[SearchType, List[Callable]] = {
SearchType.SUMMARIES: [
SummariesRetriever(top_k=top_k).get_completion,
SummariesRetriever(top_k=top_k).get_context,
],
SearchType.INSIGHTS: [
InsightsRetriever(top_k=top_k).get_completion,
InsightsRetriever(top_k=top_k).get_context,
],
SearchType.CHUNKS: [
ChunksRetriever(top_k=top_k).get_completion,
ChunksRetriever(top_k=top_k).get_context,
],
SearchType.RAG_COMPLETION: [
CompletionRetriever(
system_prompt_path=system_prompt_path,
top_k=top_k,
system_prompt=system_prompt,
).get_completion,
CompletionRetriever(
system_prompt_path=system_prompt_path,
top_k=top_k,
system_prompt=system_prompt,
).get_context,
],
SearchType.GRAPH_COMPLETION: [
GraphCompletionRetriever(
system_prompt_path=system_prompt_path,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
).get_completion,
GraphCompletionRetriever(
system_prompt_path=system_prompt_path,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
).get_context,
],
SearchType.GRAPH_COMPLETION_COT: [
GraphCompletionCotRetriever(
system_prompt_path=system_prompt_path,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
).get_completion,
GraphCompletionCotRetriever(
system_prompt_path=system_prompt_path,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
).get_context,
],
SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: [
GraphCompletionContextExtensionRetriever(
system_prompt_path=system_prompt_path,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
).get_completion,
GraphCompletionContextExtensionRetriever(
system_prompt_path=system_prompt_path,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
).get_context,
],
SearchType.GRAPH_SUMMARY_COMPLETION: [
GraphSummaryCompletionRetriever(
system_prompt_path=system_prompt_path,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
).get_completion,
GraphSummaryCompletionRetriever(
system_prompt_path=system_prompt_path,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
).get_context,
],
SearchType.CODE: [
CodeRetriever(top_k=top_k).get_completion,
CodeRetriever(top_k=top_k).get_context,
],
SearchType.CYPHER: [
CypherSearchRetriever().get_completion,
CypherSearchRetriever().get_context,
],
SearchType.NATURAL_LANGUAGE: [
NaturalLanguageRetriever().get_completion,
NaturalLanguageRetriever().get_context,
],
SearchType.FEEDBACK: [UserQAFeedback(last_k=last_k).add_feedback],
SearchType.TEMPORAL: [
TemporalRetriever(top_k=top_k).get_completion,
TemporalRetriever(top_k=top_k).get_context,
],
SearchType.CODING_RULES: [
CodingRulesRetriever(rules_nodeset_name=node_name).get_existing_rules,
],
}
# If the query type is FEELING_LUCKY, select the search type intelligently
if query_type is SearchType.FEELING_LUCKY:
query_type = await select_search_type(query_text)
search_type_tools = search_tasks.get(query_type)
if not search_type_tools:
raise UnsupportedSearchTypeError(str(query_type))
return search_type_tools

View file

@ -0,0 +1,47 @@
from typing import Any, List, Optional, Tuple, Type, Union
from cognee.modules.data.models.Dataset import Dataset
from cognee.modules.engine.models.node_set import NodeSet
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.modules.search.types import SearchType
from .get_search_type_tools import get_search_type_tools
async def no_access_control_search(
query_type: SearchType,
query_text: str,
system_prompt_path: str = "answer_simple_question.txt",
system_prompt: Optional[str] = None,
top_k: int = 10,
node_type: Optional[Type] = NodeSet,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
last_k: Optional[int] = None,
only_context: bool = False,
) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]:
search_tools = await get_search_type_tools(
query_type=query_type,
query_text=query_text,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
)
if len(search_tools) == 2:
[get_completion, get_context] = search_tools
if only_context:
return await get_context(query_text)
context = await get_context(query_text)
result = await get_completion(query_text, context)
else:
unknown_tool = search_tools[0]
result = await unknown_tool(query_text)
context = ""
return result, context, []

View file

@ -3,37 +3,27 @@ import json
import asyncio
from uuid import UUID
from fastapi.encoders import jsonable_encoder
from typing import Callable, List, Optional, Type, Union
from typing import Any, List, Optional, Tuple, Type, Union
from cognee.shared.utils import send_telemetry
from cognee.context_global_variables import set_database_global_context_variables
from cognee.modules.engine.models.node_set import NodeSet
from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback
from cognee.modules.search.exceptions import UnsupportedSearchTypeError
from cognee.context_global_variables import set_database_global_context_variables
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
from cognee.modules.retrieval.insights_retriever import InsightsRetriever
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.retrieval.temporal_retriever import TemporalRetriever
from cognee.modules.retrieval.coding_rules_retriever import CodingRulesRetriever
from cognee.modules.retrieval.graph_summary_completion_retriever import (
GraphSummaryCompletionRetriever,
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.modules.search.types import (
SearchResult,
CombinedSearchResult,
SearchResultDataset,
SearchType,
)
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
GraphCompletionContextExtensionRetriever,
)
from cognee.modules.retrieval.code_retriever import CodeRetriever
from cognee.modules.retrieval.cypher_search_retriever import CypherSearchRetriever
from cognee.modules.retrieval.natural_language_retriever import NaturalLanguageRetriever
from cognee.modules.search.types import SearchType
from cognee.modules.storage.utils import JSONEncoder
from cognee.modules.search.operations import log_query, log_result
from cognee.modules.users.models import User
from cognee.modules.data.models import Dataset
from cognee.shared.utils import send_telemetry
from cognee.modules.users.permissions.methods import get_specific_user_permission_datasets
from cognee.modules.search.operations import log_query, log_result, select_search_type
from .get_search_type_tools import get_search_type_tools
from .no_access_control_search import no_access_control_search
from ..utils.prepare_search_result import prepare_search_result
async def search(
@ -46,10 +36,11 @@ async def search(
top_k: int = 10,
node_type: Optional[Type] = NodeSet,
node_name: Optional[List[str]] = None,
save_interaction: Optional[bool] = False,
save_interaction: bool = False,
last_k: Optional[int] = None,
only_context: bool = False,
):
use_combined_context: bool = False,
) -> Union[CombinedSearchResult, List[SearchResult]]:
"""
Args:
@ -65,9 +56,12 @@ async def search(
Notes:
Searching by dataset is only available in ENABLE_BACKEND_ACCESS_CONTROL mode
"""
query = await log_query(query_text, query_type.value, user.id)
send_telemetry("cognee.search EXECUTION STARTED", user.id)
# Use search function filtered by permissions if access control is enabled
if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
return await authorized_search(
search_results = await authorized_search(
query_type=query_type,
query_text=query_text,
user=user,
@ -80,119 +74,68 @@ async def search(
save_interaction=save_interaction,
last_k=last_k,
only_context=only_context,
use_combined_context=use_combined_context,
)
else:
search_results = [
await no_access_control_search(
query_type=query_type,
query_text=query_text,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
only_context=only_context,
)
]
query = await log_query(query_text, query_type.value, user.id)
search_results = await specific_search(
query_type=query_type,
query_text=query_text,
user=user,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
only_context=only_context,
)
send_telemetry("cognee.search EXECUTION COMPLETED", user.id)
await log_result(
query.id,
json.dumps(
search_results if len(search_results) > 1 else search_results[0], cls=JSONEncoder
jsonable_encoder(
await prepare_search_result(search_results)
if use_combined_context
else [
await prepare_search_result(search_result) for search_result in search_results
]
)
),
user.id,
)
return search_results
if use_combined_context:
prepared_search_results = await prepare_search_result(search_results)
result = prepared_search_results["result"]
graphs = prepared_search_results["graphs"]
context = prepared_search_results["context"]
datasets = prepared_search_results["datasets"]
async def specific_search(
query_type: SearchType,
query_text: str,
user: User,
system_prompt_path: str = "answer_simple_question.txt",
system_prompt: Optional[str] = None,
top_k: int = 10,
node_type: Optional[Type] = NodeSet,
node_name: Optional[List[str]] = None,
save_interaction: Optional[bool] = False,
last_k: Optional[int] = None,
only_context: bool = None,
) -> list:
search_tasks: dict[SearchType, Callable] = {
SearchType.SUMMARIES: SummariesRetriever(top_k=top_k).get_completion,
SearchType.INSIGHTS: InsightsRetriever(top_k=top_k).get_completion,
SearchType.CHUNKS: ChunksRetriever(top_k=top_k).get_completion,
SearchType.RAG_COMPLETION: CompletionRetriever(
system_prompt_path=system_prompt_path,
top_k=top_k,
system_prompt=system_prompt,
only_context=only_context,
).get_completion,
SearchType.GRAPH_COMPLETION: GraphCompletionRetriever(
system_prompt_path=system_prompt_path,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
only_context=only_context,
).get_completion,
SearchType.GRAPH_COMPLETION_COT: GraphCompletionCotRetriever(
system_prompt_path=system_prompt_path,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
only_context=only_context,
).get_completion,
SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: GraphCompletionContextExtensionRetriever(
system_prompt_path=system_prompt_path,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
only_context=only_context,
).get_completion,
SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever(
system_prompt_path=system_prompt_path,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
system_prompt=system_prompt,
).get_completion,
SearchType.CODE: CodeRetriever(top_k=top_k).get_completion,
SearchType.CYPHER: CypherSearchRetriever().get_completion,
SearchType.NATURAL_LANGUAGE: NaturalLanguageRetriever().get_completion,
SearchType.FEEDBACK: UserQAFeedback(last_k=last_k).add_feedback,
SearchType.TEMPORAL: TemporalRetriever(top_k=top_k).get_completion,
SearchType.CODING_RULES: CodingRulesRetriever(
rules_nodeset_name=node_name
).get_existing_rules,
}
# If the query type is FEELING_LUCKY, select the search type intelligently
if query_type is SearchType.FEELING_LUCKY:
query_type = await select_search_type(query_text)
search_task = search_tasks.get(query_type)
if search_task is None:
raise UnsupportedSearchTypeError(str(query_type))
send_telemetry("cognee.search EXECUTION STARTED", user.id)
results = await search_task(query_text)
send_telemetry("cognee.search EXECUTION COMPLETED", user.id)
return results
return CombinedSearchResult(
result=result,
graphs=graphs,
context=context,
datasets=[
SearchResultDataset(
id=dataset.id,
name=dataset.name,
)
for dataset in datasets
],
)
else:
return [
SearchResult(
search_result=result,
dataset_id=datasets[min(index, len(datasets) - 1)].id if datasets else None,
dataset_name=datasets[min(index, len(datasets) - 1)].name if datasets else None,
)
for index, (result, _, datasets) in enumerate(search_results)
]
async def authorized_search(
@ -205,26 +148,85 @@ async def authorized_search(
top_k: int = 10,
node_type: Optional[Type] = NodeSet,
node_name: Optional[List[str]] = None,
save_interaction: Optional[bool] = False,
save_interaction: bool = False,
last_k: Optional[int] = None,
only_context: bool = None,
) -> list:
only_context: bool = False,
use_combined_context: bool = False,
) -> Union[
Tuple[Any, Union[List[Edge], str], List[Dataset]],
List[Tuple[Any, Union[List[Edge], str], List[Dataset]]],
]:
"""
Verifies access for provided datasets or uses all datasets user has read access for and performs search per dataset.
Not to be used outside of active access control mode.
"""
query = await log_query(query_text, query_type.value, user.id)
# Find datasets user has read access for (if datasets are provided only return them. Provided user has read access)
search_datasets = await get_specific_user_permission_datasets(user.id, "read", dataset_ids)
if use_combined_context:
search_responses = await search_in_datasets_context(
search_datasets=search_datasets,
query_type=query_type,
query_text=query_text,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
only_context=True,
)
context = {}
datasets: List[Dataset] = []
for _, search_context, datasets in search_responses:
for dataset in datasets:
context[str(dataset.id)] = search_context
datasets.extend(datasets)
specific_search_tools = await get_search_type_tools(
query_type=query_type,
query_text=query_text,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
)
search_tools = specific_search_tools
if len(search_tools) == 2:
[get_completion, _] = search_tools
else:
get_completion = search_tools[0]
def prepare_combined_context(
context,
) -> Union[List[Edge], str]:
combined_context = []
for dataset_context in context.values():
combined_context += dataset_context
if combined_context and isinstance(combined_context[0], str):
return "\n".join(combined_context)
return combined_context
combined_context = prepare_combined_context(context)
completion = await get_completion(query_text, combined_context)
return completion, combined_context, datasets
# Searches all provided datasets and handles setting up of appropriate database context based on permissions
search_results = await specific_search_by_context(
search_results = await search_in_datasets_context(
search_datasets=search_datasets,
query_type=query_type,
query_text=query_text,
user=user,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
top_k=top_k,
@ -235,51 +237,48 @@ async def authorized_search(
only_context=only_context,
)
await log_result(query.id, json.dumps(jsonable_encoder(search_results)), user.id)
return search_results
async def specific_search_by_context(
async def search_in_datasets_context(
search_datasets: list[Dataset],
query_type: SearchType,
query_text: str,
user: User,
system_prompt_path: str = "answer_simple_question.txt",
system_prompt: Optional[str] = None,
top_k: int = 10,
node_type: Optional[Type] = NodeSet,
node_name: Optional[List[str]] = None,
save_interaction: Optional[bool] = False,
save_interaction: bool = False,
last_k: Optional[int] = None,
only_context: bool = None,
):
only_context: bool = False,
context: Optional[Any] = None,
) -> List[Tuple[Any, Union[str, List[Edge]], List[Dataset]]]:
"""
Searches all provided datasets and handles setting up of appropriate database context based on permissions.
Not to be used outside of active access control mode.
"""
async def _search_by_context(
async def _search_in_dataset_context(
dataset: Dataset,
query_type: SearchType,
query_text: str,
user: User,
system_prompt_path: str = "answer_simple_question.txt",
system_prompt: Optional[str] = None,
top_k: int = 10,
node_type: Optional[Type] = NodeSet,
node_name: Optional[List[str]] = None,
save_interaction: Optional[bool] = False,
save_interaction: bool = False,
last_k: Optional[int] = None,
only_context: bool = None,
):
only_context: bool = False,
context: Optional[Any] = None,
) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]:
# Set database configuration in async context for each dataset user has access for
await set_database_global_context_variables(dataset.id, dataset.owner_id)
result = await specific_search(
specific_search_tools = await get_search_type_tools(
query_type=query_type,
query_text=query_text,
user=user,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
top_k=top_k,
@ -287,57 +286,31 @@ async def specific_search_by_context(
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
only_context=only_context,
)
search_tools = specific_search_tools
if len(search_tools) == 2:
[get_completion, get_context] = search_tools
if isinstance(result, tuple):
search_results = result[0]
triplets = result[1]
if only_context:
return None, await get_context(query_text), [dataset]
search_context = context or await get_context(query_text)
search_result = await get_completion(query_text, search_context)
return search_result, search_context, [dataset]
else:
search_results = result
triplets = []
unknown_tool = search_tools[0]
return {
"search_result": search_results,
"graph": [
{
"source": {
"id": triplet.node1.id,
"attributes": {
"name": triplet.node1.attributes["name"],
"type": triplet.node1.attributes["type"],
"description": triplet.node1.attributes["description"],
"vector_distance": triplet.node1.attributes["vector_distance"],
},
},
"destination": {
"id": triplet.node2.id,
"attributes": {
"name": triplet.node2.attributes["name"],
"type": triplet.node2.attributes["type"],
"description": triplet.node2.attributes["description"],
"vector_distance": triplet.node2.attributes["vector_distance"],
},
},
"attributes": {
"relationship_name": triplet.attributes["relationship_name"],
},
}
for triplet in triplets
],
"dataset_id": dataset.id,
"dataset_name": dataset.name,
}
return await unknown_tool(query_text), "", [dataset]
# Search every dataset async based on query and appropriate database configuration
tasks = []
for dataset in search_datasets:
tasks.append(
_search_by_context(
_search_in_dataset_context(
dataset=dataset,
query_type=query_type,
query_text=query_text,
user=user,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
top_k=top_k,
@ -346,6 +319,7 @@ async def specific_search_by_context(
save_interaction=save_interaction,
last_k=last_k,
only_context=only_context,
context=context,
)
)

View file

@ -0,0 +1,21 @@
from uuid import UUID
from pydantic import BaseModel
from typing import Any, Dict, List, Optional
class SearchResultDataset(BaseModel):
id: UUID
name: str
class CombinedSearchResult(BaseModel):
result: Optional[Any]
context: Dict[str, Any]
graphs: Optional[Dict[str, Any]] = {}
datasets: Optional[List[SearchResultDataset]] = None
class SearchResult(BaseModel):
search_result: Any
dataset_id: Optional[UUID]
dataset_name: Optional[str]

View file

@ -1 +1,2 @@
from .SearchType import SearchType
from .SearchResult import SearchResult, SearchResultDataset, CombinedSearchResult

View file

@ -0,0 +1,2 @@
from .prepare_search_result import prepare_search_result
from .transform_context_to_graph import transform_context_to_graph

View file

@ -0,0 +1,41 @@
from typing import List, cast
from cognee.modules.graph.utils import resolve_edges_to_text
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.modules.search.utils.transform_context_to_graph import transform_context_to_graph
async def prepare_search_result(search_result):
result, context, datasets = search_result
graphs = None
result_graph = None
context_texts = {}
if isinstance(context, List) and len(context) > 0 and isinstance(context[0], Edge):
result_graph = transform_context_to_graph(context)
graphs = {
"*": result_graph,
}
context_texts = {
"*": await resolve_edges_to_text(context),
}
elif isinstance(context, str):
context_texts = {
"*": context,
}
elif isinstance(context, List) and len(context) > 0 and isinstance(context[0], str):
context_texts = {
"*": "\n".join(cast(List[str], context)),
}
if isinstance(result, List) and len(result) > 0 and isinstance(result[0], Edge):
result_graph = transform_context_to_graph(result)
return {
"result": result_graph or result,
"graphs": graphs,
"context": context_texts,
"datasets": datasets,
}

View file

@ -0,0 +1,38 @@
from typing import List
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
def transform_context_to_graph(context: List[Edge]):
nodes = {}
edges = {}
for triplet in context:
nodes[triplet.node1.id] = {
"id": triplet.node1.id,
"label": triplet.node1.attributes["name"]
if "name" in triplet.node1.attributes
else triplet.node1.id,
"type": triplet.node1.attributes["type"],
"attributes": triplet.node2.attributes,
}
nodes[triplet.node2.id] = {
"id": triplet.node2.id,
"label": triplet.node2.attributes["name"]
if "name" in triplet.node2.attributes
else triplet.node2.id,
"type": triplet.node2.attributes["type"],
"attributes": triplet.node2.attributes,
}
edges[
f"{triplet.node1.id}_{triplet.attributes['relationship_name']}_{triplet.node2.id}"
] = {
"source": triplet.node1.id,
"target": triplet.node2.id,
"label": triplet.attributes["relationship_name"],
}
return {
"nodes": list(nodes.values()),
"edges": list(edges.values()),
}

View file

@ -31,7 +31,7 @@ class RuleSet(DataPoint):
)
async def get_existing_rules(rules_nodeset_name: str, return_list: bool = False) -> str:
async def get_existing_rules(rules_nodeset_name: str) -> List[str]:
graph_engine = await get_graph_engine()
nodes_data, _ = await graph_engine.get_nodeset_subgraph(
node_type=NodeSet, node_name=[rules_nodeset_name]
@ -46,9 +46,6 @@ async def get_existing_rules(rules_nodeset_name: str, return_list: bool = False)
and "text" in item[1]
]
if not return_list:
existing_rules = "\n".join(f"- {rule}" for rule in existing_rules)
return existing_rules
@ -103,6 +100,7 @@ async def add_rule_associations(
graph_engine = await get_graph_engine()
existing_rules = await get_existing_rules(rules_nodeset_name=rules_nodeset_name)
existing_rules = "\n".join(f"- {rule}" for rule in existing_rules)
user_context = {"chat": data, "rules": existing_rules}

View file

@ -94,21 +94,21 @@ async def main():
await cognee.cognify([dataset_name])
context_nonempty, _ = await GraphCompletionRetriever(
context_nonempty = await GraphCompletionRetriever(
node_type=NodeSet,
node_name=["first"],
).get_context("What is in the context?")
context_empty, _ = await GraphCompletionRetriever(
context_empty = await GraphCompletionRetriever(
node_type=NodeSet,
node_name=["nonexistent"],
).get_context("What is in the context?")
assert isinstance(context_nonempty, str) and context_nonempty != "", (
assert isinstance(context_nonempty, list) and context_nonempty != [], (
f"Nodeset_search_test:Expected non-empty string for context_nonempty, got: {context_nonempty!r}"
)
assert context_empty == "", (
assert context_empty == [], (
f"Nodeset_search_test:Expected empty string for context_empty, got: {context_empty!r}"
)

View file

@ -98,21 +98,21 @@ async def main():
await cognee.cognify([dataset_name])
context_nonempty, _ = await GraphCompletionRetriever(
context_nonempty = await GraphCompletionRetriever(
node_type=NodeSet,
node_name=["first"],
).get_context("What is in the context?")
context_empty, _ = await GraphCompletionRetriever(
context_empty = await GraphCompletionRetriever(
node_type=NodeSet,
node_name=["nonexistent"],
).get_context("What is in the context?")
assert isinstance(context_nonempty, str) and context_nonempty != "", (
assert isinstance(context_nonempty, list) and context_nonempty != [], (
f"Nodeset_search_test:Expected non-empty string for context_nonempty, got: {context_nonempty!r}"
)
assert context_empty == "", (
assert context_empty == [], (
f"Nodeset_search_test:Expected empty string for context_empty, got: {context_empty!r}"
)

View file

@ -79,7 +79,7 @@ async def main():
print("\n\nExtracted sentences are:\n")
for result in search_results:
print(f"{result}\n")
assert search_results[0]["dataset_name"] == "NLP", (
assert search_results[0].dataset_name == "NLP", (
f"Dict must contain dataset name 'NLP': {search_results[0]}"
)
@ -93,7 +93,7 @@ async def main():
print("\n\nExtracted sentences are:\n")
for result in search_results:
print(f"{result}\n")
assert search_results[0]["dataset_name"] == "QUANTUM", (
assert search_results[0].dataset_name == "QUANTUM", (
f"Dict must contain dataset name 'QUANTUM': {search_results[0]}"
)
@ -170,7 +170,7 @@ async def main():
for result in search_results:
print(f"{result}\n")
assert search_results[0]["dataset_name"] == "QUANTUM", (
assert search_results[0].dataset_name == "QUANTUM", (
f"Dict must contain dataset name 'QUANTUM': {search_results[0]}"
)

View file

@ -1,6 +1,6 @@
import json
import pathlib
import os
from typing import List
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.relational import (
get_migration_relational_engine,
@ -10,7 +10,7 @@ from cognee.infrastructure.databases.vector.pgvector import (
create_db_and_tables as create_pgvector_db_and_tables,
)
from cognee.tasks.ingestion import migrate_relational_database
from cognee.modules.search.types import SearchType
from cognee.modules.search.types import SearchResult, SearchType
import cognee
@ -45,13 +45,15 @@ async def relational_db_migration():
await migrate_relational_database(graph_engine, schema=schema)
# 1. Search the graph
search_results = await cognee.search(
search_results: List[SearchResult] = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION, query_text="Tell me about the artist AC/DC"
)
) # type: ignore
print("Search results:", search_results)
# 2. Assert that the search results contain "AC/DC"
assert any("AC/DC" in r for r in search_results), "AC/DC not found in search results!"
assert any("AC/DC" in r.search_result for r in search_results), (
"AC/DC not found in search results!"
)
migration_db_provider = migration_engine.engine.dialect.name
if migration_db_provider == "postgresql":

View file

@ -1,11 +1,7 @@
import os
import pathlib
from dns.e164 import query
import cognee
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.modules.graph.utils import resolve_edges_to_text
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
GraphCompletionContextExtensionRetriever,
@ -14,11 +10,8 @@ from cognee.modules.retrieval.graph_completion_cot_retriever import GraphComplet
from cognee.modules.retrieval.graph_summary_completion_retriever import (
GraphSummaryCompletionRetriever,
)
from cognee.modules.search.operations import get_history
from cognee.modules.users.methods import get_default_user
from cognee.shared.logging_utils import get_logger
from cognee.modules.search.types import SearchType
from cognee.modules.engine.models import NodeSet
from collections import Counter
logger = get_logger()
@ -46,16 +39,16 @@ async def main():
await cognee.cognify([dataset_name])
context_gk, _ = await GraphCompletionRetriever().get_context(
context_gk = await GraphCompletionRetriever().get_context(
query="Next to which country is Germany located?"
)
context_gk_cot, _ = await GraphCompletionCotRetriever().get_context(
context_gk_cot = await GraphCompletionCotRetriever().get_context(
query="Next to which country is Germany located?"
)
context_gk_ext, _ = await GraphCompletionContextExtensionRetriever().get_context(
context_gk_ext = await GraphCompletionContextExtensionRetriever().get_context(
query="Next to which country is Germany located?"
)
context_gk_sum, _ = await GraphSummaryCompletionRetriever().get_context(
context_gk_sum = await GraphSummaryCompletionRetriever().get_context(
query="Next to which country is Germany located?"
)
@ -65,9 +58,11 @@ async def main():
("GraphCompletionContextExtensionRetriever", context_gk_ext),
("GraphSummaryCompletionRetriever", context_gk_sum),
]:
assert isinstance(context, str), f"{name}: Context should be a string"
assert context.strip(), f"{name}: Context should not be empty"
lower = context.lower()
assert isinstance(context, list), f"{name}: Context should be a list"
assert len(context) > 0, f"{name}: Context should not be empty"
context_text = await resolve_edges_to_text(context)
lower = context_text.lower()
assert "germany" in lower or "netherlands" in lower, (
f"{name}: Context did not contain 'germany' or 'netherlands'; got: {context!r}"
)
@ -143,20 +138,19 @@ async def main():
last_k=1,
)
for name, completion in [
for name, search_results in [
("GRAPH_COMPLETION", completion_gk),
("GRAPH_COMPLETION_COT", completion_cot),
("GRAPH_COMPLETION_CONTEXT_EXTENSION", completion_ext),
("GRAPH_SUMMARY_COMPLETION", completion_sum),
]:
assert isinstance(completion, list), f"{name}: should return a list"
assert len(completion) == 1, f"{name}: expected single-element list, got {len(completion)}"
text = completion[0]
assert isinstance(text, str), f"{name}: element should be a string"
assert text.strip(), f"{name}: string should not be empty"
assert "netherlands" in text.lower(), (
f"{name}: expected 'netherlands' in result, got: {text!r}"
)
for search_result in search_results:
completion = search_result.search_result
assert isinstance(completion, str), f"{name}: should return a string"
assert completion.strip(), f"{name}: string should not be empty"
assert "netherlands" in completion.lower(), (
f"{name}: expected 'netherlands' in result, got: {completion!r}"
)
graph_engine = await get_graph_engine()
graph = await graph_engine.get_graph_data()

View file

@ -6,6 +6,7 @@ from typing import Optional, Union
import cognee
from cognee.low_level import setup, DataPoint
from cognee.tasks.storage import add_data_points
from cognee.modules.graph.utils import resolve_edges_to_text
from cognee.infrastructure.databases.exceptions import DatabaseNotCreatedError
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
GraphCompletionContextExtensionRetriever,
@ -51,17 +52,15 @@ class TestGraphCompletionWithContextExtensionRetriever:
retriever = GraphCompletionContextExtensionRetriever()
context, _ = await retriever.get_context("Who works at Canva?")
context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?"))
assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski"
assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer"
answer = await retriever.get_completion("Who works at Canva?")
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)
assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}"
assert answer.strip(), "Answer must contain only non-empty strings"
@pytest.mark.asyncio
async def test_graph_completion_extension_context_complex(self):
@ -129,7 +128,9 @@ class TestGraphCompletionWithContextExtensionRetriever:
retriever = GraphCompletionContextExtensionRetriever(top_k=20)
context, _ = await retriever.get_context("Who works at Figma?")
context = await resolve_edges_to_text(
await retriever.get_context("Who works at Figma and drives Tesla?")
)
print(context)
@ -139,10 +140,8 @@ class TestGraphCompletionWithContextExtensionRetriever:
answer = await retriever.get_completion("Who works at Figma?")
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)
assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}"
assert answer.strip(), "Answer must contain only non-empty strings"
@pytest.mark.asyncio
async def test_get_graph_completion_extension_context_on_empty_graph(self):
@ -167,12 +166,10 @@ class TestGraphCompletionWithContextExtensionRetriever:
await setup()
context, _ = await retriever.get_context("Who works at Figma?")
assert context == "", "Context should be empty on an empty graph"
context = await retriever.get_context("Who works at Figma?")
assert context == [], "Context should be empty on an empty graph"
answer = await retriever.get_completion("Who works at Figma?")
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)
assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}"
assert answer.strip(), "Answer must contain only non-empty strings"

View file

@ -5,6 +5,7 @@ from typing import Optional, Union
import cognee
from cognee.low_level import setup, DataPoint
from cognee.modules.graph.utils import resolve_edges_to_text
from cognee.tasks.storage import add_data_points
from cognee.infrastructure.databases.exceptions import DatabaseNotCreatedError
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
@ -47,17 +48,15 @@ class TestGraphCompletionCoTRetriever:
retriever = GraphCompletionCotRetriever()
context, _ = await retriever.get_context("Who works at Canva?")
context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?"))
assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski"
assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer"
answer = await retriever.get_completion("Who works at Canva?")
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)
assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}"
assert answer.strip(), "Answer must contain only non-empty strings"
@pytest.mark.asyncio
async def test_graph_completion_cot_context_complex(self):
@ -124,7 +123,7 @@ class TestGraphCompletionCoTRetriever:
retriever = GraphCompletionCotRetriever(top_k=20)
context, _ = await retriever.get_context("Who works at Figma?")
context = await resolve_edges_to_text(await retriever.get_context("Who works at Figma?"))
print(context)
@ -134,10 +133,8 @@ class TestGraphCompletionCoTRetriever:
answer = await retriever.get_completion("Who works at Figma?")
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)
assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}"
assert answer.strip(), "Answer must contain only non-empty strings"
@pytest.mark.asyncio
async def test_get_graph_completion_cot_context_on_empty_graph(self):
@ -162,12 +159,10 @@ class TestGraphCompletionCoTRetriever:
await setup()
context, _ = await retriever.get_context("Who works at Figma?")
assert context == "", "Context should be empty on an empty graph"
context = await retriever.get_context("Who works at Figma?")
assert context == [], "Context should be empty on an empty graph"
answer = await retriever.get_completion("Who works at Figma?")
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)
assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}"
assert answer.strip(), "Answer must contain only non-empty strings"

View file

@ -5,6 +5,7 @@ from typing import Optional, Union
import cognee
from cognee.low_level import setup, DataPoint
from cognee.modules.graph.utils import resolve_edges_to_text
from cognee.tasks.storage import add_data_points
from cognee.infrastructure.databases.exceptions import DatabaseNotCreatedError
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
@ -67,7 +68,7 @@ class TestGraphCompletionRetriever:
retriever = GraphCompletionRetriever()
context, _ = await retriever.get_context("Who works at Canva?")
context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?"))
# Ensure the top-level sections are present
assert "Nodes:" in context, "Missing 'Nodes:' section in context"
@ -191,7 +192,7 @@ class TestGraphCompletionRetriever:
retriever = GraphCompletionRetriever(top_k=20)
context, _ = await retriever.get_context("Who works at Figma?")
context = await resolve_edges_to_text(await retriever.get_context("Who works at Figma?"))
print(context)
@ -222,5 +223,5 @@ class TestGraphCompletionRetriever:
await setup()
context, _ = await retriever.get_context("Who works at Figma?")
assert context == "", "Context should be empty on an empty graph"
context = await retriever.get_context("Who works at Figma?")
assert context == [], "Context should be empty on an empty graph"

View file

@ -82,7 +82,7 @@ class TestInsightsRetriever:
context = await retriever.get_context("Mike")
assert context[0][0]["name"] == "Mike Broski", "Failed to get Mike Broski"
assert context[0].node1.attributes["name"] == "Mike Broski", "Failed to get Mike Broski"
@pytest.mark.asyncio
async def test_insights_context_complex(self):
@ -222,7 +222,9 @@ class TestInsightsRetriever:
context = await retriever.get_context("Christina")
assert context[0][0]["name"] == "Christina Mayer", "Failed to get Christina Mayer"
assert context[0].node1.attributes["name"] == "Christina Mayer", (
"Failed to get Christina Mayer"
)
@pytest.mark.asyncio
async def test_insights_context_on_empty_graph(self):

View file

@ -1,230 +0,0 @@
import json
import uuid
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from cognee.modules.engine.models.node_set import NodeSet
from cognee.modules.search.exceptions import UnsupportedSearchTypeError
from cognee.modules.search.methods.search import search, specific_search
from cognee.modules.search.types import SearchType
from cognee.modules.users.models import User
import sys
search_module = sys.modules.get("cognee.modules.search.methods.search")
@pytest.fixture
def mock_user():
user = MagicMock(spec=User)
user.id = uuid.uuid4()
return user
@pytest.mark.asyncio
@patch.object(search_module, "log_query")
@patch.object(search_module, "log_result")
@patch.object(search_module, "specific_search")
async def test_search(
mock_specific_search,
mock_log_result,
mock_log_query,
mock_user,
):
# Setup
query_text = "test query"
query_type = SearchType.CHUNKS
datasets = ["dataset1", "dataset2"]
# Mock the query logging
mock_query = MagicMock()
mock_query.id = uuid.uuid4()
mock_log_query.return_value = mock_query
# Mock document IDs
doc_id1 = uuid.uuid4()
doc_id2 = uuid.uuid4()
# Mock search results
search_results = [
{"document_id": str(doc_id1), "content": "Result 1"},
{"document_id": str(doc_id2), "content": "Result 2"},
]
mock_specific_search.return_value = search_results
# Execute
await search(query_text, query_type, datasets, mock_user)
# Verify
mock_log_query.assert_called_once_with(query_text, query_type.value, mock_user.id)
mock_specific_search.assert_called_once_with(
query_type=query_type,
query_text=query_text,
user=mock_user,
system_prompt_path="answer_simple_question.txt",
system_prompt=None,
top_k=10,
node_type=NodeSet,
node_name=None,
save_interaction=False,
last_k=None,
only_context=False,
)
# Verify result logging
mock_log_result.assert_called_once()
# Check that the first argument is the query ID
assert mock_log_result.call_args[0][0] == mock_query.id
# The second argument should be the JSON string of the filtered results
# We can't directly compare the JSON strings due to potential ordering differences
# So we parse the JSON and compare the objects
logged_results = json.loads(mock_log_result.call_args[0][1])
assert len(logged_results) == 2
assert logged_results[0]["document_id"] == str(doc_id1)
assert logged_results[1]["document_id"] == str(doc_id2)
@pytest.mark.asyncio
@patch.object(search_module, "SummariesRetriever")
@patch.object(search_module, "send_telemetry")
async def test_specific_search_summaries(mock_send_telemetry, mock_summaries_retriever, mock_user):
# Setup
query = "test query"
query_type = SearchType.SUMMARIES
# Mock the retriever
mock_retriever = MagicMock()
mock_retriever.get_completion = AsyncMock()
mock_retriever.get_completion.return_value = [{"content": "Summary result"}]
mock_summaries_retriever.return_value = mock_retriever
# Execute
results = await specific_search(query_type, query, mock_user)
# Verify
mock_summaries_retriever.assert_called_once()
mock_retriever.get_completion.assert_called_once_with(query)
mock_send_telemetry.assert_called()
assert len(results) == 1
assert results[0]["content"] == "Summary result"
@pytest.mark.asyncio
@patch.object(search_module, "InsightsRetriever")
@patch.object(search_module, "send_telemetry")
async def test_specific_search_insights(mock_send_telemetry, mock_insights_retriever, mock_user):
# Setup
query = "test query"
query_type = SearchType.INSIGHTS
# Mock the retriever
mock_retriever = MagicMock()
mock_retriever.get_completion = AsyncMock()
mock_retriever.get_completion.return_value = [{"content": "Insight result"}]
mock_insights_retriever.return_value = mock_retriever
# Execute
results = await specific_search(query_type, query, mock_user)
# Verify
mock_insights_retriever.assert_called_once()
mock_retriever.get_completion.assert_called_once_with(query)
mock_send_telemetry.assert_called()
assert len(results) == 1
assert results[0]["content"] == "Insight result"
@pytest.mark.asyncio
@patch.object(search_module, "ChunksRetriever")
@patch.object(search_module, "send_telemetry")
async def test_specific_search_chunks(mock_send_telemetry, mock_chunks_retriever, mock_user):
# Setup
query = "test query"
query_type = SearchType.CHUNKS
# Mock the retriever
mock_retriever = MagicMock()
mock_retriever.get_completion = AsyncMock()
mock_retriever.get_completion.return_value = [{"content": "Chunk result"}]
mock_chunks_retriever.return_value = mock_retriever
# Execute
results = await specific_search(query_type, query, mock_user)
# Verify
mock_chunks_retriever.assert_called_once()
mock_retriever.get_completion.assert_called_once_with(query)
mock_send_telemetry.assert_called()
assert len(results) == 1
assert results[0]["content"] == "Chunk result"
@pytest.mark.asyncio
@pytest.mark.parametrize(
"selected_type, retriever_name, expected_content, top_k",
[
(SearchType.RAG_COMPLETION, "CompletionRetriever", "RAG result from lucky search", 10),
(SearchType.CHUNKS, "ChunksRetriever", "Chunk result from lucky search", 5),
(SearchType.SUMMARIES, "SummariesRetriever", "Summary from lucky search", 15),
(SearchType.INSIGHTS, "InsightsRetriever", "Insight result from lucky search", 20),
],
)
@patch.object(search_module, "select_search_type")
@patch.object(search_module, "send_telemetry")
async def test_specific_search_feeling_lucky(
mock_send_telemetry,
mock_select_search_type,
selected_type,
retriever_name,
expected_content,
top_k,
mock_user,
):
with patch.object(search_module, retriever_name) as mock_retriever_class:
# Setup
query = f"test query for {retriever_name}"
query_type = SearchType.FEELING_LUCKY
# Mock the intelligent search type selection
mock_select_search_type.return_value = selected_type
# Mock the retriever
mock_retriever_instance = MagicMock()
mock_retriever_instance.get_completion = AsyncMock(
return_value=[{"content": expected_content}]
)
mock_retriever_class.return_value = mock_retriever_instance
# Execute
results = await specific_search(query_type, query, mock_user, top_k=top_k)
# Verify
mock_select_search_type.assert_called_once_with(query)
if retriever_name == "CompletionRetriever":
mock_retriever_class.assert_called_once_with(
system_prompt_path="answer_simple_question.txt",
top_k=top_k,
system_prompt=None,
only_context=None,
)
else:
mock_retriever_class.assert_called_once_with(top_k=top_k)
mock_retriever_instance.get_completion.assert_called_once_with(query)
mock_send_telemetry.assert_called()
assert len(results) == 1
assert results[0]["content"] == expected_content
@pytest.mark.asyncio
async def test_specific_search_invalid_type(mock_user):
# Setup
query = "test query"
query_type = "INVALID_TYPE" # Not a valid SearchType
# Execute and verify
with pytest.raises(UnsupportedSearchTypeError) as excinfo:
await specific_search(query_type, query, mock_user)
assert "Unsupported search type" in str(excinfo.value)

View file

@ -47,6 +47,7 @@ async def main():
query = "When was Kamala Harris in office?"
triplets = await brute_force_triplet_search(
query=query,
user=user,
top_k=3,
collections=["graphitinode_content", "graphitinode_name", "graphitinode_summary"],
)