From 3372679f7bb40c01ffd9e337ead27fe9f8981d54 Mon Sep 17 00:00:00 2001 From: chinu0609 Date: Wed, 29 Oct 2025 20:12:14 +0530 Subject: [PATCH 001/176] feat: adding last_accessed_at field to the models and updating the retrievers to update the timestamp --- .../modules/chunking/models/DocumentChunk.py | 7 +++ cognee/modules/engine/models/Entity.py | 7 ++- cognee/modules/retrieval/chunks_retriever.py | 55 +++++++---------- .../modules/retrieval/summaries_retriever.py | 28 ++++----- .../retrieval/utils/access_tracking.py | 61 +++++++++++++++++++ cognee/tasks/summarization/models.py | 8 ++- 6 files changed, 115 insertions(+), 51 deletions(-) create mode 100644 cognee/modules/retrieval/utils/access_tracking.py diff --git a/cognee/modules/chunking/models/DocumentChunk.py b/cognee/modules/chunking/models/DocumentChunk.py index 9f8c57486..c4c6a2ed3 100644 --- a/cognee/modules/chunking/models/DocumentChunk.py +++ b/cognee/modules/chunking/models/DocumentChunk.py @@ -1,5 +1,7 @@ from typing import List, Union +from pydantic import BaseModel, Field +from datetime import datetime, timezone from cognee.infrastructure.engine import DataPoint from cognee.modules.data.processing.document_types import Document from cognee.modules.engine.models import Entity @@ -22,6 +24,7 @@ class DocumentChunk(DataPoint): - cut_type: The type of cut that defined this chunk. - is_part_of: The document to which this chunk belongs. - contains: A list of entities or events contained within the chunk (default is None). + - last_accessed_at: The timestamp of the last time the chunk was accessed. - metadata: A dictionary to hold meta information related to the chunk, including index fields. """ @@ -32,5 +35,9 @@ class DocumentChunk(DataPoint): cut_type: str is_part_of: Document contains: List[Union[Entity, Event]] = None + + last_accessed_at: int = Field( + default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000) + ) metadata: dict = {"index_fields": ["text"]} diff --git a/cognee/modules/engine/models/Entity.py b/cognee/modules/engine/models/Entity.py index 36da2e344..3e48ea02a 100644 --- a/cognee/modules/engine/models/Entity.py +++ b/cognee/modules/engine/models/Entity.py @@ -1,11 +1,14 @@ from cognee.infrastructure.engine import DataPoint from cognee.modules.engine.models.EntityType import EntityType from typing import Optional - +from datetime import datetime, timezone +from pydantic import BaseModel, Field class Entity(DataPoint): name: str is_a: Optional[EntityType] = None description: str - + last_accessed_at: int = Field( + default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000) + ) metadata: dict = {"index_fields": ["name"]} diff --git a/cognee/modules/retrieval/chunks_retriever.py b/cognee/modules/retrieval/chunks_retriever.py index 94b9d3fb9..74634b71e 100644 --- a/cognee/modules/retrieval/chunks_retriever.py +++ b/cognee/modules/retrieval/chunks_retriever.py @@ -1,10 +1,11 @@ from typing import Any, Optional - +from cognee.modules.retrieval.utils.access_tracking import update_node_access_timestamps from cognee.shared.logging_utils import get_logger from cognee.infrastructure.databases.vector import get_vector_engine from cognee.modules.retrieval.base_retriever import BaseRetriever from cognee.modules.retrieval.exceptions.exceptions import NoDataError from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError +from datetime import datetime, timezone logger = get_logger("ChunksRetriever") @@ -27,38 +28,26 @@ class ChunksRetriever(BaseRetriever): ): self.top_k = top_k - async def get_context(self, query: str) -> Any: - """ - Retrieves document chunks context based on the query. - - Searches for document chunks relevant to the specified query using a vector engine. - Raises a NoDataError if no data is found in the system. - - Parameters: - ----------- - - - query (str): The query string to search for relevant document chunks. - - Returns: - -------- - - - Any: A list of document chunk payloads retrieved from the search. - """ - logger.info( - f"Starting chunk retrieval for query: '{query[:100]}{'...' if len(query) > 100 else ''}'" - ) - - vector_engine = get_vector_engine() - - try: - found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k) - logger.info(f"Found {len(found_chunks)} chunks from vector search") - except CollectionNotFoundError as error: - logger.error("DocumentChunk_text collection not found in vector database") - raise NoDataError("No data found in the system, please add data first.") from error - - chunk_payloads = [result.payload for result in found_chunks] - logger.info(f"Returning {len(chunk_payloads)} chunk payloads") + async def get_context(self, query: str) -> Any: + """Retrieves document chunks context based on the query.""" + logger.info( + f"Starting chunk retrieval for query: '{query[:100]}{'...' if len(query) > 100 else ''}'" + ) + + vector_engine = get_vector_engine() + + try: + found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k) + logger.info(f"Found {len(found_chunks)} chunks from vector search") + + # NEW: Update access timestamps + await update_node_access_timestamps(found_chunks, "DocumentChunk") + except CollectionNotFoundError as error: + logger.error("DocumentChunk_text collection not found in vector database") + raise NoDataError("No data found in the system, please add data first.") from error + + chunk_payloads = [result.payload for result in found_chunks] + logger.info(f"Returning {len(chunk_payloads)} chunk payloads") return chunk_payloads async def get_completion( diff --git a/cognee/modules/retrieval/summaries_retriever.py b/cognee/modules/retrieval/summaries_retriever.py index 87b224946..7f996274e 100644 --- a/cognee/modules/retrieval/summaries_retriever.py +++ b/cognee/modules/retrieval/summaries_retriever.py @@ -4,6 +4,7 @@ from cognee.shared.logging_utils import get_logger from cognee.infrastructure.databases.vector import get_vector_engine from cognee.modules.retrieval.base_retriever import BaseRetriever from cognee.modules.retrieval.exceptions.exceptions import NoDataError +from cognee.modules.retrieval.utils.access_tracking import update_node_access_timestamps from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError logger = get_logger("SummariesRetriever") @@ -47,20 +48,19 @@ class SummariesRetriever(BaseRetriever): f"Starting summary retrieval for query: '{query[:100]}{'...' if len(query) > 100 else ''}'" ) - vector_engine = get_vector_engine() - - try: - summaries_results = await vector_engine.search( - "TextSummary_text", query, limit=self.top_k - ) - logger.info(f"Found {len(summaries_results)} summaries from vector search") - except CollectionNotFoundError as error: - logger.error("TextSummary_text collection not found in vector database") - raise NoDataError("No data found in the system, please add data first.") from error - - summary_payloads = [summary.payload for summary in summaries_results] - logger.info(f"Returning {len(summary_payloads)} summary payloads") - return summary_payloads + vector_engine = get_vector_engine() + + try: + summaries_results = await vector_engine.search( + "TextSummary_text", query, limit=self.top_k + ) + + await update_node_access_timestamps(summaries_results, "TextSummary") + + except CollectionNotFoundError as error: + raise NoDataError("No data found in the system, please add data first.") from error + + return [summary.payload for summary in summaries_results] async def get_completion( self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None, **kwargs diff --git a/cognee/modules/retrieval/utils/access_tracking.py b/cognee/modules/retrieval/utils/access_tracking.py new file mode 100644 index 000000000..ca5ed88cd --- /dev/null +++ b/cognee/modules/retrieval/utils/access_tracking.py @@ -0,0 +1,61 @@ + +"""Utilities for tracking data access in retrievers.""" + +import json +from datetime import datetime, timezone +from typing import List, Any + +from cognee.infrastructure.databases.graph import get_graph_engine +from cognee.shared.logging_utils import get_logger + +logger = get_logger(__name__) + + +async def update_node_access_timestamps(items: List[Any], node_type: str): + """ + Update last_accessed_at for nodes in Kuzu graph database. + + Parameters + ---------- + items : List[Any] + List of items with payload containing 'id' field (from vector search results) + node_type : str + Type of node to update (e.g., 'DocumentChunk', 'Entity', 'TextSummary') + """ + if not items: + return + + graph_engine = await get_graph_engine() + # Convert to milliseconds since epoch (matching the field format) + timestamp_ms = int(datetime.now(timezone.utc).timestamp() * 1000) + + for item in items: + # Extract ID from payload (vector search results have this structure) + item_id = item.payload.get("id") if hasattr(item, 'payload') else item.get("id") + if not item_id: + continue + + try: + # Get current node properties from Kuzu's Node table + result = await graph_engine.query( + "MATCH (n:Node {id: $id}) WHERE n.type = $node_type RETURN n.properties as props", + {"id": str(item_id), "node_type": node_type} + ) + + if result and len(result) > 0 and result[0][0]: + # Parse existing properties JSON + props = json.loads(result[0][0]) if result[0][0] else {} + # Update last_accessed_at with millisecond timestamp + props["last_accessed_at"] = timestamp_ms + + # Write back to graph database + await graph_engine.query( + "MATCH (n:Node {id: $id}) WHERE n.type = $node_type SET n.properties = $props", + {"id": str(item_id), "node_type": node_type, "props": json.dumps(props)} + ) + except Exception as e: + logger.warning(f"Failed to update timestamp for {node_type} {item_id}: {e}") + continue + + logger.debug(f"Updated access timestamps for {len(items)} {node_type} nodes") + diff --git a/cognee/tasks/summarization/models.py b/cognee/tasks/summarization/models.py index 75ed82d50..46f9a8d8b 100644 --- a/cognee/tasks/summarization/models.py +++ b/cognee/tasks/summarization/models.py @@ -1,5 +1,7 @@ -from typing import Union +from pydantic import BaseModel, Field +from typing import Union +from datetime import datetime, timezone from cognee.infrastructure.engine import DataPoint from cognee.modules.chunking.models import DocumentChunk from cognee.shared.CodeGraphEntities import CodeFile, CodePart @@ -17,7 +19,9 @@ class TextSummary(DataPoint): text: str made_from: DocumentChunk - + last_accessed_at: int = Field( + default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000) + ) metadata: dict = {"index_fields": ["text"]} From 3f27c5592b58af29369125362510e96b72c56cbc Mon Sep 17 00:00:00 2001 From: chinu0609 Date: Wed, 29 Oct 2025 20:17:27 +0530 Subject: [PATCH 002/176] feat: adding last_accessed_at field to the models and updating the retrievers to update the timestamp --- cognee/modules/retrieval/chunks_retriever.py | 48 +++++++++++-------- .../modules/retrieval/summaries_retriever.py | 28 ++++++----- 2 files changed, 44 insertions(+), 32 deletions(-) diff --git a/cognee/modules/retrieval/chunks_retriever.py b/cognee/modules/retrieval/chunks_retriever.py index 74634b71e..f821fc902 100644 --- a/cognee/modules/retrieval/chunks_retriever.py +++ b/cognee/modules/retrieval/chunks_retriever.py @@ -29,26 +29,34 @@ class ChunksRetriever(BaseRetriever): self.top_k = top_k async def get_context(self, query: str) -> Any: - """Retrieves document chunks context based on the query.""" - logger.info( - f"Starting chunk retrieval for query: '{query[:100]}{'...' if len(query) > 100 else ''}'" - ) - - vector_engine = get_vector_engine() - - try: - found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k) - logger.info(f"Found {len(found_chunks)} chunks from vector search") - - # NEW: Update access timestamps - await update_node_access_timestamps(found_chunks, "DocumentChunk") - except CollectionNotFoundError as error: - logger.error("DocumentChunk_text collection not found in vector database") - raise NoDataError("No data found in the system, please add data first.") from error - - chunk_payloads = [result.payload for result in found_chunks] - logger.info(f"Returning {len(chunk_payloads)} chunk payloads") - return chunk_payloads + """ + Retrieves document chunks context based on the query. + Searches for document chunks relevant to the specified query using a vector engine. + Raises a NoDataError if no data is found in the system. + Parameters: + ----------- + - query (str): The query string to search for relevant document chunks. + Returns: + -------- + - Any: A list of document chunk payloads retrieved from the search. + """ + logger.info( + f"Starting chunk retrieval for query: '{query[:100]}{'...' if len(query) > 100 else ''}'" + ) + + vector_engine = get_vector_engine() + + try: + found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k) + logger.info(f"Found {len(found_chunks)} chunks from vector search") + await update_node_access_timestamps(found_chunks, "DocumentChunk") + + except CollectionNotFoundError as error: + logger.error("DocumentChunk_text collection not found in vector database") + raise NoDataError("No data found in the system, please add data first.") from error + + chunk_payloads = [result.payload for result in found_chunks] + logger.info(f"Returning {len(chunk_payloads)} chunk payloads") async def get_completion( self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None diff --git a/cognee/modules/retrieval/summaries_retriever.py b/cognee/modules/retrieval/summaries_retriever.py index 7f996274e..9ac8b096d 100644 --- a/cognee/modules/retrieval/summaries_retriever.py +++ b/cognee/modules/retrieval/summaries_retriever.py @@ -48,19 +48,23 @@ class SummariesRetriever(BaseRetriever): f"Starting summary retrieval for query: '{query[:100]}{'...' if len(query) > 100 else ''}'" ) - vector_engine = get_vector_engine() - - try: - summaries_results = await vector_engine.search( - "TextSummary_text", query, limit=self.top_k - ) - + vector_engine = get_vector_engine() + + try: + summaries_results = await vector_engine.search( + "TextSummary_text", query, limit=self.top_k + ) + logger.info(f"Found {len(summaries_results)} summaries from vector search") + await update_node_access_timestamps(summaries_results, "TextSummary") - - except CollectionNotFoundError as error: - raise NoDataError("No data found in the system, please add data first.") from error - - return [summary.payload for summary in summaries_results] + + except CollectionNotFoundError as error: + logger.error("TextSummary_text collection not found in vector database") + raise NoDataError("No data found in the system, please add data first.") from error + + summary_payloads = [summary.payload for summary in summaries_results] + logger.info(f"Returning {len(summary_payloads)} summary payloads") + return summary_payloads async def get_completion( self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None, **kwargs From 5f6f0502c832d129749b453121c6f5be565044bc Mon Sep 17 00:00:00 2001 From: chinu0609 Date: Fri, 31 Oct 2025 00:00:18 +0530 Subject: [PATCH 003/176] fix: removing last_acessed_at from individual model and adding it to DataPoint --- cognee/infrastructure/engine/models/DataPoint.py | 3 +++ cognee/modules/chunking/models/DocumentChunk.py | 5 ----- cognee/modules/engine/models/Entity.py | 3 --- cognee/tasks/summarization/models.py | 3 --- 4 files changed, 3 insertions(+), 11 deletions(-) diff --git a/cognee/infrastructure/engine/models/DataPoint.py b/cognee/infrastructure/engine/models/DataPoint.py index 812380eaa..3178713c8 100644 --- a/cognee/infrastructure/engine/models/DataPoint.py +++ b/cognee/infrastructure/engine/models/DataPoint.py @@ -43,6 +43,9 @@ class DataPoint(BaseModel): updated_at: int = Field( default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000) ) + last_accessed_at: int = Field( + default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000) + ) ontology_valid: bool = False version: int = 1 # Default version topological_rank: Optional[int] = 0 diff --git a/cognee/modules/chunking/models/DocumentChunk.py b/cognee/modules/chunking/models/DocumentChunk.py index c4c6a2ed3..601454802 100644 --- a/cognee/modules/chunking/models/DocumentChunk.py +++ b/cognee/modules/chunking/models/DocumentChunk.py @@ -35,9 +35,4 @@ class DocumentChunk(DataPoint): cut_type: str is_part_of: Document contains: List[Union[Entity, Event]] = None - - last_accessed_at: int = Field( - default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000) - ) - metadata: dict = {"index_fields": ["text"]} diff --git a/cognee/modules/engine/models/Entity.py b/cognee/modules/engine/models/Entity.py index 3e48ea02a..4083cd2e6 100644 --- a/cognee/modules/engine/models/Entity.py +++ b/cognee/modules/engine/models/Entity.py @@ -8,7 +8,4 @@ class Entity(DataPoint): name: str is_a: Optional[EntityType] = None description: str - last_accessed_at: int = Field( - default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000) - ) metadata: dict = {"index_fields": ["name"]} diff --git a/cognee/tasks/summarization/models.py b/cognee/tasks/summarization/models.py index 46f9a8d8b..8cee2ade3 100644 --- a/cognee/tasks/summarization/models.py +++ b/cognee/tasks/summarization/models.py @@ -19,9 +19,6 @@ class TextSummary(DataPoint): text: str made_from: DocumentChunk - last_accessed_at: int = Field( - default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000) - ) metadata: dict = {"index_fields": ["text"]} From 6f06e4a5eb1143ddcb2ad08132486630b8a2deae Mon Sep 17 00:00:00 2001 From: chinu0609 Date: Fri, 31 Oct 2025 00:17:13 +0530 Subject: [PATCH 004/176] fix: removing node_type and try except --- cognee/modules/retrieval/chunks_retriever.py | 2 +- .../modules/retrieval/summaries_retriever.py | 2 +- .../retrieval/utils/access_tracking.py | 55 ++++++++++--------- 3 files changed, 31 insertions(+), 28 deletions(-) diff --git a/cognee/modules/retrieval/chunks_retriever.py b/cognee/modules/retrieval/chunks_retriever.py index f821fc902..be1f95811 100644 --- a/cognee/modules/retrieval/chunks_retriever.py +++ b/cognee/modules/retrieval/chunks_retriever.py @@ -49,7 +49,7 @@ class ChunksRetriever(BaseRetriever): try: found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k) logger.info(f"Found {len(found_chunks)} chunks from vector search") - await update_node_access_timestamps(found_chunks, "DocumentChunk") + await update_node_access_timestamps(found_chunks) except CollectionNotFoundError as error: logger.error("DocumentChunk_text collection not found in vector database") diff --git a/cognee/modules/retrieval/summaries_retriever.py b/cognee/modules/retrieval/summaries_retriever.py index 9ac8b096d..0df750d22 100644 --- a/cognee/modules/retrieval/summaries_retriever.py +++ b/cognee/modules/retrieval/summaries_retriever.py @@ -56,7 +56,7 @@ class SummariesRetriever(BaseRetriever): ) logger.info(f"Found {len(summaries_results)} summaries from vector search") - await update_node_access_timestamps(summaries_results, "TextSummary") + await update_node_access_timestamps(summaries_results) except CollectionNotFoundError as error: logger.error("TextSummary_text collection not found in vector database") diff --git a/cognee/modules/retrieval/utils/access_tracking.py b/cognee/modules/retrieval/utils/access_tracking.py index ca5ed88cd..79afd25db 100644 --- a/cognee/modules/retrieval/utils/access_tracking.py +++ b/cognee/modules/retrieval/utils/access_tracking.py @@ -1,4 +1,4 @@ - + """Utilities for tracking data access in retrievers.""" import json @@ -11,51 +11,54 @@ from cognee.shared.logging_utils import get_logger logger = get_logger(__name__) -async def update_node_access_timestamps(items: List[Any], node_type: str): +async def update_node_access_timestamps(items: List[Any]): """ Update last_accessed_at for nodes in Kuzu graph database. + Automatically determines node type from the graph database. Parameters ---------- items : List[Any] List of items with payload containing 'id' field (from vector search results) - node_type : str - Type of node to update (e.g., 'DocumentChunk', 'Entity', 'TextSummary') """ if not items: return graph_engine = await get_graph_engine() - # Convert to milliseconds since epoch (matching the field format) timestamp_ms = int(datetime.now(timezone.utc).timestamp() * 1000) for item in items: - # Extract ID from payload (vector search results have this structure) + # Extract ID from payload item_id = item.payload.get("id") if hasattr(item, 'payload') else item.get("id") if not item_id: continue - try: - # Get current node properties from Kuzu's Node table - result = await graph_engine.query( - "MATCH (n:Node {id: $id}) WHERE n.type = $node_type RETURN n.properties as props", - {"id": str(item_id), "node_type": node_type} + # try: + # Query to get both node type and properties in one call + result = await graph_engine.query( + "MATCH (n:Node {id: $id}) RETURN n.type as node_type, n.properties as props", + {"id": str(item_id)} + ) + + if result and len(result) > 0 and result[0]: + node_type = result[0][0] # First column: node_type + props_json = result[0][1] # Second column: properties + + # Parse existing properties JSON + props = json.loads(props_json) if props_json else {} + # Update last_accessed_at with millisecond timestamp + props["last_accessed_at"] = timestamp_ms + + # Write back to graph database + await graph_engine.query( + "MATCH (n:Node {id: $id}) SET n.properties = $props", + {"id": str(item_id), "props": json.dumps(props)} ) - if result and len(result) > 0 and result[0][0]: - # Parse existing properties JSON - props = json.loads(result[0][0]) if result[0][0] else {} - # Update last_accessed_at with millisecond timestamp - props["last_accessed_at"] = timestamp_ms + logger.debug(f"Updated access timestamp for {node_type} node {item_id}") - # Write back to graph database - await graph_engine.query( - "MATCH (n:Node {id: $id}) WHERE n.type = $node_type SET n.properties = $props", - {"id": str(item_id), "node_type": node_type, "props": json.dumps(props)} - ) - except Exception as e: - logger.warning(f"Failed to update timestamp for {node_type} {item_id}: {e}") - continue + # except Exception as e: + # logger.error(f"Failed to update timestamp for node {item_id}: {e}") + # continue - logger.debug(f"Updated access timestamps for {len(items)} {node_type} nodes") - + logger.debug(f"Updated access timestamps for {len(items)} nodes") From f1afd1f0a2a5433dc341c485b08ce33d1bc16252 Mon Sep 17 00:00:00 2001 From: chinu0609 Date: Fri, 31 Oct 2025 15:49:34 +0530 Subject: [PATCH 005/176] feat: adding cleanup function and adding update_node_acess_timestamps in completion retriever and graph_completion retriever --- .../modules/retrieval/completion_retriever.py | 3 +- .../retrieval/graph_completion_retriever.py | 13 +- cognee/tasks/cleanup/cleanup_unused_data.py | 232 ++++++++++++++++++ 3 files changed, 246 insertions(+), 2 deletions(-) create mode 100644 cognee/tasks/cleanup/cleanup_unused_data.py diff --git a/cognee/modules/retrieval/completion_retriever.py b/cognee/modules/retrieval/completion_retriever.py index bb568924d..fc8ef747f 100644 --- a/cognee/modules/retrieval/completion_retriever.py +++ b/cognee/modules/retrieval/completion_retriever.py @@ -8,6 +8,7 @@ from cognee.modules.retrieval.utils.session_cache import ( save_conversation_history, get_conversation_history, ) +from cognee.modules.retrieval.utils.access_tracking import update_node_access_timestamps from cognee.modules.retrieval.base_retriever import BaseRetriever from cognee.modules.retrieval.exceptions.exceptions import NoDataError from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError @@ -65,7 +66,7 @@ class CompletionRetriever(BaseRetriever): if len(found_chunks) == 0: return "" - + await update_node_access_timestamps(found_chunks) # Combine all chunks text returned from vector search (number of chunks is determined by top_k chunks_payload = [found_chunk.payload["text"] for found_chunk in found_chunks] combined_context = "\n".join(chunks_payload) diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index b7ab4edae..ac7e45e3c 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -16,6 +16,7 @@ from cognee.modules.retrieval.utils.session_cache import ( ) from cognee.shared.logging_utils import get_logger from cognee.modules.retrieval.utils.extract_uuid_from_node import extract_uuid_from_node +from cognee.modules.retrieval.utils.access_tracking import update_node_access_timestamps from cognee.modules.retrieval.utils.models import CogneeUserInteraction from cognee.modules.engine.models.node_set import NodeSet from cognee.infrastructure.databases.graph import get_graph_engine @@ -138,7 +139,17 @@ class GraphCompletionRetriever(BaseGraphRetriever): return [] # context = await self.resolve_edges_to_text(triplets) - + entity_nodes = [] + seen_ids = set() + for triplet in triplets: + if hasattr(triplet, 'node1') and triplet.node1 and triplet.node1.id not in seen_ids: + entity_nodes.append({"id": str(triplet.node1.id)}) + seen_ids.add(triplet.node1.id) + if hasattr(triplet, 'node2') and triplet.node2 and triplet.node2.id not in seen_ids: + entity_nodes.append({"id": str(triplet.node2.id)}) + seen_ids.add(triplet.node2.id) + + await update_node_access_timestamps(entity_nodes) return triplets async def get_completion( diff --git a/cognee/tasks/cleanup/cleanup_unused_data.py b/cognee/tasks/cleanup/cleanup_unused_data.py new file mode 100644 index 000000000..e97692bb4 --- /dev/null +++ b/cognee/tasks/cleanup/cleanup_unused_data.py @@ -0,0 +1,232 @@ +""" +Task for automatically deleting unused data from the memify pipeline. + +This task identifies and removes data (chunks, entities, summaries) that hasn't +been accessed by retrievers for a specified period, helping maintain system +efficiency and storage optimization. +""" + +import json +from datetime import datetime, timezone, timedelta +from typing import Optional, Dict, Any +from uuid import UUID + +from cognee.infrastructure.databases.graph import get_graph_engine +from cognee.infrastructure.databases.vector import get_vector_engine +from cognee.shared.logging_utils import get_logger + +logger = get_logger(__name__) + + +async def cleanup_unused_data( + minutes_threshold: int = 30, + dry_run: bool = True, + user_id: Optional[UUID] = None +) -> Dict[str, Any]: + """ + Identify and remove unused data from the memify pipeline. + + Parameters + ---------- + minutes_threshold : int + Minutes since last access to consider data unused (default: 30) + dry_run : bool + If True, only report what would be deleted without actually deleting (default: True) + user_id : UUID, optional + Limit cleanup to specific user's data (default: None) + + Returns + ------- + Dict[str, Any] + Cleanup results with status, counts, and timestamp + """ + logger.info( + "Starting cleanup task", + minutes_threshold=minutes_threshold, + dry_run=dry_run, + user_id=str(user_id) if user_id else None + ) + + # Calculate cutoff timestamp in milliseconds + cutoff_date = datetime.now(timezone.utc) - timedelta(minutes=minutes_threshold) + cutoff_timestamp_ms = int(cutoff_date.timestamp() * 1000) + + logger.debug(f"Cutoff timestamp: {cutoff_date.isoformat()} ({cutoff_timestamp_ms}ms)") + + # Find unused nodes + unused_nodes = await _find_unused_nodes(cutoff_timestamp_ms, user_id) + + total_unused = sum(len(nodes) for nodes in unused_nodes.values()) + logger.info(f"Found {total_unused} unused nodes", unused_nodes={k: len(v) for k, v in unused_nodes.items()}) + + if dry_run: + return { + "status": "dry_run", + "unused_count": total_unused, + "deleted_count": { + "data_items": 0, + "chunks": 0, + "entities": 0, + "summaries": 0, + "associations": 0 + }, + "cleanup_date": datetime.now(timezone.utc).isoformat(), + "preview": { + "chunks": len(unused_nodes["DocumentChunk"]), + "entities": len(unused_nodes["Entity"]), + "summaries": len(unused_nodes["TextSummary"]) + } + } + + # Delete unused nodes + deleted_counts = await _delete_unused_nodes(unused_nodes) + + logger.info("Cleanup completed", deleted_counts=deleted_counts) + + return { + "status": "completed", + "unused_count": total_unused, + "deleted_count": { + "data_items": 0, + "chunks": deleted_counts["DocumentChunk"], + "entities": deleted_counts["Entity"], + "summaries": deleted_counts["TextSummary"], + "associations": deleted_counts["associations"] + }, + "cleanup_date": datetime.now(timezone.utc).isoformat() + } + + +async def _find_unused_nodes( + cutoff_timestamp_ms: int, + user_id: Optional[UUID] = None +) -> Dict[str, list]: + """ + Query Kuzu for nodes with old last_accessed_at timestamps. + + Parameters + ---------- + cutoff_timestamp_ms : int + Cutoff timestamp in milliseconds since epoch + user_id : UUID, optional + Filter by user ID if provided + + Returns + ------- + Dict[str, list] + Dictionary mapping node types to lists of unused node IDs + """ + graph_engine = await get_graph_engine() + + # Query all nodes with their properties + query = "MATCH (n:Node) RETURN n.id, n.type, n.properties" + results = await graph_engine.query(query) + + unused_nodes = { + "DocumentChunk": [], + "Entity": [], + "TextSummary": [] + } + + for node_id, node_type, props_json in results: + # Only process tracked node types + if node_type not in unused_nodes: + continue + + # Parse properties JSON + if props_json: + try: + props = json.loads(props_json) + last_accessed = props.get("last_accessed_at") + + # Check if node is unused (never accessed or accessed before cutoff) + if last_accessed is None or last_accessed < cutoff_timestamp_ms: + # TODO: Add user_id filtering when user ownership is implemented + unused_nodes[node_type].append(node_id) + logger.debug( + f"Found unused {node_type}", + node_id=node_id, + last_accessed=last_accessed + ) + except json.JSONDecodeError: + logger.warning(f"Failed to parse properties for node {node_id}") + continue + + return unused_nodes + + +async def _delete_unused_nodes(unused_nodes: Dict[str, list]) -> Dict[str, int]: + """ + Delete unused nodes from graph and vector databases. + + Parameters + ---------- + unused_nodes : Dict[str, list] + Dictionary mapping node types to lists of node IDs to delete + + Returns + ------- + Dict[str, int] + Count of deleted items by type + """ + graph_engine = await get_graph_engine() + vector_engine = get_vector_engine() + + deleted_counts = { + "DocumentChunk": 0, + "Entity": 0, + "TextSummary": 0, + "associations": 0 + } + + # Count associations before deletion + for node_type, node_ids in unused_nodes.items(): + if not node_ids: + continue + + # Count edges connected to these nodes + for node_id in node_ids: + result = await graph_engine.query( + "MATCH (n:Node {id: $id})-[r:EDGE]-() RETURN count(r)", + {"id": node_id} + ) + if result and len(result) > 0: + deleted_counts["associations"] += result[0][0] + + # Delete from graph database (uses DETACH DELETE, so edges are automatically removed) + for node_type, node_ids in unused_nodes.items(): + if not node_ids: + continue + + logger.info(f"Deleting {len(node_ids)} {node_type} nodes from graph database") + + # Delete nodes in batches + await graph_engine.delete_nodes(node_ids) + deleted_counts[node_type] = len(node_ids) + + # Delete from vector database + vector_collections = { + "DocumentChunk": "DocumentChunk_text", + "Entity": "Entity_name", + "TextSummary": "TextSummary_text" + } + + for node_type, collection_name in vector_collections.items(): + node_ids = unused_nodes[node_type] + if not node_ids: + continue + + logger.info(f"Deleting {len(node_ids)} {node_type} embeddings from vector database") + + try: + # Delete from vector collection + if await vector_engine.has_collection(collection_name): + for node_id in node_ids: + try: + await vector_engine.delete(collection_name, {"id": str(node_id)}) + except Exception as e: + logger.warning(f"Failed to delete {node_id} from {collection_name}: {e}") + except Exception as e: + logger.error(f"Error deleting from vector collection {collection_name}: {e}") + + return deleted_counts From 5080e8f8a5c20d092b917b66eb52a577fe899231 Mon Sep 17 00:00:00 2001 From: chinu0609 Date: Mon, 3 Nov 2025 00:59:04 +0530 Subject: [PATCH 006/176] feat: genarlizing getting entities from triplets --- cognee/modules/graph/utils/__init__.py | 1 + .../graph/utils/get_entity_nodes_from_triplets.py | 13 +++++++++++++ .../modules/retrieval/graph_completion_retriever.py | 12 +++--------- 3 files changed, 17 insertions(+), 9 deletions(-) create mode 100644 cognee/modules/graph/utils/get_entity_nodes_from_triplets.py diff --git a/cognee/modules/graph/utils/__init__.py b/cognee/modules/graph/utils/__init__.py index ebc648495..4c0b29d47 100644 --- a/cognee/modules/graph/utils/__init__.py +++ b/cognee/modules/graph/utils/__init__.py @@ -5,3 +5,4 @@ from .retrieve_existing_edges import retrieve_existing_edges from .convert_node_to_data_point import convert_node_to_data_point from .deduplicate_nodes_and_edges import deduplicate_nodes_and_edges from .resolve_edges_to_text import resolve_edges_to_text +from .get_entity_nodes_from_triplets import get_entity_nodes_from_triplets diff --git a/cognee/modules/graph/utils/get_entity_nodes_from_triplets.py b/cognee/modules/graph/utils/get_entity_nodes_from_triplets.py new file mode 100644 index 000000000..598a36854 --- /dev/null +++ b/cognee/modules/graph/utils/get_entity_nodes_from_triplets.py @@ -0,0 +1,13 @@ + +def get_entity_nodes_from_triplets(triplets): + entity_nodes = [] + seen_ids = set() + for triplet in triplets: + if hasattr(triplet, 'node1') and triplet.node1 and triplet.node1.id not in seen_ids: + entity_nodes.append({"id": str(triplet.node1.id)}) + seen_ids.add(triplet.node1.id) + if hasattr(triplet, 'node2') and triplet.node2 and triplet.node2.id not in seen_ids: + entity_nodes.append({"id": str(triplet.node2.id)}) + seen_ids.add(triplet.node2.id) + + return entity_nodes diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index ac7e45e3c..122cc943f 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -22,6 +22,7 @@ from cognee.modules.engine.models.node_set import NodeSet from cognee.infrastructure.databases.graph import get_graph_engine from cognee.context_global_variables import session_user from cognee.infrastructure.databases.cache.config import CacheConfig +from cognee.modules.graph.utils import get_entity_nodes_from_triplets logger = get_logger("GraphCompletionRetriever") @@ -139,15 +140,8 @@ class GraphCompletionRetriever(BaseGraphRetriever): return [] # context = await self.resolve_edges_to_text(triplets) - entity_nodes = [] - seen_ids = set() - for triplet in triplets: - if hasattr(triplet, 'node1') and triplet.node1 and triplet.node1.id not in seen_ids: - entity_nodes.append({"id": str(triplet.node1.id)}) - seen_ids.add(triplet.node1.id) - if hasattr(triplet, 'node2') and triplet.node2 and triplet.node2.id not in seen_ids: - entity_nodes.append({"id": str(triplet.node2.id)}) - seen_ids.add(triplet.node2.id) + + entity_nodes = get_entity_nodes_from_triplets(triplets) await update_node_access_timestamps(entity_nodes) return triplets From 90d10e6f9af50c85fbbf282dd961719d5da7f922 Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Mon, 3 Nov 2025 15:31:09 +0100 Subject: [PATCH 007/176] test: Add docs tests. Initial commit, still WIP. --- .github/workflows/docs_tests.yml | 18 ++++++ .../tests/docs/guides/custom_data_models.py | 38 +++++++++++++ cognee/tests/docs/guides/custom_prompts.py | 30 ++++++++++ .../docs/guides/custom_tasks_and_pipelines.py | 53 +++++++++++++++++ .../tests/docs/guides/graph_visualization.py | 13 +++++ cognee/tests/docs/guides/low_level_llm.py | 31 ++++++++++ cognee/tests/docs/guides/memify_quickstart.py | 29 ++++++++++ .../tests/docs/guides/ontology_quickstart.py | 30 ++++++++++ cognee/tests/docs/guides/s3_storage.py | 25 ++++++++ cognee/tests/docs/guides/search_basics.py | 17 ++++++ cognee/tests/docs/guides/temporal_cognify.py | 57 +++++++++++++++++++ 11 files changed, 341 insertions(+) create mode 100644 .github/workflows/docs_tests.yml create mode 100644 cognee/tests/docs/guides/custom_data_models.py create mode 100644 cognee/tests/docs/guides/custom_prompts.py create mode 100644 cognee/tests/docs/guides/custom_tasks_and_pipelines.py create mode 100644 cognee/tests/docs/guides/graph_visualization.py create mode 100644 cognee/tests/docs/guides/low_level_llm.py create mode 100644 cognee/tests/docs/guides/memify_quickstart.py create mode 100644 cognee/tests/docs/guides/ontology_quickstart.py create mode 100644 cognee/tests/docs/guides/s3_storage.py create mode 100644 cognee/tests/docs/guides/search_basics.py create mode 100644 cognee/tests/docs/guides/temporal_cognify.py diff --git a/.github/workflows/docs_tests.yml b/.github/workflows/docs_tests.yml new file mode 100644 index 000000000..b3c538668 --- /dev/null +++ b/.github/workflows/docs_tests.yml @@ -0,0 +1,18 @@ +name: Docs Test Suite +permissions: + contents: read + +on: + release: + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +env: + RUNTIME__LOG_LEVEL: ERROR + ENV: 'dev' + +jobs: + diff --git a/cognee/tests/docs/guides/custom_data_models.py b/cognee/tests/docs/guides/custom_data_models.py new file mode 100644 index 000000000..0eb314227 --- /dev/null +++ b/cognee/tests/docs/guides/custom_data_models.py @@ -0,0 +1,38 @@ +import asyncio +from typing import Any +from pydantic import SkipValidation + +import cognee +from cognee.infrastructure.engine import DataPoint +from cognee.infrastructure.engine.models.Edge import Edge +from cognee.tasks.storage import add_data_points + + +class Person(DataPoint): + name: str + # Keep it simple for forward refs / mixed values + knows: SkipValidation[Any] = None # single Person or list[Person] + # Recommended: specify which fields to index for search + metadata: dict = {"index_fields": ["name"]} + + +async def main(): + # Start clean (optional in your app) + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + + alice = Person(name="Alice") + bob = Person(name="Bob") + charlie = Person(name="Charlie") + + # Create relationships - field name becomes edge label + alice.knows = bob + # You can also do lists: alice.knows = [bob, charlie] + + # Optional: add weights and custom relationship types + bob.knows = (Edge(weight=0.9, relationship_type="friend_of"), charlie) + + await add_data_points([alice, bob, charlie]) + + +asyncio.run(main()) diff --git a/cognee/tests/docs/guides/custom_prompts.py b/cognee/tests/docs/guides/custom_prompts.py new file mode 100644 index 000000000..0d0a55a80 --- /dev/null +++ b/cognee/tests/docs/guides/custom_prompts.py @@ -0,0 +1,30 @@ +import asyncio +import cognee +from cognee.api.v1.search import SearchType + +custom_prompt = """ +Extract only people and cities as entities. +Connect people to cities with the relationship "lives_in". +Ignore all other entities. +""" + + +async def main(): + await cognee.add( + [ + "Alice moved to Paris in 2010, while Bob has always lived in New York.", + "Andreas was born in Venice, but later settled in Lisbon.", + "Diana and Tom were born and raised in Helsingy. Diana currently resides in Berlin, while Tom never moved.", + ] + ) + await cognee.cognify(custom_prompt=custom_prompt) + + res = await cognee.search( + query_type=SearchType.GRAPH_COMPLETION, + query_text="Where does Alice live?", + ) + print(res) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/cognee/tests/docs/guides/custom_tasks_and_pipelines.py b/cognee/tests/docs/guides/custom_tasks_and_pipelines.py new file mode 100644 index 000000000..202bb128a --- /dev/null +++ b/cognee/tests/docs/guides/custom_tasks_and_pipelines.py @@ -0,0 +1,53 @@ +import asyncio +from typing import Any, Dict, List +from pydantic import BaseModel, SkipValidation + +import cognee +from cognee.modules.engine.operations.setup import setup +from cognee.infrastructure.llm.LLMGateway import LLMGateway +from cognee.infrastructure.engine import DataPoint +from cognee.tasks.storage import add_data_points +from cognee.modules.pipelines import Task, run_pipeline + + +class Person(DataPoint): + name: str + # Optional relationships (we'll let the LLM populate this) + knows: List["Person"] = [] + # Make names searchable in the vector store + metadata: Dict[str, Any] = {"index_fields": ["name"]} + + +class People(BaseModel): + persons: List[Person] + + +async def extract_people(text: str) -> List[Person]: + system_prompt = ( + "Extract people mentioned in the text. " + "Return as `persons: Person[]` with each Person having `name` and optional `knows` relations. " + "If the text says someone knows someone set `knows` accordingly. " + "Only include facts explicitly stated." + ) + people = await LLMGateway.acreate_structured_output(text, system_prompt, People) + return people.persons + + +async def main(): + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + text = "Alice knows Bob." + + tasks = [ + Task(extract_people), # input: text -> output: list[Person] + Task(add_data_points), # input: list[Person] -> output: list[Person] + ] + + async for _ in run_pipeline(tasks=tasks, data=text, datasets=["people_demo"]): + pass + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/cognee/tests/docs/guides/graph_visualization.py b/cognee/tests/docs/guides/graph_visualization.py new file mode 100644 index 000000000..d463cbb56 --- /dev/null +++ b/cognee/tests/docs/guides/graph_visualization.py @@ -0,0 +1,13 @@ +import asyncio +import cognee +from cognee.api.v1.visualize.visualize import visualize_graph + + +async def main(): + await cognee.add(["Alice knows Bob.", "NLP is a subfield of CS."]) + await cognee.cognify() + + await visualize_graph("./graph_after_cognify.html") + + +asyncio.run(main()) diff --git a/cognee/tests/docs/guides/low_level_llm.py b/cognee/tests/docs/guides/low_level_llm.py new file mode 100644 index 000000000..454f53f44 --- /dev/null +++ b/cognee/tests/docs/guides/low_level_llm.py @@ -0,0 +1,31 @@ +import asyncio + +from pydantic import BaseModel +from typing import List +from cognee.infrastructure.llm.LLMGateway import LLMGateway + + +class MiniEntity(BaseModel): + name: str + type: str + + +class MiniGraph(BaseModel): + nodes: List[MiniEntity] + + +async def main(): + system_prompt = ( + "Extract entities as nodes with name and type. " + "Use concise, literal values present in the text." + ) + + text = "Apple develops iPhone; Audi produces the R8." + + result = await LLMGateway.acreate_structured_output(text, system_prompt, MiniGraph) + print(result) + # MiniGraph(nodes=[MiniEntity(name='Apple', type='Organization'), ...]) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/cognee/tests/docs/guides/memify_quickstart.py b/cognee/tests/docs/guides/memify_quickstart.py new file mode 100644 index 000000000..040654350 --- /dev/null +++ b/cognee/tests/docs/guides/memify_quickstart.py @@ -0,0 +1,29 @@ +import asyncio +import cognee +from cognee import SearchType + + +async def main(): + # 1) Add two short chats and build a graph + await cognee.add( + [ + "We follow PEP8. Add type hints and docstrings.", + "Releases should not be on Friday. Susan must review PRs.", + ], + dataset_name="rules_demo", + ) + await cognee.cognify(datasets=["rules_demo"]) # builds graph + + # 2) Enrich the graph (uses default memify tasks) + await cognee.memify(dataset="rules_demo") + + # 3) Query the new coding rules + rules = await cognee.search( + query_type=SearchType.CODING_RULES, + query_text="List coding rules", + node_name=["coding_agent_rules"], + ) + print("Rules:", rules) + + +asyncio.run(main()) diff --git a/cognee/tests/docs/guides/ontology_quickstart.py b/cognee/tests/docs/guides/ontology_quickstart.py new file mode 100644 index 000000000..2784dab19 --- /dev/null +++ b/cognee/tests/docs/guides/ontology_quickstart.py @@ -0,0 +1,30 @@ +import asyncio +import cognee + + +async def main(): + texts = ["Audi produces the R8 and e-tron.", "Apple develops iPhone and MacBook."] + + await cognee.add(texts) + # or: await cognee.add("/path/to/folder/of/files") + + import os + from cognee.modules.ontology.ontology_config import Config + from cognee.modules.ontology.rdf_xml.RDFLibOntologyResolver import RDFLibOntologyResolver + + ontology_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "ontology_input_example/basic_ontology.owl" + ) + + # Create full config structure manually + config: Config = { + "ontology_config": { + "ontology_resolver": RDFLibOntologyResolver(ontology_file=ontology_path) + } + } + + await cognee.cognify(config=config) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/cognee/tests/docs/guides/s3_storage.py b/cognee/tests/docs/guides/s3_storage.py new file mode 100644 index 000000000..1044e05b4 --- /dev/null +++ b/cognee/tests/docs/guides/s3_storage.py @@ -0,0 +1,25 @@ +import asyncio +import cognee + + +async def main(): + # Single file + await cognee.add("s3://cognee-temp/2024-11-04.md") + + # Folder/prefix (recursively expands) + await cognee.add("s3://cognee-temp") + + # Mixed list + await cognee.add( + [ + "s3://cognee-temp/2024-11-04.md", + "Some inline text to ingest", + ] + ) + + # Process the data + await cognee.cognify() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/cognee/tests/docs/guides/search_basics.py b/cognee/tests/docs/guides/search_basics.py new file mode 100644 index 000000000..67d0c938d --- /dev/null +++ b/cognee/tests/docs/guides/search_basics.py @@ -0,0 +1,17 @@ +import asyncio +import cognee + + +async def main(): + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + + text = "First rule of coding: Do not talk about coding." + + # Make sure you've already run cognee.cognify(...) so the graph has content + answers = await cognee.search(query_text="What are the main themes in my data?") + for answer in answers: + print(answer) + + +asyncio.run(main()) diff --git a/cognee/tests/docs/guides/temporal_cognify.py b/cognee/tests/docs/guides/temporal_cognify.py new file mode 100644 index 000000000..34c1ee33c --- /dev/null +++ b/cognee/tests/docs/guides/temporal_cognify.py @@ -0,0 +1,57 @@ +import asyncio +import cognee + + +async def main(): + text = """ + In 1998 the project launched. In 2001 version 1.0 shipped. In 2004 the team merged + with another group. In 2010 support for v1 ended. + """ + + await cognee.add(text, dataset_name="timeline_demo") + + await cognee.cognify(datasets=["timeline_demo"], temporal_cognify=True) + + from cognee.api.v1.search import SearchType + + # Before / after queries + result = await cognee.search( + query_type=SearchType.TEMPORAL, query_text="What happened before 2000?", top_k=10 + ) + + assert result != [] + + result = await cognee.search( + query_type=SearchType.TEMPORAL, query_text="What happened after 2010?", top_k=10 + ) + + assert result != [] + + # Between queries + result = await cognee.search( + query_type=SearchType.TEMPORAL, query_text="Events between 2001 and 2004", top_k=10 + ) + + assert result != [] + + # Scoped descriptions + result = await cognee.search( + query_type=SearchType.TEMPORAL, + query_text="Key project milestones between 1998 and 2010", + top_k=10, + ) + + assert result != [] + + result = await cognee.search( + query_type=SearchType.TEMPORAL, + query_text="What happened after 2004?", + datasets=["timeline_demo"], + top_k=10, + ) + + assert result != [] + + +if __name__ == "__main__": + asyncio.run(main()) From d34fd9237bf41c6b421bd556541b50ea68246e45 Mon Sep 17 00:00:00 2001 From: chinu0609 Date: Tue, 4 Nov 2025 22:04:32 +0530 Subject: [PATCH 008/176] feat: adding last_acessed in the Data model --- .../e1ec1dcb50b6_add_last_accessed_to_data.py | 30 ++++++ cognee/modules/data/models/Data.py | 1 + .../retrieval/utils/access_tracking.py | 102 ++++++++++++------ 3 files changed, 100 insertions(+), 33 deletions(-) create mode 100644 alembic/versions/e1ec1dcb50b6_add_last_accessed_to_data.py diff --git a/alembic/versions/e1ec1dcb50b6_add_last_accessed_to_data.py b/alembic/versions/e1ec1dcb50b6_add_last_accessed_to_data.py new file mode 100644 index 000000000..0ccefa63b --- /dev/null +++ b/alembic/versions/e1ec1dcb50b6_add_last_accessed_to_data.py @@ -0,0 +1,30 @@ +"""add_last_accessed_to_data + +Revision ID: e1ec1dcb50b6 +Revises: 211ab850ef3d +Create Date: 2025-11-04 21:45:52.642322 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'e1ec1dcb50b6' +down_revision: Union[str, None] = '211ab850ef3d' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column('data', + sa.Column('last_accessed', sa.DateTime(timezone=True), nullable=True) + ) + # Optionally initialize with created_at values for existing records + op.execute("UPDATE data SET last_accessed = created_at") + + +def downgrade() -> None: + op.drop_column('data', 'last_accessed') diff --git a/cognee/modules/data/models/Data.py b/cognee/modules/data/models/Data.py index ef228f2e1..27ab7481e 100644 --- a/cognee/modules/data/models/Data.py +++ b/cognee/modules/data/models/Data.py @@ -36,6 +36,7 @@ class Data(Base): data_size = Column(Integer, nullable=True) # File size in bytes created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc)) + last_accessed = Column(DateTime(timezone=True), nullable=True) datasets = relationship( "Dataset", diff --git a/cognee/modules/retrieval/utils/access_tracking.py b/cognee/modules/retrieval/utils/access_tracking.py index 79afd25db..621e09e27 100644 --- a/cognee/modules/retrieval/utils/access_tracking.py +++ b/cognee/modules/retrieval/utils/access_tracking.py @@ -1,20 +1,27 @@ - """Utilities for tracking data access in retrievers.""" import json from datetime import datetime, timezone from typing import List, Any +from uuid import UUID from cognee.infrastructure.databases.graph import get_graph_engine +from cognee.infrastructure.databases.relational import get_relational_engine +from cognee.modules.data.models import Data from cognee.shared.logging_utils import get_logger +from sqlalchemy import update logger = get_logger(__name__) async def update_node_access_timestamps(items: List[Any]): """ - Update last_accessed_at for nodes in Kuzu graph database. - Automatically determines node type from the graph database. + Update last_accessed_at for nodes in graph database and corresponding Data records in SQL. + + This function: + 1. Updates last_accessed_at in the graph database nodes (in properties JSON) + 2. Traverses to find origin TextDocument nodes + 3. Updates last_accessed in the SQL Data table for those documents Parameters ---------- @@ -26,39 +33,68 @@ async def update_node_access_timestamps(items: List[Any]): graph_engine = await get_graph_engine() timestamp_ms = int(datetime.now(timezone.utc).timestamp() * 1000) + timestamp_dt = datetime.now(timezone.utc) + # Extract node IDs + node_ids = [] for item in items: - # Extract ID from payload item_id = item.payload.get("id") if hasattr(item, 'payload') else item.get("id") - if not item_id: - continue - - # try: - # Query to get both node type and properties in one call - result = await graph_engine.query( - "MATCH (n:Node {id: $id}) RETURN n.type as node_type, n.properties as props", - {"id": str(item_id)} - ) - - if result and len(result) > 0 and result[0]: - node_type = result[0][0] # First column: node_type - props_json = result[0][1] # Second column: properties - - # Parse existing properties JSON - props = json.loads(props_json) if props_json else {} - # Update last_accessed_at with millisecond timestamp - props["last_accessed_at"] = timestamp_ms - - # Write back to graph database - await graph_engine.query( - "MATCH (n:Node {id: $id}) SET n.properties = $props", - {"id": str(item_id), "props": json.dumps(props)} + if item_id: + node_ids.append(str(item_id)) + + if not node_ids: + return + + try: + # Step 1: Batch update graph nodes + for node_id in node_ids: + result = await graph_engine.query( + "MATCH (n:Node {id: $id}) RETURN n.properties", + {"id": node_id} ) - logger.debug(f"Updated access timestamp for {node_type} node {item_id}") + if result and result[0]: + props = json.loads(result[0][0]) if result[0][0] else {} + props["last_accessed_at"] = timestamp_ms - # except Exception as e: - # logger.error(f"Failed to update timestamp for node {item_id}: {e}") - # continue - - logger.debug(f"Updated access timestamps for {len(items)} nodes") + await graph_engine.query( + "MATCH (n:Node {id: $id}) SET n.properties = $props", + {"id": node_id, "props": json.dumps(props)} + ) + + logger.debug(f"Updated access timestamps for {len(node_ids)} graph nodes") + + # Step 2: Find origin TextDocument nodes + origin_query = """ + UNWIND $node_ids AS node_id + MATCH (n:Node {id: node_id}) + OPTIONAL MATCH (n)-[e:EDGE]-(chunk:Node) + WHERE (e.relationship_name = 'contains' OR e.relationship_name = 'made_from') + AND chunk.type = 'DocumentChunk' + OPTIONAL MATCH (chunk)-[e2:EDGE]->(doc:Node) + WHERE e2.relationship_name = 'is_part_of' + AND doc.type IN ['TextDocument', 'PdfDocument', 'AudioDocument', 'ImageDocument', 'UnstructuredDocument'] + RETURN DISTINCT doc.id as doc_id + """ + + result = await graph_engine.query(origin_query, {"node_ids": node_ids}) + + # Extract document IDs + doc_ids = [row[0] for row in result if row and row[0]] if result else [] + + # Step 3: Update SQL Data table + if doc_ids: + db_engine = get_relational_engine() + async with db_engine.get_async_session() as session: + stmt = update(Data).where( + Data.id.in_([UUID(doc_id) for doc_id in doc_ids]) + ).values(last_accessed=timestamp_dt) + + await session.execute(stmt) + await session.commit() + + logger.debug(f"Updated last_accessed for {len(doc_ids)} Data records in SQL") + + except Exception as e: + logger.error(f"Failed to update timestamps: {e}") + raise From 3c0e915812a4ffb8662419647572c6229ed963a9 Mon Sep 17 00:00:00 2001 From: chinu0609 Date: Wed, 5 Nov 2025 12:25:51 +0530 Subject: [PATCH 009/176] fix: removing hard relations --- .../modules/retrieval/utils/access_tracking.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/cognee/modules/retrieval/utils/access_tracking.py b/cognee/modules/retrieval/utils/access_tracking.py index 621e09e27..36c0b7f50 100644 --- a/cognee/modules/retrieval/utils/access_tracking.py +++ b/cognee/modules/retrieval/utils/access_tracking.py @@ -20,7 +20,7 @@ async def update_node_access_timestamps(items: List[Any]): This function: 1. Updates last_accessed_at in the graph database nodes (in properties JSON) - 2. Traverses to find origin TextDocument nodes + 2. Traverses to find origin TextDocument nodes (without hardcoded relationship names) 3. Updates last_accessed in the SQL Data table for those documents Parameters @@ -64,23 +64,21 @@ async def update_node_access_timestamps(items: List[Any]): logger.debug(f"Updated access timestamps for {len(node_ids)} graph nodes") - # Step 2: Find origin TextDocument nodes + # Step 2: Find origin TextDocument nodes (without hardcoded relationship names) origin_query = """ UNWIND $node_ids AS node_id MATCH (n:Node {id: node_id}) OPTIONAL MATCH (n)-[e:EDGE]-(chunk:Node) - WHERE (e.relationship_name = 'contains' OR e.relationship_name = 'made_from') - AND chunk.type = 'DocumentChunk' - OPTIONAL MATCH (chunk)-[e2:EDGE]->(doc:Node) - WHERE e2.relationship_name = 'is_part_of' - AND doc.type IN ['TextDocument', 'PdfDocument', 'AudioDocument', 'ImageDocument', 'UnstructuredDocument'] + WHERE chunk.type = 'DocumentChunk' + OPTIONAL MATCH (chunk)-[e2:EDGE]-(doc:Node) + WHERE doc.type IN ['TextDocument', 'PdfDocument', 'AudioDocument', 'ImageDocument', 'UnstructuredDocument'] RETURN DISTINCT doc.id as doc_id """ result = await graph_engine.query(origin_query, {"node_ids": node_ids}) - # Extract document IDs - doc_ids = [row[0] for row in result if row and row[0]] if result else [] + # Extract and deduplicate document IDs + doc_ids = list(set([row[0] for row in result if row and row[0]])) if result else [] # Step 3: Update SQL Data table if doc_ids: From 9041a804ecc2d0be1903c2de0ac875f32fcc553c Mon Sep 17 00:00:00 2001 From: chinu0609 Date: Wed, 5 Nov 2025 18:32:49 +0530 Subject: [PATCH 010/176] fix: add text_doc flag --- cognee/tasks/cleanup/cleanup_unused_data.py | 520 ++++++++++++-------- 1 file changed, 312 insertions(+), 208 deletions(-) diff --git a/cognee/tasks/cleanup/cleanup_unused_data.py b/cognee/tasks/cleanup/cleanup_unused_data.py index e97692bb4..c9c711fe2 100644 --- a/cognee/tasks/cleanup/cleanup_unused_data.py +++ b/cognee/tasks/cleanup/cleanup_unused_data.py @@ -1,232 +1,336 @@ -""" -Task for automatically deleting unused data from the memify pipeline. - -This task identifies and removes data (chunks, entities, summaries) that hasn't -been accessed by retrievers for a specified period, helping maintain system -efficiency and storage optimization. -""" - -import json -from datetime import datetime, timezone, timedelta -from typing import Optional, Dict, Any -from uuid import UUID - -from cognee.infrastructure.databases.graph import get_graph_engine -from cognee.infrastructure.databases.vector import get_vector_engine -from cognee.shared.logging_utils import get_logger - -logger = get_logger(__name__) +""" +Task for automatically deleting unused data from the memify pipeline. + +This task identifies and removes data (chunks, entities, summaries) that hasn't +been accessed by retrievers for a specified period, helping maintain system +efficiency and storage optimization. +""" + +import json +from datetime import datetime, timezone, timedelta +from typing import Optional, Dict, Any +from uuid import UUID + +from cognee.infrastructure.databases.graph import get_graph_engine +from cognee.infrastructure.databases.vector import get_vector_engine +from cognee.infrastructure.databases.relational import get_relational_engine +from cognee.modules.data.models import Data, DatasetData +from cognee.shared.logging_utils import get_logger +from sqlalchemy import select, or_ +import cognee + +logger = get_logger(__name__) + + +async def cleanup_unused_data( + minutes_threshold: int = 30, + dry_run: bool = True, + user_id: Optional[UUID] = None, + text_doc: bool = False +) -> Dict[str, Any]: + """ + Identify and remove unused data from the memify pipeline. + + Parameters + ---------- + minutes_threshold : int + Minutes since last access to consider data unused (default: 30) + dry_run : bool + If True, only report what would be deleted without actually deleting (default: True) + user_id : UUID, optional + Limit cleanup to specific user's data (default: None) + text_doc : bool + If True, use SQL-based filtering to find unused TextDocuments and call cognee.delete() + for proper whole-document deletion (default: False) + + Returns + ------- + Dict[str, Any] + Cleanup results with status, counts, and timestamp + """ + logger.info( + "Starting cleanup task", + minutes_threshold=minutes_threshold, + dry_run=dry_run, + user_id=str(user_id) if user_id else None, + text_doc=text_doc + ) + + # Calculate cutoff timestamp + cutoff_date = datetime.now(timezone.utc) - timedelta(minutes=minutes_threshold) + + if text_doc: + # SQL-based approach: Find unused TextDocuments and use cognee.delete() + return await _cleanup_via_sql(cutoff_date, dry_run, user_id) + else: + # Graph-based approach: Find unused nodes directly from graph + cutoff_timestamp_ms = int(cutoff_date.timestamp() * 1000) + logger.debug(f"Cutoff timestamp: {cutoff_date.isoformat()} ({cutoff_timestamp_ms}ms)") + + # Find unused nodes + unused_nodes = await _find_unused_nodes(cutoff_timestamp_ms, user_id) + + total_unused = sum(len(nodes) for nodes in unused_nodes.values()) + logger.info(f"Found {total_unused} unused nodes", unused_nodes={k: len(v) for k, v in unused_nodes.items()}) + + if dry_run: + return { + "status": "dry_run", + "unused_count": total_unused, + "deleted_count": { + "data_items": 0, + "chunks": 0, + "entities": 0, + "summaries": 0, + "associations": 0 + }, + "cleanup_date": datetime.now(timezone.utc).isoformat(), + "preview": { + "chunks": len(unused_nodes["DocumentChunk"]), + "entities": len(unused_nodes["Entity"]), + "summaries": len(unused_nodes["TextSummary"]) + } + } + + # Delete unused nodes + deleted_counts = await _delete_unused_nodes(unused_nodes) + + logger.info("Cleanup completed", deleted_counts=deleted_counts) + + return { + "status": "completed", + "unused_count": total_unused, + "deleted_count": { + "data_items": 0, + "chunks": deleted_counts["DocumentChunk"], + "entities": deleted_counts["Entity"], + "summaries": deleted_counts["TextSummary"], + "associations": deleted_counts["associations"] + }, + "cleanup_date": datetime.now(timezone.utc).isoformat() + } -async def cleanup_unused_data( - minutes_threshold: int = 30, - dry_run: bool = True, +async def _cleanup_via_sql( + cutoff_date: datetime, + dry_run: bool, user_id: Optional[UUID] = None ) -> Dict[str, Any]: """ - Identify and remove unused data from the memify pipeline. + SQL-based cleanup: Query Data table for unused documents and use cognee.delete(). Parameters ---------- - minutes_threshold : int - Minutes since last access to consider data unused (default: 30) + cutoff_date : datetime + Cutoff date for last_accessed filtering dry_run : bool - If True, only report what would be deleted without actually deleting (default: True) - user_id : UUID, optional - Limit cleanup to specific user's data (default: None) - - Returns - ------- - Dict[str, Any] - Cleanup results with status, counts, and timestamp - """ - logger.info( - "Starting cleanup task", - minutes_threshold=minutes_threshold, - dry_run=dry_run, - user_id=str(user_id) if user_id else None - ) - - # Calculate cutoff timestamp in milliseconds - cutoff_date = datetime.now(timezone.utc) - timedelta(minutes=minutes_threshold) - cutoff_timestamp_ms = int(cutoff_date.timestamp() * 1000) - - logger.debug(f"Cutoff timestamp: {cutoff_date.isoformat()} ({cutoff_timestamp_ms}ms)") - - # Find unused nodes - unused_nodes = await _find_unused_nodes(cutoff_timestamp_ms, user_id) - - total_unused = sum(len(nodes) for nodes in unused_nodes.values()) - logger.info(f"Found {total_unused} unused nodes", unused_nodes={k: len(v) for k, v in unused_nodes.items()}) - - if dry_run: - return { - "status": "dry_run", - "unused_count": total_unused, - "deleted_count": { - "data_items": 0, - "chunks": 0, - "entities": 0, - "summaries": 0, - "associations": 0 - }, - "cleanup_date": datetime.now(timezone.utc).isoformat(), - "preview": { - "chunks": len(unused_nodes["DocumentChunk"]), - "entities": len(unused_nodes["Entity"]), - "summaries": len(unused_nodes["TextSummary"]) - } - } - - # Delete unused nodes - deleted_counts = await _delete_unused_nodes(unused_nodes) - - logger.info("Cleanup completed", deleted_counts=deleted_counts) - - return { - "status": "completed", - "unused_count": total_unused, - "deleted_count": { - "data_items": 0, - "chunks": deleted_counts["DocumentChunk"], - "entities": deleted_counts["Entity"], - "summaries": deleted_counts["TextSummary"], - "associations": deleted_counts["associations"] - }, - "cleanup_date": datetime.now(timezone.utc).isoformat() - } - - -async def _find_unused_nodes( - cutoff_timestamp_ms: int, - user_id: Optional[UUID] = None -) -> Dict[str, list]: - """ - Query Kuzu for nodes with old last_accessed_at timestamps. - - Parameters - ---------- - cutoff_timestamp_ms : int - Cutoff timestamp in milliseconds since epoch + If True, only report what would be deleted user_id : UUID, optional Filter by user ID if provided Returns ------- - Dict[str, list] - Dictionary mapping node types to lists of unused node IDs + Dict[str, Any] + Cleanup results """ - graph_engine = await get_graph_engine() + db_engine = get_relational_engine() - # Query all nodes with their properties - query = "MATCH (n:Node) RETURN n.id, n.type, n.properties" - results = await graph_engine.query(query) - - unused_nodes = { - "DocumentChunk": [], - "Entity": [], - "TextSummary": [] - } - - for node_id, node_type, props_json in results: - # Only process tracked node types - if node_type not in unused_nodes: - continue - - # Parse properties JSON - if props_json: - try: - props = json.loads(props_json) - last_accessed = props.get("last_accessed_at") - - # Check if node is unused (never accessed or accessed before cutoff) - if last_accessed is None or last_accessed < cutoff_timestamp_ms: - # TODO: Add user_id filtering when user ownership is implemented - unused_nodes[node_type].append(node_id) - logger.debug( - f"Found unused {node_type}", - node_id=node_id, - last_accessed=last_accessed - ) - except json.JSONDecodeError: - logger.warning(f"Failed to parse properties for node {node_id}") - continue - - return unused_nodes - - -async def _delete_unused_nodes(unused_nodes: Dict[str, list]) -> Dict[str, int]: - """ - Delete unused nodes from graph and vector databases. - - Parameters - ---------- - unused_nodes : Dict[str, list] - Dictionary mapping node types to lists of node IDs to delete - - Returns - ------- - Dict[str, int] - Count of deleted items by type - """ - graph_engine = await get_graph_engine() - vector_engine = get_vector_engine() - - deleted_counts = { - "DocumentChunk": 0, - "Entity": 0, - "TextSummary": 0, - "associations": 0 - } - - # Count associations before deletion - for node_type, node_ids in unused_nodes.items(): - if not node_ids: - continue - - # Count edges connected to these nodes - for node_id in node_ids: - result = await graph_engine.query( - "MATCH (n:Node {id: $id})-[r:EDGE]-() RETURN count(r)", - {"id": node_id} + async with db_engine.get_async_session() as session: + # Query for Data records with old last_accessed timestamps + query = select(Data, DatasetData).join( + DatasetData, Data.id == DatasetData.data_id + ).where( + or_( + Data.last_accessed < cutoff_date, + Data.last_accessed.is_(None) ) - if result and len(result) > 0: - deleted_counts["associations"] += result[0][0] + ) + + if user_id: + from cognee.modules.data.models import Dataset + query = query.join(Dataset, DatasetData.dataset_id == Dataset.id).where( + Dataset.owner_id == user_id + ) + + result = await session.execute(query) + unused_data = result.all() - # Delete from graph database (uses DETACH DELETE, so edges are automatically removed) - for node_type, node_ids in unused_nodes.items(): - if not node_ids: - continue - - logger.info(f"Deleting {len(node_ids)} {node_type} nodes from graph database") - - # Delete nodes in batches - await graph_engine.delete_nodes(node_ids) - deleted_counts[node_type] = len(node_ids) + logger.info(f"Found {len(unused_data)} unused documents in SQL") - # Delete from vector database - vector_collections = { - "DocumentChunk": "DocumentChunk_text", - "Entity": "Entity_name", - "TextSummary": "TextSummary_text" - } + if dry_run: + return { + "status": "dry_run", + "unused_count": len(unused_data), + "deleted_count": { + "data_items": 0, + "documents": 0 + }, + "cleanup_date": datetime.now(timezone.utc).isoformat(), + "preview": { + "documents": len(unused_data) + } + } - for node_type, collection_name in vector_collections.items(): - node_ids = unused_nodes[node_type] - if not node_ids: - continue - - logger.info(f"Deleting {len(node_ids)} {node_type} embeddings from vector database") - + # Delete each document using cognee.delete() + deleted_count = 0 + from cognee.modules.users.methods import get_default_user + user = await get_default_user() if user_id is None else None + + for data, dataset_data in unused_data: try: - # Delete from vector collection - if await vector_engine.has_collection(collection_name): - for node_id in node_ids: - try: - await vector_engine.delete(collection_name, {"id": str(node_id)}) - except Exception as e: - logger.warning(f"Failed to delete {node_id} from {collection_name}: {e}") + await cognee.delete( + data_id=data.id, + dataset_id=dataset_data.dataset_id, + mode="hard", # Use hard mode to also remove orphaned entities + user=user + ) + deleted_count += 1 + logger.info(f"Deleted document {data.id} from dataset {dataset_data.dataset_id}") except Exception as e: - logger.error(f"Error deleting from vector collection {collection_name}: {e}") + logger.error(f"Failed to delete document {data.id}: {e}") + logger.info("Cleanup completed", deleted_count=deleted_count) + + return { + "status": "completed", + "unused_count": len(unused_data), + "deleted_count": { + "data_items": deleted_count, + "documents": deleted_count + }, + "cleanup_date": datetime.now(timezone.utc).isoformat() + } + + +async def _find_unused_nodes( + cutoff_timestamp_ms: int, + user_id: Optional[UUID] = None +) -> Dict[str, list]: + """ + Query Kuzu for nodes with old last_accessed_at timestamps. + + Parameters + ---------- + cutoff_timestamp_ms : int + Cutoff timestamp in milliseconds since epoch + user_id : UUID, optional + Filter by user ID if provided + + Returns + ------- + Dict[str, list] + Dictionary mapping node types to lists of unused node IDs + """ + graph_engine = await get_graph_engine() + + # Query all nodes with their properties + query = "MATCH (n:Node) RETURN n.id, n.type, n.properties" + results = await graph_engine.query(query) + + unused_nodes = { + "DocumentChunk": [], + "Entity": [], + "TextSummary": [] + } + + for node_id, node_type, props_json in results: + # Only process tracked node types + if node_type not in unused_nodes: + continue + + # Parse properties JSON + if props_json: + try: + props = json.loads(props_json) + last_accessed = props.get("last_accessed_at") + + # Check if node is unused (never accessed or accessed before cutoff) + if last_accessed is None or last_accessed < cutoff_timestamp_ms: + unused_nodes[node_type].append(node_id) + logger.debug( + f"Found unused {node_type}", + node_id=node_id, + last_accessed=last_accessed + ) + except json.JSONDecodeError: + logger.warning(f"Failed to parse properties for node {node_id}") + continue + + return unused_nodes + + +async def _delete_unused_nodes(unused_nodes: Dict[str, list]) -> Dict[str, int]: + """ + Delete unused nodes from graph and vector databases. + + Parameters + ---------- + unused_nodes : Dict[str, list] + Dictionary mapping node types to lists of node IDs to delete + + Returns + ------- + Dict[str, int] + Count of deleted items by type + """ + graph_engine = await get_graph_engine() + vector_engine = get_vector_engine() + + deleted_counts = { + "DocumentChunk": 0, + "Entity": 0, + "TextSummary": 0, + "associations": 0 + } + + # Count associations before deletion + for node_type, node_ids in unused_nodes.items(): + if not node_ids: + continue + + # Count edges connected to these nodes + for node_id in node_ids: + result = await graph_engine.query( + "MATCH (n:Node {id: $id})-[r:EDGE]-() RETURN count(r)", + {"id": node_id} + ) + if result and len(result) > 0: + deleted_counts["associations"] += result[0][0] + + # Delete from graph database (uses DETACH DELETE, so edges are automatically removed) + for node_type, node_ids in unused_nodes.items(): + if not node_ids: + continue + + logger.info(f"Deleting {len(node_ids)} {node_type} nodes from graph database") + + # Delete nodes in batches + await graph_engine.delete_nodes(node_ids) + deleted_counts[node_type] = len(node_ids) + + # Delete from vector database + vector_collections = { + "DocumentChunk": "DocumentChunk_text", + "Entity": "Entity_name", + "TextSummary": "TextSummary_text" + } + + for node_type, collection_name in vector_collections.items(): + node_ids = unused_nodes[node_type] + if not node_ids: + continue + + logger.info(f"Deleting {len(node_ids)} {node_type} embeddings from vector database") + + try: + # Delete from vector collection + if await vector_engine.has_collection(collection_name): + for node_id in node_ids: + try: + await vector_engine.delete(collection_name, {"id": str(node_id)}) + except Exception as e: + logger.warning(f"Failed to delete {node_id} from {collection_name}: {e}") + except Exception as e: + logger.error(f"Error deleting from vector collection {collection_name}: {e}") + return deleted_counts From ff263c0132b170b3c03961606db56c2a174d2b90 Mon Sep 17 00:00:00 2001 From: chinu0609 Date: Wed, 5 Nov 2025 18:40:58 +0530 Subject: [PATCH 011/176] fix: add column check in migration --- .../e1ec1dcb50b6_add_last_accessed_to_data.py | 28 +++++++++++++++---- 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/alembic/versions/e1ec1dcb50b6_add_last_accessed_to_data.py b/alembic/versions/e1ec1dcb50b6_add_last_accessed_to_data.py index 0ccefa63b..267e11fb2 100644 --- a/alembic/versions/e1ec1dcb50b6_add_last_accessed_to_data.py +++ b/alembic/versions/e1ec1dcb50b6_add_last_accessed_to_data.py @@ -17,14 +17,30 @@ down_revision: Union[str, None] = '211ab850ef3d' branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None +def _get_column(inspector, table, name, schema=None): + for col in inspector.get_columns(table, schema=schema): + if col["name"] == name: + return col + return None + def upgrade() -> None: - op.add_column('data', - sa.Column('last_accessed', sa.DateTime(timezone=True), nullable=True) - ) - # Optionally initialize with created_at values for existing records - op.execute("UPDATE data SET last_accessed = created_at") + conn = op.get_bind() + insp = sa.inspect(conn) + + last_accessed_column = _get_column(insp, "data", "last_accessed") + if not last_accessed_column: + op.add_column('data', + sa.Column('last_accessed', sa.DateTime(timezone=True), nullable=True) + ) + # Optionally initialize with created_at values for existing records + op.execute("UPDATE data SET last_accessed = created_at") def downgrade() -> None: - op.drop_column('data', 'last_accessed') + conn = op.get_bind() + insp = sa.inspect(conn) + + last_accessed_column = _get_column(insp, "data", "last_accessed") + if last_accessed_column: + op.drop_column('data', 'last_accessed') From c5f0c4af87ff13bf8e3cbe0f4e9163ece44c3094 Mon Sep 17 00:00:00 2001 From: chinu0609 Date: Wed, 5 Nov 2025 20:22:17 +0530 Subject: [PATCH 012/176] fix: add text_doc flag --- cognee/modules/retrieval/utils/access_tracking.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/cognee/modules/retrieval/utils/access_tracking.py b/cognee/modules/retrieval/utils/access_tracking.py index 36c0b7f50..65d597a93 100644 --- a/cognee/modules/retrieval/utils/access_tracking.py +++ b/cognee/modules/retrieval/utils/access_tracking.py @@ -67,12 +67,9 @@ async def update_node_access_timestamps(items: List[Any]): # Step 2: Find origin TextDocument nodes (without hardcoded relationship names) origin_query = """ UNWIND $node_ids AS node_id - MATCH (n:Node {id: node_id}) - OPTIONAL MATCH (n)-[e:EDGE]-(chunk:Node) - WHERE chunk.type = 'DocumentChunk' - OPTIONAL MATCH (chunk)-[e2:EDGE]-(doc:Node) - WHERE doc.type IN ['TextDocument', 'PdfDocument', 'AudioDocument', 'ImageDocument', 'UnstructuredDocument'] - RETURN DISTINCT doc.id as doc_id + MATCH (chunk:Node {id: node_id})-[e:EDGE]-(doc:Node) + WHERE chunk.type = 'DocumentChunk' AND doc.type IN ['TextDocument', 'Document'] + RETURN DISTINCT doc.id """ result = await graph_engine.query(origin_query, {"node_ids": node_ids}) From fdf037b3d0117bd29f0c541ed027895c070678df Mon Sep 17 00:00:00 2001 From: chinu0609 Date: Thu, 6 Nov 2025 23:00:56 +0530 Subject: [PATCH 013/176] fix: min to days --- cognee/tasks/cleanup/cleanup_unused_data.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/cognee/tasks/cleanup/cleanup_unused_data.py b/cognee/tasks/cleanup/cleanup_unused_data.py index c9c711fe2..4df622a2c 100644 --- a/cognee/tasks/cleanup/cleanup_unused_data.py +++ b/cognee/tasks/cleanup/cleanup_unused_data.py @@ -23,7 +23,7 @@ logger = get_logger(__name__) async def cleanup_unused_data( - minutes_threshold: int = 30, + days_threshold: Optional[int], dry_run: bool = True, user_id: Optional[UUID] = None, text_doc: bool = False @@ -33,8 +33,8 @@ async def cleanup_unused_data( Parameters ---------- - minutes_threshold : int - Minutes since last access to consider data unused (default: 30) + days_threshold : int + days since last access to consider data unused dry_run : bool If True, only report what would be deleted without actually deleting (default: True) user_id : UUID, optional @@ -50,14 +50,14 @@ async def cleanup_unused_data( """ logger.info( "Starting cleanup task", - minutes_threshold=minutes_threshold, + days_threshold=days_threshold, dry_run=dry_run, user_id=str(user_id) if user_id else None, text_doc=text_doc ) # Calculate cutoff timestamp - cutoff_date = datetime.now(timezone.utc) - timedelta(minutes=minutes_threshold) + cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_threshold) if text_doc: # SQL-based approach: Find unused TextDocuments and use cognee.delete() From 84c8e07ddd980af7c11b89c7e510b38e5c44f119 Mon Sep 17 00:00:00 2001 From: chinu0609 Date: Fri, 7 Nov 2025 12:03:17 +0530 Subject: [PATCH 014/176] fix: remove uneccessary imports --- cognee/modules/chunking/models/DocumentChunk.py | 2 -- cognee/modules/engine/models/Entity.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/cognee/modules/chunking/models/DocumentChunk.py b/cognee/modules/chunking/models/DocumentChunk.py index a9fb08a9e..e2b216a9b 100644 --- a/cognee/modules/chunking/models/DocumentChunk.py +++ b/cognee/modules/chunking/models/DocumentChunk.py @@ -1,7 +1,5 @@ from typing import List, Union -from pydantic import BaseModel, Field -from datetime import datetime, timezone from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine.models.Edge import Edge from cognee.modules.data.processing.document_types import Document diff --git a/cognee/modules/engine/models/Entity.py b/cognee/modules/engine/models/Entity.py index 4083cd2e6..a34a6503c 100644 --- a/cognee/modules/engine/models/Entity.py +++ b/cognee/modules/engine/models/Entity.py @@ -1,8 +1,6 @@ from cognee.infrastructure.engine import DataPoint from cognee.modules.engine.models.EntityType import EntityType from typing import Optional -from datetime import datetime, timezone -from pydantic import BaseModel, Field class Entity(DataPoint): name: str From 84bd2f38f7513c244ed1040937a1e5a5297cec2e Mon Sep 17 00:00:00 2001 From: chinu0609 Date: Fri, 7 Nov 2025 12:12:46 +0530 Subject: [PATCH 015/176] fix: remove uneccessary imports --- cognee/tasks/summarization/models.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/cognee/tasks/summarization/models.py b/cognee/tasks/summarization/models.py index 8cee2ade3..8420cfaa5 100644 --- a/cognee/tasks/summarization/models.py +++ b/cognee/tasks/summarization/models.py @@ -1,7 +1,5 @@ -from pydantic import BaseModel, Field from typing import Union -from datetime import datetime, timezone from cognee.infrastructure.engine import DataPoint from cognee.modules.chunking.models import DocumentChunk from cognee.shared.CodeGraphEntities import CodeFile, CodePart From d351c9a009d12a8a8a4869afa7aee38c61482e21 Mon Sep 17 00:00:00 2001 From: chinu0609 Date: Mon, 10 Nov 2025 21:58:01 +0530 Subject: [PATCH 016/176] fix: return chunk payload --- cognee/modules/retrieval/chunks_retriever.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cognee/modules/retrieval/chunks_retriever.py b/cognee/modules/retrieval/chunks_retriever.py index be1f95811..b7a90238a 100644 --- a/cognee/modules/retrieval/chunks_retriever.py +++ b/cognee/modules/retrieval/chunks_retriever.py @@ -57,6 +57,7 @@ class ChunksRetriever(BaseRetriever): chunk_payloads = [result.payload for result in found_chunks] logger.info(f"Returning {len(chunk_payloads)} chunk payloads") + return chunk_payloads async def get_completion( self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None From ac3300760b7a521aebe452d041bb7ceaa35f8052 Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Wed, 12 Nov 2025 12:26:28 +0100 Subject: [PATCH 017/176] test: add search tests docs --- cognee/tests/docs/guides/search_basics.py | 46 +++++++++++++++++++++-- 1 file changed, 43 insertions(+), 3 deletions(-) diff --git a/cognee/tests/docs/guides/search_basics.py b/cognee/tests/docs/guides/search_basics.py index 67d0c938d..09dee3f92 100644 --- a/cognee/tests/docs/guides/search_basics.py +++ b/cognee/tests/docs/guides/search_basics.py @@ -1,17 +1,57 @@ import asyncio import cognee +from cognee.modules.search.types import SearchType, CombinedSearchResult + async def main(): await cognee.prune.prune_data() await cognee.prune.prune_system(metadata=True) - text = "First rule of coding: Do not talk about coding." + text = """ + Natural language processing (NLP) is an interdisciplinary + subfield of computer science and information retrieval. + First rule of coding: Do not talk about coding. + """ + + text2 = """ + Sandwiches are best served toasted with cheese, ham, mayo, + lettuce, mustard, and salt & pepper. + """ + + await cognee.add(text, dataset_name="NLP_coding") + await cognee.add(text2, dataset_name="Sandwiches") + await cognee.add(text2) + + await cognee.cognify() # Make sure you've already run cognee.cognify(...) so the graph has content answers = await cognee.search(query_text="What are the main themes in my data?") - for answer in answers: - print(answer) + assert len(answers) > 0 + answers = await cognee.search( + query_text="List coding guidelines", + query_type=SearchType.CODING_RULES, + ) + assert len(answers) > 0 + + answers = await cognee.search( + query_text="Give me a confident answer: What is NLP?", + system_prompt="Answer succinctly and state confidence at the end.", + ) + assert len(answers) > 0 + + answers = await cognee.search( + query_text="Tell me about NLP", + only_context=True, + ) + assert len(answers) > 0 + + answers = await cognee.search( + query_text="Quarterly financial highlights", + datasets=["NLP_coding", "Sandwiches"], + use_combined_context=True, + ) + assert isinstance(answers, CombinedSearchResult) asyncio.run(main()) From 503bdc34f38e18e2ec3dccb6e47aaff669702f55 Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Wed, 12 Nov 2025 13:20:23 +0100 Subject: [PATCH 018/176] test: add tests to workflows --- .github/workflows/docs_tests.yml | 276 ++++++++++++++++++++++++++++- .github/workflows/release_test.yml | 5 + 2 files changed, 274 insertions(+), 7 deletions(-) diff --git a/.github/workflows/docs_tests.yml b/.github/workflows/docs_tests.yml index b3c538668..7f7282bb2 100644 --- a/.github/workflows/docs_tests.yml +++ b/.github/workflows/docs_tests.yml @@ -1,18 +1,280 @@ -name: Docs Test Suite +name: Docs Tests + permissions: contents: read on: - release: workflow_dispatch: - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true + workflow_call: + secrets: + LLM_PROVIDER: + required: true + LLM_MODEL: + required: true + LLM_ENDPOINT: + required: true + LLM_API_KEY: + required: true + LLM_API_VERSION: + required: true + EMBEDDING_PROVIDER: + required: true + EMBEDDING_MODEL: + required: true + EMBEDDING_ENDPOINT: + required: true + EMBEDDING_API_KEY: + required: true + EMBEDDING_API_VERSION: + required: true env: - RUNTIME__LOG_LEVEL: ERROR ENV: 'dev' jobs: + test-search-basics: + name: Test Search Basics + runs-on: ubuntu-22.04 + steps: + - name: Check out repository + uses: actions/checkout@v4 + - name: Cognee Setup + uses: ./.github/actions/cognee_setup + with: + python-version: '3.11.x' + + - name: Run Search Basics Test + env: + LLM_MODEL: ${{ secrets.LLM_MODEL }} + LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} + LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} + EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} + EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} + EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} + EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} + run: uv run python ./cognee/tests/docs/guides/search_basics.py + + test-temporal-cognify: + name: Test Temporal Cognify + runs-on: ubuntu-22.04 + steps: + - name: Check out repository + uses: actions/checkout@v4 + + - name: Cognee Setup + uses: ./.github/actions/cognee_setup + with: + python-version: '3.11.x' + + - name: Run Temporal Cognify Test + env: + LLM_MODEL: ${{ secrets.LLM_MODEL }} + LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} + LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} + EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} + EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} + EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} + EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} + run: uv run python ./cognee/tests/docs/guides/temporal_cognify.py + + test-ontology-quickstart: + name: Test Temporal Cognify + runs-on: ubuntu-22.04 + steps: + - name: Check out repository + uses: actions/checkout@v4 + + - name: Cognee Setup + uses: ./.github/actions/cognee_setup + with: + python-version: '3.11.x' + + - name: Run Temporal Cognify Test + env: + LLM_MODEL: ${{ secrets.LLM_MODEL }} + LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} + LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} + EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} + EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} + EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} + EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} + run: uv run python ./cognee/tests/docs/guides/temporal_cognify.py + + test-s3-storage: + name: Test S3 Docs Guide + runs-on: ubuntu-22.04 + steps: + - name: Check out repository + uses: actions/checkout@v4 + + - name: Cognee Setup + uses: ./.github/actions/cognee_setup + with: + python-version: '3.11.x' + extra-dependencies: "aws" + + - name: Run S3 Docs Guide Test + env: + ENABLE_BACKEND_ACCESS_CONTROL: True + LLM_MODEL: ${{ secrets.LLM_MODEL }} + LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} + LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} + EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} + EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} + EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} + EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} + STORAGE_BACKEND: s3 + AWS_REGION: eu-west-1 + AWS_ENDPOINT_URL: https://s3-eu-west-1.amazonaws.com + AWS_ACCESS_KEY_ID: ${{ secrets.AWS_S3_DEV_USER_KEY_ID }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_S3_DEV_USER_SECRET_KEY }} + run: uv run python ./cognee/tests/docs/guides/s3_storage.py + + test-graph-visualization: + name: Test Graph Visualization + runs-on: ubuntu-22.04 + steps: + - name: Check out repository + uses: actions/checkout@v4 + + - name: Cognee Setup + uses: ./.github/actions/cognee_setup + with: + python-version: '3.11.x' + + - name: Run Graph Visualization Test + env: + LLM_MODEL: ${{ secrets.LLM_MODEL }} + LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} + LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} + EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} + EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} + EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} + EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} + run: uv run python ./cognee/tests/docs/guides/graph_visualization.py + + test-low-level-llm: + name: Test Low Level LLM + runs-on: ubuntu-22.04 + steps: + - name: Check out repository + uses: actions/checkout@v4 + + - name: Cognee Setup + uses: ./.github/actions/cognee_setup + with: + python-version: '3.11.x' + + - name: Run Low Level LLM Test + env: + LLM_MODEL: ${{ secrets.LLM_MODEL }} + LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} + LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} + EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} + EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} + EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} + EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} + run: uv run python ./cognee/tests/docs/guides/low_level_llm.py + + test-memify-quickstart: + name: Test Memify Quickstart + runs-on: ubuntu-22.04 + steps: + - name: Check out repository + uses: actions/checkout@v4 + + - name: Cognee Setup + uses: ./.github/actions/cognee_setup + with: + python-version: '3.11.x' + + - name: Run Memify Quickstart Test + env: + LLM_MODEL: ${{ secrets.LLM_MODEL }} + LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} + LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} + EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} + EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} + EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} + EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} + run: uv run python ./cognee/tests/docs/guides/memify_quickstart.py + + test-custom-data-models: + name: Test Custom Data Models + runs-on: ubuntu-22.04 + steps: + - name: Check out repository + uses: actions/checkout@v4 + + - name: Cognee Setup + uses: ./.github/actions/cognee_setup + with: + python-version: '3.11.x' + + - name: Run Custom Data Models Test + env: + LLM_MODEL: ${{ secrets.LLM_MODEL }} + LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} + LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} + EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} + EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} + EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} + EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} + run: uv run python ./cognee/tests/docs/guides/custom_data_models.py + + test-custom-tasks-and-pipelines: + name: Test Custom Tasks and Pipelines + runs-on: ubuntu-22.04 + steps: + - name: Check out repository + uses: actions/checkout@v4 + + - name: Cognee Setup + uses: ./.github/actions/cognee_setup + with: + python-version: '3.11.x' + + - name: Run Custom Tasks and Pipelines Test + env: + LLM_MODEL: ${{ secrets.LLM_MODEL }} + LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} + LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} + EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} + EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} + EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} + EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} + run: uv run python ./cognee/tests/docs/guides/custom_tasks_and_pipelines.py + + test-custom-prompts: + name: Test Custom Prompts + runs-on: ubuntu-22.04 + steps: + - name: Check out repository + uses: actions/checkout@v4 + + - name: Cognee Setup + uses: ./.github/actions/cognee_setup + with: + python-version: '3.11.x' + + - name: Run Custom Prompts Test + env: + LLM_MODEL: ${{ secrets.LLM_MODEL }} + LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} + LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} + EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} + EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} + EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} + EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} + run: uv run python ./cognee/tests/docs/guides/custom_prompts.py \ No newline at end of file diff --git a/.github/workflows/release_test.yml b/.github/workflows/release_test.yml index 6ac3ca515..89540fcfb 100644 --- a/.github/workflows/release_test.yml +++ b/.github/workflows/release_test.yml @@ -14,4 +14,9 @@ jobs: load-tests: name: Load Tests uses: ./.github/workflows/load_tests.yml + secrets: inherit + + docs-tests: + name: Docs Tests + uses: ./.github/workflows/docs_tests.yml secrets: inherit \ No newline at end of file From 1e56d6dc389e1f33c08a7ee897689a941a7b8a9f Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Wed, 12 Nov 2025 13:42:53 +0100 Subject: [PATCH 019/176] chore: ruff format --- cognee/tests/docs/guides/search_basics.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cognee/tests/docs/guides/search_basics.py b/cognee/tests/docs/guides/search_basics.py index 09dee3f92..f1847ad4b 100644 --- a/cognee/tests/docs/guides/search_basics.py +++ b/cognee/tests/docs/guides/search_basics.py @@ -54,4 +54,5 @@ async def main(): ) assert isinstance(answers, CombinedSearchResult) + asyncio.run(main()) From 82d48663bb5d722ec9310f5745998299793fabaf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9s=20Pe=C3=B1a=20del=20R=C3=ADo?= Date: Mon, 17 Nov 2025 23:27:31 +0100 Subject: [PATCH 020/176] Add custom label support to Data model (#1769) --- .../a1b2c3d4e5f6_add_label_column_to_data.py | 26 +++++++++++++++++++ .../datasets/routers/get_datasets_router.py | 1 + cognee/modules/data/models/Data.py | 3 ++- cognee/tasks/ingestion/data_item.py | 8 ++++++ cognee/tasks/ingestion/ingest_data.py | 14 ++++++++-- uv.lock | 12 +++++++-- 6 files changed, 59 insertions(+), 5 deletions(-) create mode 100644 alembic/versions/a1b2c3d4e5f6_add_label_column_to_data.py create mode 100644 cognee/tasks/ingestion/data_item.py diff --git a/alembic/versions/a1b2c3d4e5f6_add_label_column_to_data.py b/alembic/versions/a1b2c3d4e5f6_add_label_column_to_data.py new file mode 100644 index 000000000..8e7bc19b1 --- /dev/null +++ b/alembic/versions/a1b2c3d4e5f6_add_label_column_to_data.py @@ -0,0 +1,26 @@ +"""Add sync_operations table + +Revision ID: a1b2c3d4e5f6 +Revises: 211ab850ef3d +Create Date: 2025-11-17 17:54:32.123456 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = "a1b2c3d4e5f6" +down_revision: Union[str, None] = "211ab850ef3d" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +def upgrade() -> None: + op.add_column( + "data", + sa.Column("label", sa.String(), nullable=True)), + +def downgrade() -> None: + op.drop_column("data", "label") \ No newline at end of file diff --git a/cognee/api/v1/datasets/routers/get_datasets_router.py b/cognee/api/v1/datasets/routers/get_datasets_router.py index eff87b3af..84ba3126b 100644 --- a/cognee/api/v1/datasets/routers/get_datasets_router.py +++ b/cognee/api/v1/datasets/routers/get_datasets_router.py @@ -44,6 +44,7 @@ class DatasetDTO(OutDTO): class DataDTO(OutDTO): id: UUID name: str + label: Optional[str] = None created_at: datetime updated_at: Optional[datetime] = None extension: str diff --git a/cognee/modules/data/models/Data.py b/cognee/modules/data/models/Data.py index ef228f2e1..3cdead9d9 100644 --- a/cognee/modules/data/models/Data.py +++ b/cognee/modules/data/models/Data.py @@ -13,7 +13,7 @@ class Data(Base): __tablename__ = "data" id = Column(UUID, primary_key=True, default=uuid4) - + label = Column(String, nullable=True) name = Column(String) extension = Column(String) mime_type = Column(String) @@ -49,6 +49,7 @@ class Data(Base): return { "id": str(self.id), "name": self.name, + "label": self.label, "extension": self.extension, "mimeType": self.mime_type, "rawDataLocation": self.raw_data_location, diff --git a/cognee/tasks/ingestion/data_item.py b/cognee/tasks/ingestion/data_item.py new file mode 100644 index 000000000..23570bf77 --- /dev/null +++ b/cognee/tasks/ingestion/data_item.py @@ -0,0 +1,8 @@ +from dataclasses import dataclass +from typing import Any, Dict, Optional + +@dataclass +class DataItem: + data: Any + label: Optional[str] = None + diff --git a/cognee/tasks/ingestion/ingest_data.py b/cognee/tasks/ingestion/ingest_data.py index 0572d0f1e..3f38dc6db 100644 --- a/cognee/tasks/ingestion/ingest_data.py +++ b/cognee/tasks/ingestion/ingest_data.py @@ -20,7 +20,7 @@ from cognee.modules.data.methods import ( from .save_data_item_to_storage import save_data_item_to_storage from .data_item_to_text_file import data_item_to_text_file - +from .data_item import DataItem async def ingest_data( data: Any, @@ -78,8 +78,16 @@ async def ingest_data( dataset_data_map = {str(data.id): True for data in dataset_data} for data_item in data: + # Support for DataItem (custom label + data wrapper) + current_label = None + underlying_data = data_item + + if isinstance(data_item, DataItem): + underlying_data = data_item.data + current_label = data_item.label + # Get file path of data item or create a file if it doesn't exist - original_file_path = await save_data_item_to_storage(data_item) + original_file_path = await save_data_item_to_storage(underlying_data) # Transform file path to be OS usable actual_file_path = get_data_file_path(original_file_path) @@ -139,6 +147,7 @@ async def ingest_data( data_point.external_metadata = ext_metadata data_point.node_set = json.dumps(node_set) if node_set else None data_point.tenant_id = user.tenant_id if user.tenant_id else None + data_point.label = current_label # Check if data is already in dataset if str(data_point.id) in dataset_data_map: @@ -169,6 +178,7 @@ async def ingest_data( tenant_id=user.tenant_id if user.tenant_id else None, pipeline_status={}, token_count=-1, + label = current_label ) new_datapoints.append(data_point) diff --git a/uv.lock b/uv.lock index e2fc1df83..3ed54543d 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.10, <3.14" resolution-markers = [ "python_full_version >= '3.13' and platform_python_implementation != 'PyPy' and sys_platform == 'darwin'", @@ -929,7 +929,7 @@ wheels = [ [[package]] name = "cognee" -version = "0.3.9" +version = "0.4.0" source = { editable = "." } dependencies = [ { name = "aiofiles" }, @@ -2529,6 +2529,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7f/91/ae2eb6b7979e2f9b035a9f612cf70f1bf54aad4e1d125129bef1eae96f19/greenlet-3.2.4-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c2ca18a03a8cfb5b25bc1cbe20f3d9a4c80d8c3b13ba3df49ac3961af0b1018d", size = 584358, upload-time = "2025-08-07T13:18:23.708Z" }, { url = "https://files.pythonhosted.org/packages/f7/85/433de0c9c0252b22b16d413c9407e6cb3b41df7389afc366ca204dbc1393/greenlet-3.2.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:9fe0a28a7b952a21e2c062cd5756d34354117796c6d9215a87f55e38d15402c5", size = 1113550, upload-time = "2025-08-07T13:42:37.467Z" }, { url = "https://files.pythonhosted.org/packages/a1/8d/88f3ebd2bc96bf7747093696f4335a0a8a4c5acfcf1b757717c0d2474ba3/greenlet-3.2.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8854167e06950ca75b898b104b63cc646573aa5fef1353d4508ecdd1ee76254f", size = 1137126, upload-time = "2025-08-07T13:18:20.239Z" }, + { url = "https://files.pythonhosted.org/packages/f1/29/74242b7d72385e29bcc5563fba67dad94943d7cd03552bac320d597f29b2/greenlet-3.2.4-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:f47617f698838ba98f4ff4189aef02e7343952df3a615f847bb575c3feb177a7", size = 1544904, upload-time = "2025-11-04T12:42:04.763Z" }, + { url = "https://files.pythonhosted.org/packages/c8/e2/1572b8eeab0f77df5f6729d6ab6b141e4a84ee8eb9bc8c1e7918f94eda6d/greenlet-3.2.4-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:af41be48a4f60429d5cad9d22175217805098a9ef7c40bfef44f7669fb9d74d8", size = 1611228, upload-time = "2025-11-04T12:42:08.423Z" }, { url = "https://files.pythonhosted.org/packages/d6/6f/b60b0291d9623c496638c582297ead61f43c4b72eef5e9c926ef4565ec13/greenlet-3.2.4-cp310-cp310-win_amd64.whl", hash = "sha256:73f49b5368b5359d04e18d15828eecc1806033db5233397748f4ca813ff1056c", size = 298654, upload-time = "2025-08-07T13:50:00.469Z" }, { url = "https://files.pythonhosted.org/packages/a4/de/f28ced0a67749cac23fecb02b694f6473f47686dff6afaa211d186e2ef9c/greenlet-3.2.4-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:96378df1de302bc38e99c3a9aa311967b7dc80ced1dcc6f171e99842987882a2", size = 272305, upload-time = "2025-08-07T13:15:41.288Z" }, { url = "https://files.pythonhosted.org/packages/09/16/2c3792cba130000bf2a31c5272999113f4764fd9d874fb257ff588ac779a/greenlet-3.2.4-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:1ee8fae0519a337f2329cb78bd7a8e128ec0f881073d43f023c7b8d4831d5246", size = 632472, upload-time = "2025-08-07T13:42:55.044Z" }, @@ -2538,6 +2540,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1f/8e/abdd3f14d735b2929290a018ecf133c901be4874b858dd1c604b9319f064/greenlet-3.2.4-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2523e5246274f54fdadbce8494458a2ebdcdbc7b802318466ac5606d3cded1f8", size = 587684, upload-time = "2025-08-07T13:18:25.164Z" }, { url = "https://files.pythonhosted.org/packages/5d/65/deb2a69c3e5996439b0176f6651e0052542bb6c8f8ec2e3fba97c9768805/greenlet-3.2.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:1987de92fec508535687fb807a5cea1560f6196285a4cde35c100b8cd632cc52", size = 1116647, upload-time = "2025-08-07T13:42:38.655Z" }, { url = "https://files.pythonhosted.org/packages/3f/cc/b07000438a29ac5cfb2194bfc128151d52f333cee74dd7dfe3fb733fc16c/greenlet-3.2.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:55e9c5affaa6775e2c6b67659f3a71684de4c549b3dd9afca3bc773533d284fa", size = 1142073, upload-time = "2025-08-07T13:18:21.737Z" }, + { url = "https://files.pythonhosted.org/packages/67/24/28a5b2fa42d12b3d7e5614145f0bd89714c34c08be6aabe39c14dd52db34/greenlet-3.2.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c9c6de1940a7d828635fbd254d69db79e54619f165ee7ce32fda763a9cb6a58c", size = 1548385, upload-time = "2025-11-04T12:42:11.067Z" }, + { url = "https://files.pythonhosted.org/packages/6a/05/03f2f0bdd0b0ff9a4f7b99333d57b53a7709c27723ec8123056b084e69cd/greenlet-3.2.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:03c5136e7be905045160b1b9fdca93dd6727b180feeafda6818e6496434ed8c5", size = 1613329, upload-time = "2025-11-04T12:42:12.928Z" }, { url = "https://files.pythonhosted.org/packages/d8/0f/30aef242fcab550b0b3520b8e3561156857c94288f0332a79928c31a52cf/greenlet-3.2.4-cp311-cp311-win_amd64.whl", hash = "sha256:9c40adce87eaa9ddb593ccb0fa6a07caf34015a29bf8d344811665b573138db9", size = 299100, upload-time = "2025-08-07T13:44:12.287Z" }, { url = "https://files.pythonhosted.org/packages/44/69/9b804adb5fd0671f367781560eb5eb586c4d495277c93bde4307b9e28068/greenlet-3.2.4-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:3b67ca49f54cede0186854a008109d6ee71f66bd57bb36abd6d0a0267b540cdd", size = 274079, upload-time = "2025-08-07T13:15:45.033Z" }, { url = "https://files.pythonhosted.org/packages/46/e9/d2a80c99f19a153eff70bc451ab78615583b8dac0754cfb942223d2c1a0d/greenlet-3.2.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ddf9164e7a5b08e9d22511526865780a576f19ddd00d62f8a665949327fde8bb", size = 640997, upload-time = "2025-08-07T13:42:56.234Z" }, @@ -2547,6 +2551,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/19/0d/6660d55f7373b2ff8152401a83e02084956da23ae58cddbfb0b330978fe9/greenlet-3.2.4-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3b3812d8d0c9579967815af437d96623f45c0f2ae5f04e366de62a12d83a8fb0", size = 607586, upload-time = "2025-08-07T13:18:28.544Z" }, { url = "https://files.pythonhosted.org/packages/8e/1a/c953fdedd22d81ee4629afbb38d2f9d71e37d23caace44775a3a969147d4/greenlet-3.2.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:abbf57b5a870d30c4675928c37278493044d7c14378350b3aa5d484fa65575f0", size = 1123281, upload-time = "2025-08-07T13:42:39.858Z" }, { url = "https://files.pythonhosted.org/packages/3f/c7/12381b18e21aef2c6bd3a636da1088b888b97b7a0362fac2e4de92405f97/greenlet-3.2.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:20fb936b4652b6e307b8f347665e2c615540d4b42b3b4c8a321d8286da7e520f", size = 1151142, upload-time = "2025-08-07T13:18:22.981Z" }, + { url = "https://files.pythonhosted.org/packages/27/45/80935968b53cfd3f33cf99ea5f08227f2646e044568c9b1555b58ffd61c2/greenlet-3.2.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:ee7a6ec486883397d70eec05059353b8e83eca9168b9f3f9a361971e77e0bcd0", size = 1564846, upload-time = "2025-11-04T12:42:15.191Z" }, + { url = "https://files.pythonhosted.org/packages/69/02/b7c30e5e04752cb4db6202a3858b149c0710e5453b71a3b2aec5d78a1aab/greenlet-3.2.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:326d234cbf337c9c3def0676412eb7040a35a768efc92504b947b3e9cfc7543d", size = 1633814, upload-time = "2025-11-04T12:42:17.175Z" }, { url = "https://files.pythonhosted.org/packages/e9/08/b0814846b79399e585f974bbeebf5580fbe59e258ea7be64d9dfb253c84f/greenlet-3.2.4-cp312-cp312-win_amd64.whl", hash = "sha256:a7d4e128405eea3814a12cc2605e0e6aedb4035bf32697f72deca74de4105e02", size = 299899, upload-time = "2025-08-07T13:38:53.448Z" }, { url = "https://files.pythonhosted.org/packages/49/e8/58c7f85958bda41dafea50497cbd59738c5c43dbbea5ee83d651234398f4/greenlet-3.2.4-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:1a921e542453fe531144e91e1feedf12e07351b1cf6c9e8a3325ea600a715a31", size = 272814, upload-time = "2025-08-07T13:15:50.011Z" }, { url = "https://files.pythonhosted.org/packages/62/dd/b9f59862e9e257a16e4e610480cfffd29e3fae018a68c2332090b53aac3d/greenlet-3.2.4-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cd3c8e693bff0fff6ba55f140bf390fa92c994083f838fece0f63be121334945", size = 641073, upload-time = "2025-08-07T13:42:57.23Z" }, @@ -2556,6 +2562,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ee/43/3cecdc0349359e1a527cbf2e3e28e5f8f06d3343aaf82ca13437a9aa290f/greenlet-3.2.4-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:23768528f2911bcd7e475210822ffb5254ed10d71f4028387e5a99b4c6699671", size = 610497, upload-time = "2025-08-07T13:18:31.636Z" }, { url = "https://files.pythonhosted.org/packages/b8/19/06b6cf5d604e2c382a6f31cafafd6f33d5dea706f4db7bdab184bad2b21d/greenlet-3.2.4-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:00fadb3fedccc447f517ee0d3fd8fe49eae949e1cd0f6a611818f4f6fb7dc83b", size = 1121662, upload-time = "2025-08-07T13:42:41.117Z" }, { url = "https://files.pythonhosted.org/packages/a2/15/0d5e4e1a66fab130d98168fe984c509249c833c1a3c16806b90f253ce7b9/greenlet-3.2.4-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:d25c5091190f2dc0eaa3f950252122edbbadbb682aa7b1ef2f8af0f8c0afefae", size = 1149210, upload-time = "2025-08-07T13:18:24.072Z" }, + { url = "https://files.pythonhosted.org/packages/1c/53/f9c440463b3057485b8594d7a638bed53ba531165ef0ca0e6c364b5cc807/greenlet-3.2.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:6e343822feb58ac4d0a1211bd9399de2b3a04963ddeec21530fc426cc121f19b", size = 1564759, upload-time = "2025-11-04T12:42:19.395Z" }, + { url = "https://files.pythonhosted.org/packages/47/e4/3bb4240abdd0a8d23f4f88adec746a3099f0d86bfedb623f063b2e3b4df0/greenlet-3.2.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ca7f6f1f2649b89ce02f6f229d7c19f680a6238af656f61e0115b24857917929", size = 1634288, upload-time = "2025-11-04T12:42:21.174Z" }, { url = "https://files.pythonhosted.org/packages/0b/55/2321e43595e6801e105fcfdee02b34c0f996eb71e6ddffca6b10b7e1d771/greenlet-3.2.4-cp313-cp313-win_amd64.whl", hash = "sha256:554b03b6e73aaabec3745364d6239e9e012d64c68ccd0b8430c64ccc14939a8b", size = 299685, upload-time = "2025-08-07T13:24:38.824Z" }, ] From 8ea83e4a26a598f8bc65aa27b15fad68e91b0b2b Mon Sep 17 00:00:00 2001 From: apenade <166741079+apenade@users.noreply.github.com> Date: Tue, 18 Nov 2025 10:44:49 +0100 Subject: [PATCH 021/176] Update alembic/versions/a1b2c3d4e5f6_add_label_column_to_data.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- alembic/versions/a1b2c3d4e5f6_add_label_column_to_data.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/alembic/versions/a1b2c3d4e5f6_add_label_column_to_data.py b/alembic/versions/a1b2c3d4e5f6_add_label_column_to_data.py index 8e7bc19b1..bffe61b46 100644 --- a/alembic/versions/a1b2c3d4e5f6_add_label_column_to_data.py +++ b/alembic/versions/a1b2c3d4e5f6_add_label_column_to_data.py @@ -20,7 +20,8 @@ depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: op.add_column( "data", - sa.Column("label", sa.String(), nullable=True)), + sa.Column("label", sa.String(), nullable=True) + ) def downgrade() -> None: op.drop_column("data", "label") \ No newline at end of file From a451fb8c5a22503dc0666af7f9383372344eeb7a Mon Sep 17 00:00:00 2001 From: apenade <166741079+apenade@users.noreply.github.com> Date: Tue, 18 Nov 2025 10:45:31 +0100 Subject: [PATCH 022/176] Update cognee/tasks/ingestion/data_item.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- cognee/tasks/ingestion/data_item.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cognee/tasks/ingestion/data_item.py b/cognee/tasks/ingestion/data_item.py index 23570bf77..23285d677 100644 --- a/cognee/tasks/ingestion/data_item.py +++ b/cognee/tasks/ingestion/data_item.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Dict, Optional +from typing import Any, Optional @dataclass class DataItem: From 7bd7079aac9fcb003bcc20e118bc65d066e9029c Mon Sep 17 00:00:00 2001 From: chinu0609 Date: Tue, 18 Nov 2025 22:17:23 +0530 Subject: [PATCH 023/176] fix: vecto_engine.delte_data_points --- cognee/tasks/cleanup/cleanup_unused_data.py | 33 ++++++++++----------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/cognee/tasks/cleanup/cleanup_unused_data.py b/cognee/tasks/cleanup/cleanup_unused_data.py index 4df622a2c..fd4b68204 100644 --- a/cognee/tasks/cleanup/cleanup_unused_data.py +++ b/cognee/tasks/cleanup/cleanup_unused_data.py @@ -315,22 +315,21 @@ async def _delete_unused_nodes(unused_nodes: Dict[str, list]) -> Dict[str, int]: "TextSummary": "TextSummary_text" } - for node_type, collection_name in vector_collections.items(): - node_ids = unused_nodes[node_type] - if not node_ids: - continue - - logger.info(f"Deleting {len(node_ids)} {node_type} embeddings from vector database") - - try: - # Delete from vector collection - if await vector_engine.has_collection(collection_name): - for node_id in node_ids: - try: - await vector_engine.delete(collection_name, {"id": str(node_id)}) - except Exception as e: - logger.warning(f"Failed to delete {node_id} from {collection_name}: {e}") - except Exception as e: - logger.error(f"Error deleting from vector collection {collection_name}: {e}") + + for node_type, collection_name in vector_collections.items(): + node_ids = unused_nodes[node_type] + if not node_ids: + continue + + logger.info(f"Deleting {len(node_ids)} {node_type} embeddings from vector database") + + try: + if await vector_engine.has_collection(collection_name): + await vector_engine.delete_data_points( + collection_name, + [str(node_id) for node_id in node_ids] + ) + except Exception as e: + logger.error(f"Error deleting from vector collection {collection_name}: {e}") return deleted_counts From 5fac3b40b94e4c81a7d9828ca9d2d84ab5e82bc1 Mon Sep 17 00:00:00 2001 From: chinu0609 Date: Tue, 18 Nov 2025 22:26:59 +0530 Subject: [PATCH 024/176] fix: test file for cleanup unused data --- cognee/tests/test_cleanup_unused_data.py | 244 +++++++++++++++++++++++ 1 file changed, 244 insertions(+) create mode 100644 cognee/tests/test_cleanup_unused_data.py diff --git a/cognee/tests/test_cleanup_unused_data.py b/cognee/tests/test_cleanup_unused_data.py new file mode 100644 index 000000000..c21b9f5ea --- /dev/null +++ b/cognee/tests/test_cleanup_unused_data.py @@ -0,0 +1,244 @@ +import os +import pathlib +import cognee +from datetime import datetime, timezone, timedelta +from uuid import UUID +from sqlalchemy import select, update +from cognee.modules.data.models import Data, DatasetData +from cognee.infrastructure.databases.relational import get_relational_engine +from cognee.modules.users.methods import get_default_user +from cognee.shared.logging_utils import get_logger +from cognee.modules.search.types import SearchType + +logger = get_logger() + + +async def test_textdocument_cleanup_with_sql(): + """ + End-to-end test for TextDocument cleanup based on last_accessed timestamps. + + Tests: + 1. Add and cognify a document + 2. Perform search to populate last_accessed timestamp + 3. Verify last_accessed is set in SQL Data table + 4. Manually age the timestamp beyond cleanup threshold + 5. Run cleanup with text_doc=True + 6. Verify document was deleted from all databases (relational, graph, and vector) + """ + # Setup test directories + data_directory_path = str( + pathlib.Path( + os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_cleanup") + ).resolve() + ) + cognee_directory_path = str( + pathlib.Path( + os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_cleanup") + ).resolve() + ) + + cognee.config.data_root_directory(data_directory_path) + cognee.config.system_root_directory(cognee_directory_path) + + # Initialize database + from cognee.modules.engine.operations.setup import setup + + # Clean slate + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + + logger.info("🧪 Testing TextDocument cleanup based on last_accessed") + + # Step 1: Add and cognify a test document + dataset_name = "test_cleanup_dataset" + test_text = """ + Machine learning is a subset of artificial intelligence that enables systems to learn + and improve from experience without being explicitly programmed. Deep learning uses + neural networks with multiple layers to process data. + """ + + await setup() + user = await get_default_user() + await cognee.add([test_text], dataset_name=dataset_name, user=user) + + cognify_result = await cognee.cognify([dataset_name], user=user) + + # Extract dataset_id from cognify result (ds_id is already a UUID) + dataset_id = None + for ds_id, pipeline_result in cognify_result.items(): + dataset_id = ds_id # Don't wrap in UUID() - it's already a UUID object + break + + assert dataset_id is not None, "Failed to get dataset_id from cognify result" + logger.info(f"✅ Document added and cognified. Dataset ID: {dataset_id}") + + # Step 2: Perform search to trigger last_accessed update + logger.info("Triggering search to update last_accessed...") + search_results = await cognee.search( + query_type=SearchType.CHUNKS, + query_text="machine learning", + datasets=[dataset_name], + user=user + ) + logger.info(f"✅ Search completed, found {len(search_results)} results") + + # Step 3: Verify last_accessed was set in SQL Data table + db_engine = get_relational_engine() + async with db_engine.get_async_session() as session: + # Get the Data record for this dataset + result = await session.execute( + select(Data, DatasetData) + .join(DatasetData, Data.id == DatasetData.data_id) + .where(DatasetData.dataset_id == dataset_id) + ) + data_records = result.all() + assert len(data_records) > 0, "No Data records found for the dataset" + data_record = data_records[0][0] + data_id = data_record.id + + # Verify last_accessed is set (should be set by search operation) + assert data_record.last_accessed is not None, ( + "last_accessed should be set after search operation" + ) + + original_last_accessed = data_record.last_accessed + logger.info(f"✅ last_accessed verified: {original_last_accessed}") + + # Step 4: Manually age the timestamp to be older than cleanup threshold + days_threshold = 30 + aged_timestamp = datetime.now(timezone.utc) - timedelta(days=days_threshold + 10) + + async with db_engine.get_async_session() as session: + stmt = update(Data).where(Data.id == data_id).values(last_accessed=aged_timestamp) + await session.execute(stmt) + await session.commit() + + # Query in a NEW session to avoid cached values + async with db_engine.get_async_session() as session: + result = await session.execute(select(Data).where(Data.id == data_id)) + updated_data = result.scalar_one_or_none() + + # Make both timezone-aware for comparison + retrieved_timestamp = updated_data.last_accessed + if retrieved_timestamp.tzinfo is None: + # If database returned naive datetime, make it UTC-aware + retrieved_timestamp = retrieved_timestamp.replace(tzinfo=timezone.utc) + + assert retrieved_timestamp == aged_timestamp, ( + f"Timestamp should be updated to aged value. " + f"Expected: {aged_timestamp}, Got: {retrieved_timestamp}" + ) + + # Step 5: Test cleanup with text_doc=True + from cognee.tasks.cleanup.cleanup_unused_data import cleanup_unused_data + + # First do a dry run + logger.info("Testing dry run with text_doc=True...") + dry_run_result = await cleanup_unused_data( + days_threshold=30, + dry_run=True, + user_id=user.id, + text_doc=True + ) + + assert dry_run_result['status'] == 'dry_run', "Status should be 'dry_run'" + assert dry_run_result['unused_count'] > 0, ( + "Should find at least one unused document" + ) + logger.info(f"✅ Dry run found {dry_run_result['unused_count']} unused documents") + + # Now run actual cleanup + logger.info("Executing cleanup with text_doc=True...") + cleanup_result = await cleanup_unused_data( + days_threshold=30, + dry_run=False, + user_id=user.id, + text_doc=True + ) + + assert cleanup_result["status"] == "completed", "Cleanup should complete successfully" + assert cleanup_result["deleted_count"]["documents"] > 0, ( + "At least one document should be deleted" + ) + logger.info(f"✅ Cleanup completed. Deleted {cleanup_result['deleted_count']['documents']} documents") + + # Step 6: Verify the document was actually deleted from SQL + async with db_engine.get_async_session() as session: + deleted_data = ( + await session.execute(select(Data).where(Data.id == data_id)) + ).scalar_one_or_none() + + assert deleted_data is None, ( + "Data record should be deleted after cleanup" + ) + logger.info("✅ Confirmed: Data record was deleted from SQL database") + + # Verify the dataset-data link was also removed + async with db_engine.get_async_session() as session: + dataset_data_link = ( + await session.execute( + select(DatasetData).where( + DatasetData.data_id == data_id, + DatasetData.dataset_id == dataset_id + ) + ) + ).scalar_one_or_none() + + assert dataset_data_link is None, ( + "DatasetData link should be deleted after cleanup" + ) + logger.info("✅ Confirmed: DatasetData link was deleted") + + # Verify graph nodes were cleaned up + from cognee.infrastructure.databases.graph import get_graph_engine + + graph_engine = await get_graph_engine() + + # Try to find the TextDocument node - it should not exist + result = await graph_engine.query( + "MATCH (n:Node {id: $id}) RETURN n", + {"id": str(data_id)} + ) + + assert len(result) == 0, ( + "TextDocument node should be deleted from graph database" + ) + logger.info("✅ Confirmed: TextDocument node was deleted from graph database") + + # Verify vector database was cleaned up + from cognee.infrastructure.databases.vector import get_vector_engine + + vector_engine = get_vector_engine() + + # Check each collection that should have been cleaned up + vector_collections = [ + "DocumentChunk_text", + "Entity_name", + "TextSummary_text" + ] + + for collection_name in vector_collections: + if await vector_engine.has_collection(collection_name): + # Try to retrieve the deleted data points + try: + results = await vector_engine.retrieve(collection_name, [str(data_id)]) + assert len(results) == 0, ( + f"Data points should be deleted from {collection_name} collection" + ) + logger.info(f"✅ Confirmed: {collection_name} collection is clean") + except Exception as e: + # Collection might be empty or not exist, which is fine + logger.info(f"✅ Confirmed: {collection_name} collection is empty or doesn't exist") + pass + + logger.info("✅ Confirmed: Vector database entries were deleted") + + logger.info("🎉 All cleanup tests passed!") + + return True + + +if __name__ == "__main__": + import asyncio + success = asyncio.run(test_textdocument_cleanup_with_sql()) + exit(0 if success else 1) From a072773995734b8496b7e1a57a2a7abfbfe6faa7 Mon Sep 17 00:00:00 2001 From: apenade <166741079+apenade@users.noreply.github.com> Date: Wed, 19 Nov 2025 16:02:27 +0100 Subject: [PATCH 025/176] Update alembic/versions/a1b2c3d4e5f6_add_label_column_to_data.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- alembic/versions/a1b2c3d4e5f6_add_label_column_to_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/alembic/versions/a1b2c3d4e5f6_add_label_column_to_data.py b/alembic/versions/a1b2c3d4e5f6_add_label_column_to_data.py index bffe61b46..814467954 100644 --- a/alembic/versions/a1b2c3d4e5f6_add_label_column_to_data.py +++ b/alembic/versions/a1b2c3d4e5f6_add_label_column_to_data.py @@ -1,4 +1,4 @@ -"""Add sync_operations table +"""Add label column to data table Revision ID: a1b2c3d4e5f6 Revises: 211ab850ef3d From 43290af1b23d24d6ab8b5d57c243abe1cee8787e Mon Sep 17 00:00:00 2001 From: chinu0609 Date: Wed, 19 Nov 2025 21:00:16 +0530 Subject: [PATCH 026/176] fix: set last_acessed to current timestamp --- alembic/versions/e1ec1dcb50b6_add_last_accessed_to_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/alembic/versions/e1ec1dcb50b6_add_last_accessed_to_data.py b/alembic/versions/e1ec1dcb50b6_add_last_accessed_to_data.py index 267e11fb2..a16c99e9f 100644 --- a/alembic/versions/e1ec1dcb50b6_add_last_accessed_to_data.py +++ b/alembic/versions/e1ec1dcb50b6_add_last_accessed_to_data.py @@ -34,7 +34,7 @@ def upgrade() -> None: sa.Column('last_accessed', sa.DateTime(timezone=True), nullable=True) ) # Optionally initialize with created_at values for existing records - op.execute("UPDATE data SET last_accessed = created_at") + op.execute("UPDATE data SET last_accessed = CURRENT_TIMESTAMP") def downgrade() -> None: From b52c1a1e25e6edffe112462836ab315b36bec567 Mon Sep 17 00:00:00 2001 From: chinu0609 Date: Mon, 24 Nov 2025 12:50:39 +0530 Subject: [PATCH 027/176] fix: flag to enable and disable last_accessed --- .../e1ec1dcb50b6_add_last_accessed_to_data.py | 88 ++++++++++--------- .../retrieval/utils/access_tracking.py | 7 +- cognee/tasks/cleanup/cleanup_unused_data.py | 40 ++++++++- 3 files changed, 90 insertions(+), 45 deletions(-) diff --git a/alembic/versions/e1ec1dcb50b6_add_last_accessed_to_data.py b/alembic/versions/e1ec1dcb50b6_add_last_accessed_to_data.py index a16c99e9f..f1a36ae59 100644 --- a/alembic/versions/e1ec1dcb50b6_add_last_accessed_to_data.py +++ b/alembic/versions/e1ec1dcb50b6_add_last_accessed_to_data.py @@ -1,46 +1,52 @@ -"""add_last_accessed_to_data - -Revision ID: e1ec1dcb50b6 -Revises: 211ab850ef3d -Create Date: 2025-11-04 21:45:52.642322 - -""" -from typing import Sequence, Union - -from alembic import op -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision: str = 'e1ec1dcb50b6' -down_revision: Union[str, None] = '211ab850ef3d' -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - -def _get_column(inspector, table, name, schema=None): - for col in inspector.get_columns(table, schema=schema): - if col["name"] == name: - return col - return None +"""add_last_accessed_to_data + +Revision ID: e1ec1dcb50b6 +Revises: 211ab850ef3d +Create Date: 2025-11-04 21:45:52.642322 + +""" +import os +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa -def upgrade() -> None: - conn = op.get_bind() - insp = sa.inspect(conn) - - last_accessed_column = _get_column(insp, "data", "last_accessed") - if not last_accessed_column: - op.add_column('data', - sa.Column('last_accessed', sa.DateTime(timezone=True), nullable=True) - ) - # Optionally initialize with created_at values for existing records - op.execute("UPDATE data SET last_accessed = CURRENT_TIMESTAMP") +# revision identifiers, used by Alembic. +revision: str = 'e1ec1dcb50b6' +down_revision: Union[str, None] = '211ab850ef3d' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None -def downgrade() -> None: - conn = op.get_bind() - insp = sa.inspect(conn) - - last_accessed_column = _get_column(insp, "data", "last_accessed") - if last_accessed_column: +def _get_column(inspector, table, name, schema=None): + for col in inspector.get_columns(table, schema=schema): + if col["name"] == name: + return col + return None + + +def upgrade() -> None: + conn = op.get_bind() + insp = sa.inspect(conn) + + last_accessed_column = _get_column(insp, "data", "last_accessed") + if not last_accessed_column: + # Always create the column for schema consistency + op.add_column('data', + sa.Column('last_accessed', sa.DateTime(timezone=True), nullable=True) + ) + + # Only initialize existing records if feature is enabled + enable_last_accessed = os.getenv("ENABLE_LAST_ACCESSED", "false").lower() == "true" + if enable_last_accessed: + op.execute("UPDATE data SET last_accessed = CURRENT_TIMESTAMP") + + +def downgrade() -> None: + conn = op.get_bind() + insp = sa.inspect(conn) + + last_accessed_column = _get_column(insp, "data", "last_accessed") + if last_accessed_column: op.drop_column('data', 'last_accessed') diff --git a/cognee/modules/retrieval/utils/access_tracking.py b/cognee/modules/retrieval/utils/access_tracking.py index 65d597a93..6df0284ec 100644 --- a/cognee/modules/retrieval/utils/access_tracking.py +++ b/cognee/modules/retrieval/utils/access_tracking.py @@ -4,7 +4,7 @@ import json from datetime import datetime, timezone from typing import List, Any from uuid import UUID - +import os from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.relational import get_relational_engine from cognee.modules.data.models import Data @@ -27,7 +27,10 @@ async def update_node_access_timestamps(items: List[Any]): ---------- items : List[Any] List of items with payload containing 'id' field (from vector search results) - """ + """ + if os.getenv("ENABLE_LAST_ACCESSED", "false").lower() != "true": + return + if not items: return diff --git a/cognee/tasks/cleanup/cleanup_unused_data.py b/cognee/tasks/cleanup/cleanup_unused_data.py index fd4b68204..175452a0a 100644 --- a/cognee/tasks/cleanup/cleanup_unused_data.py +++ b/cognee/tasks/cleanup/cleanup_unused_data.py @@ -10,7 +10,7 @@ import json from datetime import datetime, timezone, timedelta from typing import Optional, Dict, Any from uuid import UUID - +import os from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.relational import get_relational_engine @@ -47,7 +47,43 @@ async def cleanup_unused_data( ------- Dict[str, Any] Cleanup results with status, counts, and timestamp - """ + """ + # Check 1: Environment variable must be enabled + if os.getenv("ENABLE_LAST_ACCESSED", "false").lower() != "true": + logger.warning( + "Cleanup skipped: ENABLE_LAST_ACCESSED is not enabled." + ) + return { + "status": "skipped", + "reason": "ENABLE_LAST_ACCESSED not enabled", + "unused_count": 0, + "deleted_count": {}, + "cleanup_date": datetime.now(timezone.utc).isoformat() + } + + # Check 2: Verify tracking has actually been running + db_engine = get_relational_engine() + async with db_engine.get_async_session() as session: + # Count records with non-NULL last_accessed + tracked_count = await session.execute( + select(sa.func.count(Data.id)).where(Data.last_accessed.isnot(None)) + ) + tracked_records = tracked_count.scalar() + + if tracked_records == 0: + logger.warning( + "Cleanup skipped: No records have been tracked yet. " + "ENABLE_LAST_ACCESSED may have been recently enabled. " + "Wait for retrievers to update timestamps before running cleanup." + ) + return { + "status": "skipped", + "reason": "No tracked records found - tracking may be newly enabled", + "unused_count": 0, + "deleted_count": {}, + "cleanup_date": datetime.now(timezone.utc).isoformat() + } + logger.info( "Starting cleanup task", days_threshold=days_threshold, From 5cb6510205742e7a5abf2afe23d2527b229931d0 Mon Sep 17 00:00:00 2001 From: chinu0609 Date: Mon, 24 Nov 2025 13:12:46 +0530 Subject: [PATCH 028/176] fix: import --- cognee/tasks/cleanup/cleanup_unused_data.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cognee/tasks/cleanup/cleanup_unused_data.py b/cognee/tasks/cleanup/cleanup_unused_data.py index 175452a0a..a90d96b5c 100644 --- a/cognee/tasks/cleanup/cleanup_unused_data.py +++ b/cognee/tasks/cleanup/cleanup_unused_data.py @@ -18,6 +18,7 @@ from cognee.modules.data.models import Data, DatasetData from cognee.shared.logging_utils import get_logger from sqlalchemy import select, or_ import cognee +import sqlalchemy as sa logger = get_logger(__name__) From 02b17786588b7be4c582a2fdfe93a5412c074cda Mon Sep 17 00:00:00 2001 From: rajeevrajeshuni Date: Tue, 25 Nov 2025 12:22:15 +0530 Subject: [PATCH 029/176] Adding support for audio/image transcription for all other providers --- cognee/infrastructure/llm/LLMGateway.py | 13 -- .../llm/anthropic/adapter.py | 27 ++-- .../litellm_instructor/llm/gemini/adapter.py | 59 +++---- .../llm/generic_llm_api/adapter.py | 132 +++++++++++++++- .../litellm_instructor/llm/get_llm_client.py | 13 +- .../litellm_instructor/llm/llm_interface.py | 47 +++++- .../litellm_instructor/llm/mistral/adapter.py | 66 ++++++-- .../litellm_instructor/llm/ollama/adapter.py | 118 ++------------ .../litellm_instructor/llm/openai/adapter.py | 148 +++--------------- uv.lock | 2 +- 10 files changed, 313 insertions(+), 312 deletions(-) diff --git a/cognee/infrastructure/llm/LLMGateway.py b/cognee/infrastructure/llm/LLMGateway.py index ab5bb35d7..66a364110 100644 --- a/cognee/infrastructure/llm/LLMGateway.py +++ b/cognee/infrastructure/llm/LLMGateway.py @@ -34,19 +34,6 @@ class LLMGateway: text_input=text_input, system_prompt=system_prompt, response_model=response_model ) - @staticmethod - def create_structured_output( - text_input: str, system_prompt: str, response_model: Type[BaseModel] - ) -> BaseModel: - from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.get_llm_client import ( - get_llm_client, - ) - - llm_client = get_llm_client() - return llm_client.create_structured_output( - text_input=text_input, system_prompt=system_prompt, response_model=response_model - ) - @staticmethod def create_transcript(input) -> Coroutine: from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.get_llm_client import ( diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py index dbf0dfbea..818d3adb7 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py @@ -3,7 +3,9 @@ from typing import Type from pydantic import BaseModel import litellm import instructor +import anthropic from cognee.shared.logging_utils import get_logger +from cognee.modules.observability.get_observe import get_observe from tenacity import ( retry, stop_after_delay, @@ -12,27 +14,32 @@ from tenacity import ( before_sleep_log, ) -from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import ( - LLMInterface, +from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import ( + GenericAPIAdapter, ) from cognee.infrastructure.llm.config import get_llm_config logger = get_logger() +observe = get_observe() -class AnthropicAdapter(LLMInterface): +class AnthropicAdapter(GenericAPIAdapter): """ Adapter for interfacing with the Anthropic API, enabling structured output generation and prompt display. """ - name = "Anthropic" - model: str default_instructor_mode = "anthropic_tools" - def __init__(self, max_completion_tokens: int, model: str = None, instructor_mode: str = None): - import anthropic - + def __init__( + self, api_key: str, model: str, max_completion_tokens: int, instructor_mode: str = None + ): + super().__init__( + api_key=api_key, + model=model, + max_completion_tokens=max_completion_tokens, + name="Anthropic", + ) self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode self.aclient = instructor.patch( @@ -40,9 +47,7 @@ class AnthropicAdapter(LLMInterface): mode=instructor.Mode(self.instructor_mode), ) - self.model = model - self.max_completion_tokens = max_completion_tokens - + @observe(as_type="generation") @retry( stop=stop_after_delay(128), wait=wait_exponential_jitter(2, 128), diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py index 226f291d7..bae665052 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py @@ -1,4 +1,4 @@ -"""Adapter for Generic API LLM provider API""" +"""Adapter for Gemini API LLM provider""" import litellm import instructor @@ -8,12 +8,7 @@ from openai import ContentFilterFinishReasonError from litellm.exceptions import ContentPolicyViolationError from instructor.core import InstructorRetryException -from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError -from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import ( - LLMInterface, -) import logging -from cognee.shared.logging_utils import get_logger from tenacity import ( retry, stop_after_delay, @@ -22,55 +17,65 @@ from tenacity import ( before_sleep_log, ) +from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError +from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import ( + GenericAPIAdapter, +) +from cognee.shared.logging_utils import get_logger +from cognee.modules.observability.get_observe import get_observe + logger = get_logger() +observe = get_observe() -class GeminiAdapter(LLMInterface): +class GeminiAdapter(GenericAPIAdapter): """ Adapter for Gemini API LLM provider. This class initializes the API adapter with necessary credentials and configurations for interacting with the gemini LLM models. It provides methods for creating structured outputs - based on user input and system prompts. + based on user input and system prompts, as well as multimodal processing capabilities. Public methods: - - acreate_structured_output(text_input: str, system_prompt: str, response_model: - Type[BaseModel]) -> BaseModel + - acreate_structured_output(text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel + - create_transcript(input) -> BaseModel: Transcribe audio files to text + - transcribe_image(input) -> BaseModel: Inherited from GenericAPIAdapter """ - name: str - model: str - api_key: str default_instructor_mode = "json_mode" def __init__( self, - endpoint, api_key: str, model: str, - api_version: str, max_completion_tokens: int, + endpoint: str = None, + api_version: str = None, + transcription_model: str = None, instructor_mode: str = None, fallback_model: str = None, fallback_api_key: str = None, fallback_endpoint: str = None, ): - self.model = model - self.api_key = api_key - self.endpoint = endpoint - self.api_version = api_version - self.max_completion_tokens = max_completion_tokens - - self.fallback_model = fallback_model - self.fallback_api_key = fallback_api_key - self.fallback_endpoint = fallback_endpoint - + super().__init__( + api_key=api_key, + model=model, + max_completion_tokens=max_completion_tokens, + name="Gemini", + endpoint=endpoint, + api_version=api_version, + transcription_model=transcription_model, + fallback_model=fallback_model, + fallback_api_key=fallback_api_key, + fallback_endpoint=fallback_endpoint, + ) self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode self.aclient = instructor.from_litellm( litellm.acompletion, mode=instructor.Mode(self.instructor_mode) ) + @observe(as_type="generation") @retry( stop=stop_after_delay(128), wait=wait_exponential_jitter(2, 128), @@ -118,7 +123,7 @@ class GeminiAdapter(LLMInterface): }, ], api_key=self.api_key, - max_retries=5, + max_retries=self.MAX_RETRIES, api_base=self.endpoint, api_version=self.api_version, response_model=response_model, @@ -152,7 +157,7 @@ class GeminiAdapter(LLMInterface): "content": system_prompt, }, ], - max_retries=5, + max_retries=self.MAX_RETRIES, api_key=self.fallback_api_key, api_base=self.fallback_endpoint, response_model=response_model, diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py index 9d7f25fc5..9987711b9 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py @@ -1,8 +1,10 @@ """Adapter for Generic API LLM provider API""" +import base64 +import mimetypes import litellm import instructor -from typing import Type +from typing import Type, Optional from pydantic import BaseModel from openai import ContentFilterFinishReasonError from litellm.exceptions import ContentPolicyViolationError @@ -12,6 +14,8 @@ from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import ( LLMInterface, ) +from cognee.infrastructure.files.utils.open_data_file import open_data_file +from cognee.modules.observability.get_observe import get_observe import logging from cognee.shared.logging_utils import get_logger from tenacity import ( @@ -23,6 +27,7 @@ from tenacity import ( ) logger = get_logger() +observe = get_observe() class GenericAPIAdapter(LLMInterface): @@ -38,18 +43,19 @@ class GenericAPIAdapter(LLMInterface): Type[BaseModel]) -> BaseModel """ - name: str - model: str - api_key: str + MAX_RETRIES = 5 default_instructor_mode = "json_mode" def __init__( self, - endpoint, api_key: str, model: str, - name: str, max_completion_tokens: int, + name: str, + endpoint: str = None, + api_version: str = None, + transcription_model: str = None, + image_transcribe_model: str = None, instructor_mode: str = None, fallback_model: str = None, fallback_api_key: str = None, @@ -58,9 +64,11 @@ class GenericAPIAdapter(LLMInterface): self.name = name self.model = model self.api_key = api_key + self.api_version = api_version self.endpoint = endpoint self.max_completion_tokens = max_completion_tokens - + self.transcription_model = transcription_model or model + self.image_transcribe_model = image_transcribe_model or model self.fallback_model = fallback_model self.fallback_api_key = fallback_api_key self.fallback_endpoint = fallback_endpoint @@ -71,6 +79,7 @@ class GenericAPIAdapter(LLMInterface): litellm.acompletion, mode=instructor.Mode(self.instructor_mode) ) + @observe(as_type="generation") @retry( stop=stop_after_delay(128), wait=wait_exponential_jitter(2, 128), @@ -170,3 +179,112 @@ class GenericAPIAdapter(LLMInterface): raise ContentPolicyFilterError( f"The provided input contains content that is not aligned with our content policy: {text_input}" ) from error + + @observe(as_type="transcription") + @retry( + stop=stop_after_delay(128), + wait=wait_exponential_jitter(2, 128), + retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError), + before_sleep=before_sleep_log(logger, logging.DEBUG), + reraise=True, + ) + async def create_transcript(self, input) -> Optional[BaseModel]: + """ + Generate an audio transcript from a user query. + + This method creates a transcript from the specified audio file, raising a + FileNotFoundError if the file does not exist. The audio file is processed and the + transcription is retrieved from the API. + + Parameters: + ----------- + - input: The path to the audio file that needs to be transcribed. + + Returns: + -------- + The generated transcription of the audio file. + """ + async with open_data_file(input, mode="rb") as audio_file: + encoded_string = base64.b64encode(audio_file.read()).decode("utf-8") + mime_type, _ = mimetypes.guess_type(input) + if not mime_type or not mime_type.startswith("audio/"): + raise ValueError( + f"Could not determine MIME type for audio file: {input}. Is the extension correct?" + ) + return litellm.completion( + model=self.transcription_model, + messages=[ + { + "role": "user", + "content": [ + { + "type": "file", + "file": {"file_data": f"data:{mime_type};base64,{encoded_string}"}, + }, + {"type": "text", "text": "Transcribe the following audio precisely."}, + ], + } + ], + api_key=self.api_key, + api_version=self.api_version, + max_completion_tokens=self.max_completion_tokens, + api_base=self.endpoint, + max_retries=self.MAX_RETRIES, + ) + + @observe(as_type="transcribe_image") + @retry( + stop=stop_after_delay(128), + wait=wait_exponential_jitter(2, 128), + retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError), + before_sleep=before_sleep_log(logger, logging.DEBUG), + reraise=True, + ) + async def transcribe_image(self, input) -> Optional[BaseModel]: + """ + Generate a transcription of an image from a user query. + + This method encodes the image and sends a request to the API to obtain a + description of the contents of the image. + + Parameters: + ----------- + - input: The path to the image file that needs to be transcribed. + + Returns: + -------- + - BaseModel: A structured output generated by the model, returned as an instance of + BaseModel. + """ + async with open_data_file(input, mode="rb") as image_file: + encoded_image = base64.b64encode(image_file.read()).decode("utf-8") + mime_type, _ = mimetypes.guess_type(input) + if not mime_type or not mime_type.startswith("image/"): + raise ValueError( + f"Could not determine MIME type for image file: {input}. Is the extension correct?" + ) + return litellm.completion( + model=self.image_transcribe_model, + messages=[ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What's in this image?", + }, + { + "type": "image_url", + "image_url": { + "url": f"data:{mime_type};base64,{encoded_image}", + }, + }, + ], + } + ], + api_key=self.api_key, + api_base=self.endpoint, + api_version=self.api_version, + max_completion_tokens=300, + max_retries=self.MAX_RETRIES, + ) diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py index 39558f36d..de6cfaf19 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py @@ -97,11 +97,10 @@ def get_llm_client(raise_api_key_error: bool = True): ) return OllamaAPIAdapter( - llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, - "Ollama", - max_completion_tokens=max_completion_tokens, + max_completion_tokens, + llm_config.llm_endpoint, instructor_mode=llm_config.llm_instructor_mode.lower(), ) @@ -111,8 +110,9 @@ def get_llm_client(raise_api_key_error: bool = True): ) return AnthropicAdapter( - max_completion_tokens=max_completion_tokens, - model=llm_config.llm_model, + llm_config.llm_api_key, + llm_config.llm_model, + max_completion_tokens, instructor_mode=llm_config.llm_instructor_mode.lower(), ) @@ -125,11 +125,10 @@ def get_llm_client(raise_api_key_error: bool = True): ) return GenericAPIAdapter( - llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, + max_completion_tokens, "Custom", - max_completion_tokens=max_completion_tokens, instructor_mode=llm_config.llm_instructor_mode.lower(), fallback_api_key=llm_config.fallback_api_key, fallback_endpoint=llm_config.fallback_endpoint, diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py index b02105484..f8352737d 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py @@ -1,6 +1,6 @@ """LLM Interface""" -from typing import Type, Protocol +from typing import Type, Protocol, Optional from abc import abstractmethod from pydantic import BaseModel from cognee.infrastructure.llm.LLMGateway import LLMGateway @@ -8,13 +8,12 @@ from cognee.infrastructure.llm.LLMGateway import LLMGateway class LLMInterface(Protocol): """ - Define an interface for LLM models with methods for structured output and prompt - display. + Define an interface for LLM models with methods for structured output, multimodal processing, and prompt display. Methods: - - acreate_structured_output(text_input: str, system_prompt: str, response_model: - Type[BaseModel]) - - show_prompt(text_input: str, system_prompt: str) + - acreate_structured_output(text_input: str, system_prompt: str, response_model: Type[BaseModel]) + - create_transcript(input): Transcribe audio files to text + - transcribe_image(input): Analyze image files and return text description """ @abstractmethod @@ -36,3 +35,39 @@ class LLMInterface(Protocol): output. """ raise NotImplementedError + + @abstractmethod + async def create_transcript(self, input) -> Optional[BaseModel]: + """ + Transcribe audio content to text. + + This method should be implemented by subclasses that support audio transcription. + If not implemented, returns None and should be handled gracefully by callers. + + Parameters: + ----------- + - input: The path to the audio file that needs to be transcribed. + + Returns: + -------- + - BaseModel: A structured output containing the transcription, or None if not supported. + """ + raise NotImplementedError + + @abstractmethod + async def transcribe_image(self, input) -> Optional[BaseModel]: + """ + Analyze image content and return text description. + + This method should be implemented by subclasses that support image analysis. + If not implemented, returns None and should be handled gracefully by callers. + + Parameters: + ----------- + - input: The path to the image file that needs to be analyzed. + + Returns: + -------- + - BaseModel: A structured output containing the image description, or None if not supported. + """ + raise NotImplementedError diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py index 355cdae0b..0fa35923f 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py @@ -2,12 +2,12 @@ import litellm import instructor from pydantic import BaseModel from typing import Type -from litellm import JSONSchemaValidationError +from litellm import JSONSchemaValidationError, transcription from cognee.shared.logging_utils import get_logger from cognee.modules.observability.get_observe import get_observe -from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import ( - LLMInterface, +from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import ( + GenericAPIAdapter, ) from cognee.infrastructure.llm.config import get_llm_config @@ -19,12 +19,13 @@ from tenacity import ( retry_if_not_exception_type, before_sleep_log, ) +from mistralai import Mistral logger = get_logger() observe = get_observe() -class MistralAdapter(LLMInterface): +class MistralAdapter(GenericAPIAdapter): """ Adapter for Mistral AI API, for structured output generation and prompt display. @@ -33,10 +34,6 @@ class MistralAdapter(LLMInterface): - show_prompt """ - name = "Mistral" - model: str - api_key: str - max_completion_tokens: int default_instructor_mode = "mistral_tools" def __init__( @@ -45,12 +42,21 @@ class MistralAdapter(LLMInterface): model: str, max_completion_tokens: int, endpoint: str = None, + transcription_model: str = None, + image_transcribe_model: str = None, instructor_mode: str = None, ): from mistralai import Mistral - self.model = model - self.max_completion_tokens = max_completion_tokens + super().__init__( + api_key=api_key, + model=model, + max_completion_tokens=max_completion_tokens, + name="Mistral", + endpoint=endpoint, + transcription_model=transcription_model, + image_transcribe_model=image_transcribe_model, + ) self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode @@ -60,6 +66,7 @@ class MistralAdapter(LLMInterface): api_key=get_llm_config().llm_api_key, ) + @observe(as_type="generation") @retry( stop=stop_after_delay(128), wait=wait_exponential_jitter(2, 128), @@ -117,3 +124,42 @@ class MistralAdapter(LLMInterface): logger.error(f"Schema validation failed: {str(e)}") logger.debug(f"Raw response: {e.raw_response}") raise ValueError(f"Response failed schema validation: {str(e)}") + + @observe(as_type="transcription") + @retry( + stop=stop_after_delay(128), + wait=wait_exponential_jitter(2, 128), + retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError), + before_sleep=before_sleep_log(logger, logging.DEBUG), + reraise=True, + ) + async def create_transcript(self, input): + """ + Generate an audio transcript from a user query. + + This method creates a transcript from the specified audio file. + The audio file is processed and the transcription is retrieved from the API. + + Parameters: + ----------- + - input: The path to the audio file that needs to be transcribed. + + Returns: + -------- + The generated transcription of the audio file. + """ + transcription_model = self.transcription_model + if self.transcription_model.startswith("mistral"): + transcription_model = self.transcription_model.split("/")[-1] + file_name = input.split("/")[-1] + client = Mistral(api_key=self.api_key) + with open(input, "rb") as f: + transcription_response = client.audio.transcriptions.complete( + model=transcription_model, + file={ + "content": f, + "file_name": file_name, + }, + ) + # TODO: We need to standardize return type of create_transcript across different models. + return transcription_response diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py index aabd19867..163637a95 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py @@ -5,12 +5,12 @@ import instructor from typing import Type from openai import OpenAI from pydantic import BaseModel - -from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import ( - LLMInterface, -) from cognee.infrastructure.files.utils.open_data_file import open_data_file from cognee.shared.logging_utils import get_logger +from cognee.modules.observability.get_observe import get_observe +from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import ( + GenericAPIAdapter, +) from tenacity import ( retry, stop_after_delay, @@ -20,9 +20,10 @@ from tenacity import ( ) logger = get_logger() +observe = get_observe() -class OllamaAPIAdapter(LLMInterface): +class OllamaAPIAdapter(GenericAPIAdapter): """ Adapter for a Generic API LLM provider using instructor with an OpenAI backend. @@ -46,18 +47,20 @@ class OllamaAPIAdapter(LLMInterface): def __init__( self, - endpoint: str, api_key: str, model: str, name: str, max_completion_tokens: int, + endpoint: str, instructor_mode: str = None, ): - self.name = name - self.model = model - self.api_key = api_key - self.endpoint = endpoint - self.max_completion_tokens = max_completion_tokens + super().__init__( + api_key=api_key, + model=model, + max_completion_tokens=max_completion_tokens, + name="Ollama", + endpoint=endpoint, + ) self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode @@ -66,6 +69,7 @@ class OllamaAPIAdapter(LLMInterface): mode=instructor.Mode(self.instructor_mode), ) + @observe(as_type="generation") @retry( stop=stop_after_delay(128), wait=wait_exponential_jitter(2, 128), @@ -113,95 +117,3 @@ class OllamaAPIAdapter(LLMInterface): ) return response - - @retry( - stop=stop_after_delay(128), - wait=wait_exponential_jitter(2, 128), - retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError), - before_sleep=before_sleep_log(logger, logging.DEBUG), - reraise=True, - ) - async def create_transcript(self, input_file: str) -> str: - """ - Generate an audio transcript from a user query. - - This synchronous method takes an input audio file and returns its transcription. Raises - a FileNotFoundError if the input file does not exist, and raises a ValueError if - transcription fails or returns no text. - - Parameters: - ----------- - - - input_file (str): The path to the audio file to be transcribed. - - Returns: - -------- - - - str: The transcription of the audio as a string. - """ - - async with open_data_file(input_file, mode="rb") as audio_file: - transcription = self.aclient.audio.transcriptions.create( - model="whisper-1", # Ensure the correct model for transcription - file=audio_file, - language="en", - ) - - # Ensure the response contains a valid transcript - if not hasattr(transcription, "text"): - raise ValueError("Transcription failed. No text returned.") - - return transcription.text - - @retry( - stop=stop_after_delay(128), - wait=wait_exponential_jitter(2, 128), - retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError), - before_sleep=before_sleep_log(logger, logging.DEBUG), - reraise=True, - ) - async def transcribe_image(self, input_file: str) -> str: - """ - Transcribe content from an image using base64 encoding. - - This synchronous method takes an input image file, encodes it as base64, and returns the - transcription of its content. Raises a FileNotFoundError if the input file does not - exist, and raises a ValueError if the transcription fails or no valid response is - received. - - Parameters: - ----------- - - - input_file (str): The path to the image file to be transcribed. - - Returns: - -------- - - - str: The transcription of the image's content as a string. - """ - - async with open_data_file(input_file, mode="rb") as image_file: - encoded_image = base64.b64encode(image_file.read()).decode("utf-8") - - response = self.aclient.chat.completions.create( - model=self.model, - messages=[ - { - "role": "user", - "content": [ - {"type": "text", "text": "What's in this image?"}, - { - "type": "image_url", - "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}, - }, - ], - } - ], - max_completion_tokens=300, - ) - - # Ensure response is valid before accessing .choices[0].message.content - if not hasattr(response, "choices") or not response.choices: - raise ValueError("Image transcription failed. No response received.") - - return response.choices[0].message.content diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py index 778c8eec7..e9943c335 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py @@ -1,4 +1,3 @@ -import base64 import litellm import instructor from typing import Type @@ -16,8 +15,8 @@ from tenacity import ( before_sleep_log, ) -from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import ( - LLMInterface, +from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import ( + GenericAPIAdapter, ) from cognee.infrastructure.llm.exceptions import ( ContentPolicyFilterError, @@ -31,7 +30,7 @@ logger = get_logger() observe = get_observe() -class OpenAIAdapter(LLMInterface): +class OpenAIAdapter(GenericAPIAdapter): """ Adapter for OpenAI's GPT-3, GPT-4 API. @@ -52,12 +51,7 @@ class OpenAIAdapter(LLMInterface): - MAX_RETRIES """ - name = "OpenAI" - model: str - api_key: str - api_version: str default_instructor_mode = "json_schema_mode" - MAX_RETRIES = 5 """Adapter for OpenAI's GPT-3, GPT=4 API""" @@ -65,17 +59,29 @@ class OpenAIAdapter(LLMInterface): def __init__( self, api_key: str, - endpoint: str, - api_version: str, model: str, - transcription_model: str, max_completion_tokens: int, + endpoint: str = None, + api_version: str = None, + transcription_model: str = None, instructor_mode: str = None, streaming: bool = False, fallback_model: str = None, fallback_api_key: str = None, fallback_endpoint: str = None, ): + super().__init__( + api_key=api_key, + model=model, + max_completion_tokens=max_completion_tokens, + name="OpenAI", + endpoint=endpoint, + api_version=api_version, + transcription_model=transcription_model, + fallback_model=fallback_model, + fallback_api_key=fallback_api_key, + fallback_endpoint=fallback_endpoint, + ) self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode # TODO: With gpt5 series models OpenAI expects JSON_SCHEMA as a mode for structured outputs. # Make sure all new gpt models will work with this mode as well. @@ -90,18 +96,8 @@ class OpenAIAdapter(LLMInterface): self.aclient = instructor.from_litellm(litellm.acompletion) self.client = instructor.from_litellm(litellm.completion) - self.transcription_model = transcription_model - self.model = model - self.api_key = api_key - self.endpoint = endpoint - self.api_version = api_version - self.max_completion_tokens = max_completion_tokens self.streaming = streaming - self.fallback_model = fallback_model - self.fallback_api_key = fallback_api_key - self.fallback_endpoint = fallback_endpoint - @observe(as_type="generation") @retry( stop=stop_after_delay(128), @@ -174,7 +170,7 @@ class OpenAIAdapter(LLMInterface): }, ], api_key=self.fallback_api_key, - # api_base=self.fallback_endpoint, + api_base=self.fallback_endpoint, response_model=response_model, max_retries=self.MAX_RETRIES, ) @@ -193,57 +189,7 @@ class OpenAIAdapter(LLMInterface): f"The provided input contains content that is not aligned with our content policy: {text_input}" ) from error - @observe - @retry( - stop=stop_after_delay(128), - wait=wait_exponential_jitter(2, 128), - retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError), - before_sleep=before_sleep_log(logger, logging.DEBUG), - reraise=True, - ) - def create_structured_output( - self, text_input: str, system_prompt: str, response_model: Type[BaseModel] - ) -> BaseModel: - """ - Generate a response from a user query. - - This method creates structured output by sending a synchronous request to the OpenAI API - using the provided parameters to generate a completion based on the user input and - system prompt. - - Parameters: - ----------- - - - text_input (str): The input text provided by the user for generating a response. - - system_prompt (str): The system's prompt to guide the model's response. - - response_model (Type[BaseModel]): The expected model type for the response. - - Returns: - -------- - - - BaseModel: A structured output generated by the model, returned as an instance of - BaseModel. - """ - - return self.client.chat.completions.create( - model=self.model, - messages=[ - { - "role": "user", - "content": f"""{text_input}""", - }, - { - "role": "system", - "content": system_prompt, - }, - ], - api_key=self.api_key, - api_base=self.endpoint, - api_version=self.api_version, - response_model=response_model, - max_retries=self.MAX_RETRIES, - ) - + @observe(as_type="transcription") @retry( stop=stop_after_delay(128), wait=wait_exponential_jitter(2, 128), @@ -282,56 +228,4 @@ class OpenAIAdapter(LLMInterface): return transcription - @retry( - stop=stop_after_delay(128), - wait=wait_exponential_jitter(2, 128), - retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError), - before_sleep=before_sleep_log(logger, logging.DEBUG), - reraise=True, - ) - async def transcribe_image(self, input) -> BaseModel: - """ - Generate a transcription of an image from a user query. - - This method encodes the image and sends a request to the OpenAI API to obtain a - description of the contents of the image. - - Parameters: - ----------- - - - input: The path to the image file that needs to be transcribed. - - Returns: - -------- - - - BaseModel: A structured output generated by the model, returned as an instance of - BaseModel. - """ - async with open_data_file(input, mode="rb") as image_file: - encoded_image = base64.b64encode(image_file.read()).decode("utf-8") - - return litellm.completion( - model=self.model, - messages=[ - { - "role": "user", - "content": [ - { - "type": "text", - "text": "What's in this image?", - }, - { - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{encoded_image}", - }, - }, - ], - } - ], - api_key=self.api_key, - api_base=self.endpoint, - api_version=self.api_version, - max_completion_tokens=300, - max_retries=self.MAX_RETRIES, - ) + # transcribe image inherited from GenericAdapter diff --git a/uv.lock b/uv.lock index cc66c3d7e..d8fb3805b 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.10, <3.14" resolution-markers = [ "python_full_version >= '3.13' and platform_python_implementation != 'PyPy' and sys_platform == 'darwin'", From 09fbf2276828043b8ed1458f50b3ab7efcaa04d2 Mon Sep 17 00:00:00 2001 From: rajeevrajeshuni Date: Tue, 25 Nov 2025 12:24:30 +0530 Subject: [PATCH 030/176] uv lock version revert --- uv.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/uv.lock b/uv.lock index d8fb3805b..cc66c3d7e 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.10, <3.14" resolution-markers = [ "python_full_version >= '3.13' and platform_python_implementation != 'PyPy' and sys_platform == 'darwin'", From 12ce80005ceccafac38a63da458e6df376776b61 Mon Sep 17 00:00:00 2001 From: chinu0609 Date: Wed, 26 Nov 2025 17:32:50 +0530 Subject: [PATCH 031/176] fix: generalized queries --- .../retrieval/utils/access_tracking.py | 147 ++-- cognee/tasks/cleanup/cleanup_unused_data.py | 778 ++++++++++-------- 2 files changed, 516 insertions(+), 409 deletions(-) diff --git a/cognee/modules/retrieval/utils/access_tracking.py b/cognee/modules/retrieval/utils/access_tracking.py index 6df0284ec..12a66f8bc 100644 --- a/cognee/modules/retrieval/utils/access_tracking.py +++ b/cognee/modules/retrieval/utils/access_tracking.py @@ -13,24 +13,10 @@ from sqlalchemy import update logger = get_logger(__name__) - async def update_node_access_timestamps(items: List[Any]): - """ - Update last_accessed_at for nodes in graph database and corresponding Data records in SQL. - - This function: - 1. Updates last_accessed_at in the graph database nodes (in properties JSON) - 2. Traverses to find origin TextDocument nodes (without hardcoded relationship names) - 3. Updates last_accessed in the SQL Data table for those documents - - Parameters - ---------- - items : List[Any] - List of items with payload containing 'id' field (from vector search results) - """ if os.getenv("ENABLE_LAST_ACCESSED", "false").lower() != "true": return - + if not items: return @@ -49,50 +35,95 @@ async def update_node_access_timestamps(items: List[Any]): return try: - # Step 1: Batch update graph nodes - for node_id in node_ids: - result = await graph_engine.query( - "MATCH (n:Node {id: $id}) RETURN n.properties", - {"id": node_id} - ) + # Detect database provider and use appropriate queries + provider = os.getenv("GRAPH_DATABASE_PROVIDER", "kuzu").lower() + + if provider == "kuzu": + await _update_kuzu_nodes(graph_engine, node_ids, timestamp_ms) + elif provider == "neo4j": + await _update_neo4j_nodes(graph_engine, node_ids, timestamp_ms) + elif provider == "neptune": + await _update_neptune_nodes(graph_engine, node_ids, timestamp_ms) + else: + logger.warning(f"Unsupported graph provider: {provider}") + return - if result and result[0]: - props = json.loads(result[0][0]) if result[0][0] else {} - props["last_accessed_at"] = timestamp_ms - - await graph_engine.query( - "MATCH (n:Node {id: $id}) SET n.properties = $props", - {"id": node_id, "props": json.dumps(props)} - ) - - logger.debug(f"Updated access timestamps for {len(node_ids)} graph nodes") - - # Step 2: Find origin TextDocument nodes (without hardcoded relationship names) - origin_query = """ - UNWIND $node_ids AS node_id - MATCH (chunk:Node {id: node_id})-[e:EDGE]-(doc:Node) - WHERE chunk.type = 'DocumentChunk' AND doc.type IN ['TextDocument', 'Document'] - RETURN DISTINCT doc.id - """ - - result = await graph_engine.query(origin_query, {"node_ids": node_ids}) - - # Extract and deduplicate document IDs - doc_ids = list(set([row[0] for row in result if row and row[0]])) if result else [] - - # Step 3: Update SQL Data table + # Find origin documents and update SQL + doc_ids = await _find_origin_documents(graph_engine, node_ids, provider) if doc_ids: - db_engine = get_relational_engine() - async with db_engine.get_async_session() as session: - stmt = update(Data).where( - Data.id.in_([UUID(doc_id) for doc_id in doc_ids]) - ).values(last_accessed=timestamp_dt) - - await session.execute(stmt) - await session.commit() - - logger.debug(f"Updated last_accessed for {len(doc_ids)} Data records in SQL") - + await _update_sql_records(doc_ids, timestamp_dt) + except Exception as e: logger.error(f"Failed to update timestamps: {e}") - raise + raise + +async def _update_kuzu_nodes(graph_engine, node_ids, timestamp_ms): + """Kuzu-specific node updates""" + for node_id in node_ids: + result = await graph_engine.query( + "MATCH (n:Node {id: $id}) RETURN n.properties", + {"id": node_id} + ) + + if result and result[0]: + props = json.loads(result[0][0]) if result[0][0] else {} + props["last_accessed_at"] = timestamp_ms + + await graph_engine.query( + "MATCH (n:Node {id: $id}) SET n.properties = $props", + {"id": node_id, "props": json.dumps(props)} + ) + +async def _update_neo4j_nodes(graph_engine, node_ids, timestamp_ms): + """Neo4j-specific node updates""" + for node_id in node_ids: + await graph_engine.query( + "MATCH (n:__Node__ {id: $id}) SET n.last_accessed_at = $timestamp", + {"id": node_id, "timestamp": timestamp_ms} + ) + +async def _update_neptune_nodes(graph_engine, node_ids, timestamp_ms): + """Neptune-specific node updates""" + for node_id in node_ids: + await graph_engine.query( + "MATCH (n:Node {id: $id}) SET n.last_accessed_at = $timestamp", + {"id": node_id, "timestamp": timestamp_ms} + ) + +async def _find_origin_documents(graph_engine, node_ids, provider): + """Find origin documents with provider-specific queries""" + if provider == "kuzu": + query = """ + UNWIND $node_ids AS node_id + MATCH (chunk:Node {id: node_id})-[e:EDGE]-(doc:Node) + WHERE chunk.type = 'DocumentChunk' AND doc.type IN ['TextDocument', 'Document'] + RETURN DISTINCT doc.id + """ + elif provider == "neo4j": + query = """ + UNWIND $node_ids AS node_id + MATCH (chunk:__Node__ {id: node_id})-[e:EDGE]-(doc:__Node__) + WHERE chunk.type = 'DocumentChunk' AND doc.type IN ['TextDocument', 'Document'] + RETURN DISTINCT doc.id + """ + elif provider == "neptune": + query = """ + UNWIND $node_ids AS node_id + MATCH (chunk:Node {id: node_id})-[e:EDGE]-(doc:Node) + WHERE chunk.type = 'DocumentChunk' AND doc.type IN ['TextDocument', 'Document'] + RETURN DISTINCT doc.id + """ + + result = await graph_engine.query(query, {"node_ids": node_ids}) + return list(set([row[0] for row in result if row and row[0]])) if result else [] + +async def _update_sql_records(doc_ids, timestamp_dt): + """Update SQL Data table (same for all providers)""" + db_engine = get_relational_engine() + async with db_engine.get_async_session() as session: + stmt = update(Data).where( + Data.id.in_([UUID(doc_id) for doc_id in doc_ids]) + ).values(last_accessed=timestamp_dt) + + await session.execute(stmt) + await session.commit() diff --git a/cognee/tasks/cleanup/cleanup_unused_data.py b/cognee/tasks/cleanup/cleanup_unused_data.py index a90d96b5c..b89c939a8 100644 --- a/cognee/tasks/cleanup/cleanup_unused_data.py +++ b/cognee/tasks/cleanup/cleanup_unused_data.py @@ -1,372 +1,448 @@ -""" -Task for automatically deleting unused data from the memify pipeline. - -This task identifies and removes data (chunks, entities, summaries) that hasn't -been accessed by retrievers for a specified period, helping maintain system -efficiency and storage optimization. -""" - -import json -from datetime import datetime, timezone, timedelta -from typing import Optional, Dict, Any -from uuid import UUID -import os -from cognee.infrastructure.databases.graph import get_graph_engine -from cognee.infrastructure.databases.vector import get_vector_engine -from cognee.infrastructure.databases.relational import get_relational_engine -from cognee.modules.data.models import Data, DatasetData -from cognee.shared.logging_utils import get_logger -from sqlalchemy import select, or_ -import cognee -import sqlalchemy as sa - -logger = get_logger(__name__) +""" +Task for automatically deleting unused data from the memify pipeline. + +This task identifies and removes data (chunks, entities, summaries)) that hasn't +been accessed by retrievers for a specified period, helping maintain system +efficiency and storage optimization. +""" + +import json +from datetime import datetime, timezone, timedelta +from typing import Optional, Dict, Any +from uuid import UUID +import os +from cognee.infrastructure.databases.graph import get_graph_engine +from cognee.infrastructure.databases.vector import get_vector_engine +from cognee.infrastructure.databases.relational import get_relational_engine +from cognee.modules.data.models import Data, DatasetData +from cognee.shared.logging_utils import get_logger +from sqlalchemy import select, or_ +import cognee +import sqlalchemy as sa + +logger = get_logger(__name__) + + +async def cleanup_unused_data( + minutes_threshold: Optional[int], + dry_run: bool = True, + user_id: Optional[UUID] = None, + text_doc: bool = False +) -> Dict[str, Any]: + """ + Identify and remove unused data from the memify pipeline. + + Parameters + ---------- + minutes_threshold : int + days since last access to consider data unused + dry_run : bool + If True, only report what would be delete without actually deleting (default: True) + user_id : UUID, optional + Limit cleanup to specific user's data (default: None) + text_doc : bool + If True, use SQL-based filtering to find unused TextDocuments and call cognee.delete() + for proper whole-document deletion (default: False) + + Returns + ------- + Dict[str, Any] + Cleanup results with status, counts, and timestamp + """ + # Check 1: Environment variable must be enabled + if os.getenv("ENABLE_LAST_ACCESSED", "false").lower() != "true": + logger.warning( + "Cleanup skipped: ENABLE_LAST_ACCESSED is not enabled." + ) + return { + "status": "skipped", + "reason": "ENABLE_LAST_ACCESSED not enabled", + "unused_count": 0, + "deleted_count": {}, + "cleanup_date": datetime.now(timezone.utc).isoformat() + } + + # Check 2: Verify tracking has actually been running + db_engine = get_relational_engine() + async with db_engine.get_async_session() as session: + # Count records with non-NULL last_accessed + tracked_count = await session.execute( + select(sa.func.count(Data.id)).where(Data.last_accessed.isnot(None)) + ) + tracked_records = tracked_count.scalar() + + if tracked_records == 0: + logger.warning( + "Cleanup skipped: No records have been tracked yet. " + "ENABLE_LAST_ACCESSED may have been recently enabled. " + "Wait for retrievers to update timestamps before running cleanup." + ) + return { + "status": "skipped", + "reason": "No tracked records found - tracking may be newly enabled", + "unused_count": 0, + "deleted_count": {}, + "cleanup_date": datetime.now(timezone.utc).isoformat() + } + + logger.info( + "Starting cleanup task", + minutes_threshold=minutes_threshold, + dry_run=dry_run, + user_id=str(user_id) if user_id else None, + text_doc=text_doc + ) + + # Calculate cutoff timestamp + cutoff_date = datetime.now(timezone.utc) - timedelta(minutes=minutes_threshold) + + if text_doc: + # SQL-based approach: Find unused TextDocuments and use cognee.delete() + return await _cleanup_via_sql(cutoff_date, dry_run, user_id) + else: + # Graph-based approach: Find unused nodes directly from graph + cutoff_timestamp_ms = int(cutoff_date.timestamp() * 1000) + logger.debug(f"Cutoff timestamp: {cutoff_date.isoformat()} ({cutoff_timestamp_ms}ms)") + + # Detect database provider and find unused nodes + provider = os.getenv("GRAPH_DATABASE_PROVIDER", "kuzu").lower() + unused_nodes = await _find_unused_nodes(cutoff_timestamp_ms, user_id, provider) + + total_unused = sum(len(nodes) for nodes in unused_nodes.values()) + logger.info(f"Found {total_unused} unused nodes", unused_nodes={k: len(v) for k, v in unused_nodes.items()}) + + if dry_run: + return { + "status": "dry_run", + "unused_count": total_unused, + "deleted_count": { + "data_items": 0, + "chunks": 0, + "entities": 0, + "summaries": 0, + "associations": 0 + }, + "cleanup_date": datetime.now(timezone.utc).isoformat(), + "preview": { + "chunks": len(unused_nodes["DocumentChunk"]), + "entities": len(unused_nodes["Entity"]), + "summaries": len(unused_nodes["TextSummary"]) + } + } + + # Delete unused nodes with provider-specific logic + deleted_counts = await _delete_unused_nodes(unused_nodes, provider) + + logger.info("Cleanup completed", deleted_counts=deleted_counts) + + return { + "status": "completed", + "unused_count": total_unused, + "deleted_count": { + "data_items": 0, + "chunks": deleted_counts["DocumentChunk"], + "entities": deleted_counts["Entity"], + "summaries": deleted_counts["TextSummary"], + "associations": deleted_counts["associations"] + }, + "cleanup_date": datetime.now(timezone.utc).isoformat() + } -async def cleanup_unused_data( - days_threshold: Optional[int], - dry_run: bool = True, - user_id: Optional[UUID] = None, - text_doc: bool = False +async def _cleanup_via_sql( + cutoff_date: datetime, + dry_run: bool, + user_id: Optional[UUID] = None ) -> Dict[str, Any]: """ - Identify and remove unused data from the memify pipeline. + SQL-based cleanup: Query Data table for unused documents and use cognee.delete(). Parameters ---------- - days_threshold : int - days since last access to consider data unused + cutoff_date : datetime + Cutoff date for last_accessed filtering dry_run : bool - If True, only report what would be deleted without actually deleting (default: True) - user_id : UUID, optional - Limit cleanup to specific user's data (default: None) - text_doc : bool - If True, use SQL-based filtering to find unused TextDocuments and call cognee.delete() - for proper whole-document deletion (default: False) - - Returns - ------- - Dict[str, Any] - Cleanup results with status, counts, and timestamp - """ - # Check 1: Environment variable must be enabled - if os.getenv("ENABLE_LAST_ACCESSED", "false").lower() != "true": - logger.warning( - "Cleanup skipped: ENABLE_LAST_ACCESSED is not enabled." - ) - return { - "status": "skipped", - "reason": "ENABLE_LAST_ACCESSED not enabled", - "unused_count": 0, - "deleted_count": {}, - "cleanup_date": datetime.now(timezone.utc).isoformat() - } - - # Check 2: Verify tracking has actually been running - db_engine = get_relational_engine() - async with db_engine.get_async_session() as session: - # Count records with non-NULL last_accessed - tracked_count = await session.execute( - select(sa.func.count(Data.id)).where(Data.last_accessed.isnot(None)) - ) - tracked_records = tracked_count.scalar() - - if tracked_records == 0: - logger.warning( - "Cleanup skipped: No records have been tracked yet. " - "ENABLE_LAST_ACCESSED may have been recently enabled. " - "Wait for retrievers to update timestamps before running cleanup." - ) - return { - "status": "skipped", - "reason": "No tracked records found - tracking may be newly enabled", - "unused_count": 0, - "deleted_count": {}, - "cleanup_date": datetime.now(timezone.utc).isoformat() - } - - logger.info( - "Starting cleanup task", - days_threshold=days_threshold, - dry_run=dry_run, - user_id=str(user_id) if user_id else None, - text_doc=text_doc - ) - - # Calculate cutoff timestamp - cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_threshold) - - if text_doc: - # SQL-based approach: Find unused TextDocuments and use cognee.delete() - return await _cleanup_via_sql(cutoff_date, dry_run, user_id) - else: - # Graph-based approach: Find unused nodes directly from graph - cutoff_timestamp_ms = int(cutoff_date.timestamp() * 1000) - logger.debug(f"Cutoff timestamp: {cutoff_date.isoformat()} ({cutoff_timestamp_ms}ms)") - - # Find unused nodes - unused_nodes = await _find_unused_nodes(cutoff_timestamp_ms, user_id) - - total_unused = sum(len(nodes) for nodes in unused_nodes.values()) - logger.info(f"Found {total_unused} unused nodes", unused_nodes={k: len(v) for k, v in unused_nodes.items()}) - - if dry_run: - return { - "status": "dry_run", - "unused_count": total_unused, - "deleted_count": { - "data_items": 0, - "chunks": 0, - "entities": 0, - "summaries": 0, - "associations": 0 - }, - "cleanup_date": datetime.now(timezone.utc).isoformat(), - "preview": { - "chunks": len(unused_nodes["DocumentChunk"]), - "entities": len(unused_nodes["Entity"]), - "summaries": len(unused_nodes["TextSummary"]) - } - } - - # Delete unused nodes - deleted_counts = await _delete_unused_nodes(unused_nodes) - - logger.info("Cleanup completed", deleted_counts=deleted_counts) - - return { - "status": "completed", - "unused_count": total_unused, - "deleted_count": { - "data_items": 0, - "chunks": deleted_counts["DocumentChunk"], - "entities": deleted_counts["Entity"], - "summaries": deleted_counts["TextSummary"], - "associations": deleted_counts["associations"] - }, - "cleanup_date": datetime.now(timezone.utc).isoformat() - } - - -async def _cleanup_via_sql( - cutoff_date: datetime, - dry_run: bool, - user_id: Optional[UUID] = None -) -> Dict[str, Any]: - """ - SQL-based cleanup: Query Data table for unused documents and use cognee.delete(). - - Parameters - ---------- - cutoff_date : datetime - Cutoff date for last_accessed filtering - dry_run : bool - If True, only report what would be deleted - user_id : UUID, optional - Filter by user ID if provided - - Returns - ------- - Dict[str, Any] - Cleanup results - """ - db_engine = get_relational_engine() - - async with db_engine.get_async_session() as session: - # Query for Data records with old last_accessed timestamps - query = select(Data, DatasetData).join( - DatasetData, Data.id == DatasetData.data_id - ).where( - or_( - Data.last_accessed < cutoff_date, - Data.last_accessed.is_(None) - ) - ) - - if user_id: - from cognee.modules.data.models import Dataset - query = query.join(Dataset, DatasetData.dataset_id == Dataset.id).where( - Dataset.owner_id == user_id - ) - - result = await session.execute(query) - unused_data = result.all() - - logger.info(f"Found {len(unused_data)} unused documents in SQL") - - if dry_run: - return { - "status": "dry_run", - "unused_count": len(unused_data), - "deleted_count": { - "data_items": 0, - "documents": 0 - }, - "cleanup_date": datetime.now(timezone.utc).isoformat(), - "preview": { - "documents": len(unused_data) - } - } - - # Delete each document using cognee.delete() - deleted_count = 0 - from cognee.modules.users.methods import get_default_user - user = await get_default_user() if user_id is None else None - - for data, dataset_data in unused_data: - try: - await cognee.delete( - data_id=data.id, - dataset_id=dataset_data.dataset_id, - mode="hard", # Use hard mode to also remove orphaned entities - user=user - ) - deleted_count += 1 - logger.info(f"Deleted document {data.id} from dataset {dataset_data.dataset_id}") - except Exception as e: - logger.error(f"Failed to delete document {data.id}: {e}") - - logger.info("Cleanup completed", deleted_count=deleted_count) - - return { - "status": "completed", - "unused_count": len(unused_data), - "deleted_count": { - "data_items": deleted_count, - "documents": deleted_count - }, - "cleanup_date": datetime.now(timezone.utc).isoformat() - } - - -async def _find_unused_nodes( - cutoff_timestamp_ms: int, - user_id: Optional[UUID] = None -) -> Dict[str, list]: - """ - Query Kuzu for nodes with old last_accessed_at timestamps. - - Parameters - ---------- - cutoff_timestamp_ms : int - Cutoff timestamp in milliseconds since epoch + If True, only report what would be deleted user_id : UUID, optional Filter by user ID if provided Returns ------- - Dict[str, list] - Dictionary mapping node types to lists of unused node IDs + Dict[str, Any] + Cleanup results """ - graph_engine = await get_graph_engine() + db_engine = get_relational_engine() - # Query all nodes with their properties - query = "MATCH (n:Node) RETURN n.id, n.type, n.properties" - results = await graph_engine.query(query) - - unused_nodes = { - "DocumentChunk": [], - "Entity": [], - "TextSummary": [] - } - - for node_id, node_type, props_json in results: - # Only process tracked node types - if node_type not in unused_nodes: - continue - - # Parse properties JSON - if props_json: - try: - props = json.loads(props_json) - last_accessed = props.get("last_accessed_at") - - # Check if node is unused (never accessed or accessed before cutoff) - if last_accessed is None or last_accessed < cutoff_timestamp_ms: - unused_nodes[node_type].append(node_id) - logger.debug( - f"Found unused {node_type}", - node_id=node_id, - last_accessed=last_accessed - ) - except json.JSONDecodeError: - logger.warning(f"Failed to parse properties for node {node_id}") - continue - - return unused_nodes - - -async def _delete_unused_nodes(unused_nodes: Dict[str, list]) -> Dict[str, int]: - """ - Delete unused nodes from graph and vector databases. - - Parameters - ---------- - unused_nodes : Dict[str, list] - Dictionary mapping node types to lists of node IDs to delete - - Returns - ------- - Dict[str, int] - Count of deleted items by type - """ - graph_engine = await get_graph_engine() - vector_engine = get_vector_engine() - - deleted_counts = { - "DocumentChunk": 0, - "Entity": 0, - "TextSummary": 0, - "associations": 0 - } - - # Count associations before deletion - for node_type, node_ids in unused_nodes.items(): - if not node_ids: - continue - - # Count edges connected to these nodes - for node_id in node_ids: - result = await graph_engine.query( - "MATCH (n:Node {id: $id})-[r:EDGE]-() RETURN count(r)", - {"id": node_id} + async with db_engine.get_async_session() as session: + # Query for Data records with old last_accessed timestamps + query = select(Data, DatasetData).join( + DatasetData, Data.id == DatasetData.data_id + ).where( + or_( + Data.last_accessed < cutoff_date, + Data.last_accessed.is_(None) ) - if result and len(result) > 0: - deleted_counts["associations"] += result[0][0] - - # Delete from graph database (uses DETACH DELETE, so edges are automatically removed) - for node_type, node_ids in unused_nodes.items(): - if not node_ids: - continue + ) - logger.info(f"Deleting {len(node_ids)} {node_type} nodes from graph database") + if user_id: + from cognee.modules.data.models import Dataset + query = query.join(Dataset, DatasetData.dataset_id == Dataset.id).where( + Dataset.owner_id == user_id + ) - # Delete nodes in batches - await graph_engine.delete_nodes(node_ids) - deleted_counts[node_type] = len(node_ids) + result = await session.execute(query) + unused_data = result.all() - # Delete from vector database - vector_collections = { - "DocumentChunk": "DocumentChunk_text", - "Entity": "Entity_name", - "TextSummary": "TextSummary_text" + logger.info(f"Found {len(unused_data)} unused documents in SQL") + + if dry_run: + return { + "status": "dry_run", + "unused_count": len(unused_data), + "deleted_count": { + "data_items": 0, + "documents": 0 + }, + "cleanup_date": datetime.now(timezone.utc).isoformat(), + "preview": { + "documents": len(unused_data) + } + } + + # Delete each document using cognee.delete() + deleted_count = 0 + from cognee.modules.users.methods import get_default_user + user = await get_default_user() if user_id is None else None + + for data, dataset_data in unused_data: + try: + await cognee.delete( + data_id=data.id, + dataset_id=dataset_data.dataset_id, + mode="hard", # Use hard mode to also remove orphaned entities + user=user + ) + deleted_count += 1 + logger.info(f"Deleted document {data.id} from dataset {dataset_data.dataset_id}") + except Exception as e: + logger.error(f"Failed to delete document {data.id}: {e}") + + logger.info("Cleanup completed", deleted_count=deleted_count) + + return { + "status": "completed", + "unused_count": len(unused_data), + "deleted_count": { + "data_items": deleted_count, + "documents": deleted_count + }, + "cleanup_date": datetime.now(timezone.utc).isoformat() } - - - for node_type, collection_name in vector_collections.items(): - node_ids = unused_nodes[node_type] - if not node_ids: - continue - - logger.info(f"Deleting {len(node_ids)} {node_type} embeddings from vector database") - - try: - if await vector_engine.has_collection(collection_name): - await vector_engine.delete_data_points( - collection_name, - [str(node_id) for node_id in node_ids] - ) - except Exception as e: - logger.error(f"Error deleting from vector collection {collection_name}: {e}") - + + +async def _find_unused_nodes( + cutoff_timestamp_ms: int, + user_id: Optional[UUID] = None, + provider: str = "kuzu" +) -> Dict[str, list]: + """ + Find unused nodes with provider-specific queries. + + Parameters + ---------- + cutoff_timestamp_ms : int + Cutoff timestamp in milliseconds since epoch + user_id : UUID, optional + Filter by user ID if provided + provider : str + Graph database provider (kuzu, neo4j, neptune) + + Returns + ------- + Dict[str, list] + Dictionary mapping node types to lists of unused node IDs + """ + graph_engine = await get_graph_engine() + + if provider == "kuzu": + return await _find_unused_nodes_kuzu(graph_engine, cutoff_timestamp_ms) + elif provider == "neo4j": + return await _find_unused_nodes_neo4j(graph_engine, cutoff_timestamp_ms) + elif provider == "neptune": + return await _find_unused_nodes_neptune(graph_engine, cutoff_timestamp_ms) + else: + logger.warning(f"Unsupported graph provider: {provider}") + return {"DocumentChunk": [], "Entity": [], "TextSummary": []} + + +async def _find_unused_nodes_kuzu(graph_engine, cutoff_timestamp_ms): + """Kuzu-specific unused node detection""" + query = "MATCH (n:Node) RETURN n.id, n.type, n.properties" + results = await graph_engine.query(query) + + unused_nodes = {"DocumentChunk": [], "Entity": [], "TextSummary": []} + + for node_id, node_type, props_json in results: + if node_type not in unused_nodes: + continue + + if props_json: + try: + props = json.loads(props_json) + last_accessed = props.get("last_accessed_at") + + if last_accessed is None or last_accessed < cutoff_timestamp_ms: + unused_nodes[node_type].append(node_id) + logger.debug( + f"Found unused {node_type}", + node_id=node_id, + last_accessed=last_accessed + ) + except json.JSONDecodeError: + logger.warning(f"Failed to parse properties for node {node_id}") + continue + + return unused_nodes + + +async def _find_unused_nodes_neo4j(graph_engine, cutoff_timestamp_ms): + """Neo4j-specific unused node detection""" + query = "MATCH (n:__Node__) RETURN n.id, n.type, n.last_accessed_at" + results = await graph_engine.query(query) + + unused_nodes = {"DocumentChunk": [], "Entity": [], "TextSummary": []} + + for row in results: + node_id = row["n"]["id"] + node_type = row["n"]["type"] + last_accessed = row["n"].get("last_accessed_at") + + if node_type not in unused_nodes: + continue + + if last_accessed is None or last_accessed < cutoff_timestamp_ms: + unused_nodes[node_type].append(node_id) + logger.debug( + f"Found unused {node_type}", + node_id=node_id, + last_accessed=last_accessed + ) + + return unused_nodes + + +async def _find_unused_nodes_neptune(graph_engine, cutoff_timestamp_ms): + """Neptune-specific unused node detection""" + query = "MATCH (n:Node) RETURN n.id, n.type, n.last_accessed_at" + results = await graph_engine.query(query) + + unused_nodes = {"DocumentChunk": [], "Entity": [], "TextSummary": []} + + for row in results: + node_id = row["n"]["id"] + node_type = row["n"]["type"] + last_accessed = row["n"].get("last_accessed_at") + + if node_type not in unused_nodes: + continue + + if last_accessed is None or last_accessed < cutoff_timestamp_ms: + unused_nodes[node_type].append(node_id) + logger.debug( + f"Found unused {node_type}", + node_id=node_id, + last_accessed=last_accessed + ) + + return unused_nodes + + +async def _delete_unused_nodes(unused_nodes: Dict[str, list], provider: str) -> Dict[str, int]: + """ + Delete unused nodes from graph and vector databases. + + Parameters + ---------- + unused_nodes : Dict[str, list] + Dictionary mapping node types to lists of node IDs to delete + provider : str + Graph database provider (kuzu, neo4j, neptune) + + Returns + ------- + Dict[str, int] + Count of deleted items by type + """ + graph_engine = await get_graph_engine() + vector_engine = get_vector_engine() + + deleted_counts = { + "DocumentChunk": 0, + "Entity": 0, + "TextSummary": 0, + "associations": 0 + } + + # Count associations before deletion + for node_type, node_ids in unused_nodes.items(): + if not node_ids: + continue + + # Count edges connected to these nodes + for node_id in node_ids: + if provider == "kuzu": + result = await graph_engine.query( + "MATCH (n:Node {id: $id})-[r:EDGE]-() RETURN count(r)", + {"id": node_id} + ) + elif provider == "neo4j": + result = await graph_engine.query( + "MATCH (n:__Node__ {id: $id})-[r:EDGE]-() RETURN count(r)", + {"id": node_id} + ) + elif provider == "neptune": + result = await graph_engine.query( + "MATCH (n:Node {id: $id})-[r:EDGE]-() RETURN count(r)", + {"id": node_id} + ) + + if result and len(result) > 0: + count = result[0][0] if provider == "kuzu" else result[0]["count_count(r)"] + deleted_counts["associations"] += count + + # Delete from graph database (uses DETACH DELETE, so edges are automatically removed) + for node_type, node_ids in unused_nodes.items(): + if not node_ids: + continue + + logger.info(f"Deleting {len(node_ids)} {node_type} nodes from graph database") + + # Delete nodes in batches + await graph_engine.delete_nodes(node_ids) + deleted_counts[node_type] = len(node_ids) + + # Delete from vector database + vector_collections = { + "DocumentChunk": "DocumentChunk_text", + "Entity": "Entity_name", + "TextSummary": "TextSummary_text" + } + + + for node_type, collection_name in vector_collections.items(): + node_ids = unused_nodes[node_type] + if not node_ids: + continue + + logger.info(f"Deleting {len(node_ids)} {node_type} embeddings from vector database") + + try: + if await vector_engine.has_collection(collection_name): + await vector_engine.delete_data_points( + collection_name, + [str(node_id) for node_id in node_ids] + ) + except Exception as e: + logger.error(f"Error deleting from vector collection {collection_name}: {e}") + return deleted_counts From 73d84129de9405a8d158a46c00c7d7802225f1ed Mon Sep 17 00:00:00 2001 From: Mike Potter Date: Fri, 28 Nov 2025 12:34:20 -0500 Subject: [PATCH 032/176] fix(api): pass run_in_background parameter to memify function The run_in_background parameter was defined in MemifyPayloadDTO but was never passed to the cognee_memify function call, making the parameter effectively unused. This fix passes the parameter so users can actually run memify operations in the background. Signed-off-by: Mike Potter --- cognee/api/v1/memify/routers/get_memify_router.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cognee/api/v1/memify/routers/get_memify_router.py b/cognee/api/v1/memify/routers/get_memify_router.py index cc07a3a0c..c63e4a394 100644 --- a/cognee/api/v1/memify/routers/get_memify_router.py +++ b/cognee/api/v1/memify/routers/get_memify_router.py @@ -90,6 +90,7 @@ def get_memify_router() -> APIRouter: dataset=payload.dataset_id if payload.dataset_id else payload.dataset_name, node_name=payload.node_name, user=user, + run_in_background=payload.run_in_background, ) if isinstance(memify_run, PipelineRunErrored): From 6a4d31356bb613e5cf74e7972445f804796ee6d4 Mon Sep 17 00:00:00 2001 From: chinu0609 Date: Tue, 2 Dec 2025 18:55:47 +0530 Subject: [PATCH 033/176] fix: using graph projection instead of conditions --- .../retrieval/utils/access_tracking.py | 156 ++-- cognee/tasks/cleanup/cleanup_unused_data.py | 759 ++++++++---------- 2 files changed, 418 insertions(+), 497 deletions(-) diff --git a/cognee/modules/retrieval/utils/access_tracking.py b/cognee/modules/retrieval/utils/access_tracking.py index 12a66f8bc..935c47157 100644 --- a/cognee/modules/retrieval/utils/access_tracking.py +++ b/cognee/modules/retrieval/utils/access_tracking.py @@ -4,118 +4,116 @@ import json from datetime import datetime, timezone from typing import List, Any from uuid import UUID -import os +import os from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.relational import get_relational_engine from cognee.modules.data.models import Data from cognee.shared.logging_utils import get_logger from sqlalchemy import update +from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph logger = get_logger(__name__) async def update_node_access_timestamps(items: List[Any]): if os.getenv("ENABLE_LAST_ACCESSED", "false").lower() != "true": return - + if not items: return - + graph_engine = await get_graph_engine() timestamp_ms = int(datetime.now(timezone.utc).timestamp() * 1000) timestamp_dt = datetime.now(timezone.utc) - + # Extract node IDs node_ids = [] for item in items: item_id = item.payload.get("id") if hasattr(item, 'payload') else item.get("id") if item_id: node_ids.append(str(item_id)) - + if not node_ids: return - - try: - # Detect database provider and use appropriate queries - provider = os.getenv("GRAPH_DATABASE_PROVIDER", "kuzu").lower() - if provider == "kuzu": - await _update_kuzu_nodes(graph_engine, node_ids, timestamp_ms) - elif provider == "neo4j": - await _update_neo4j_nodes(graph_engine, node_ids, timestamp_ms) - elif provider == "neptune": - await _update_neptune_nodes(graph_engine, node_ids, timestamp_ms) - else: - logger.warning(f"Unsupported graph provider: {provider}") - return + try: + # Update nodes using graph projection ( database-agnostic approach + await _update_nodes_via_projection(graph_engine, node_ids, timestamp_ms) # Find origin documents and update SQL - doc_ids = await _find_origin_documents(graph_engine, node_ids, provider) + doc_ids = await _find_origin_documents_via_projection(graph_engine, node_ids) if doc_ids: await _update_sql_records(doc_ids, timestamp_dt) - + except Exception as e: logger.error(f"Failed to update timestamps: {e}") raise -async def _update_kuzu_nodes(graph_engine, node_ids, timestamp_ms): - """Kuzu-specific node updates""" - for node_id in node_ids: - result = await graph_engine.query( - "MATCH (n:Node {id: $id}) RETURN n.properties", - {"id": node_id} - ) - - if result and result[0]: - props = json.loads(result[0][0]) if result[0][0] else {} - props["last_accessed_at"] = timestamp_ms - - await graph_engine.query( - "MATCH (n:Node {id: $id}) SET n.properties = $props", - {"id": node_id, "props": json.dumps(props)} - ) - -async def _update_neo4j_nodes(graph_engine, node_ids, timestamp_ms): - """Neo4j-specific node updates""" - for node_id in node_ids: - await graph_engine.query( - "MATCH (n:__Node__ {id: $id}) SET n.last_accessed_at = $timestamp", - {"id": node_id, "timestamp": timestamp_ms} - ) - -async def _update_neptune_nodes(graph_engine, node_ids, timestamp_ms): - """Neptune-specific node updates""" - for node_id in node_ids: - await graph_engine.query( - "MATCH (n:Node {id: $id}) SET n.last_accessed_at = $timestamp", - {"id": node_id, "timestamp": timestamp_ms} - ) - -async def _find_origin_documents(graph_engine, node_ids, provider): - """Find origin documents with provider-specific queries""" - if provider == "kuzu": - query = """ - UNWIND $node_ids AS node_id - MATCH (chunk:Node {id: node_id})-[e:EDGE]-(doc:Node) - WHERE chunk.type = 'DocumentChunk' AND doc.type IN ['TextDocument', 'Document'] - RETURN DISTINCT doc.id - """ - elif provider == "neo4j": - query = """ - UNWIND $node_ids AS node_id - MATCH (chunk:__Node__ {id: node_id})-[e:EDGE]-(doc:__Node__) - WHERE chunk.type = 'DocumentChunk' AND doc.type IN ['TextDocument', 'Document'] - RETURN DISTINCT doc.id - """ - elif provider == "neptune": - query = """ - UNWIND $node_ids AS node_id - MATCH (chunk:Node {id: node_id})-[e:EDGE]-(doc:Node) - WHERE chunk.type = 'DocumentChunk' AND doc.type IN ['TextDocument', 'Document'] - RETURN DISTINCT doc.id - """ +async def _update_nodes_via_projection(graph_engine, node_ids, timestamp_ms): + """Update nodes using graph projection - works with any graph database""" + # Project the graph with necessary properties + memory_fragment = CogneeGraph() + await memory_fragment.project_graph_from_db( + graph_engine, + node_properties_to_project=["id"], + edge_properties_to_project=[] + ) - result = await graph_engine.query(query, {"node_ids": node_ids}) - return list(set([row[0] for row in result if row and row[0]])) if result else [] + # Update each node's last_accessed_at property + for node_id in node_ids: + node = memory_fragment.get_node(node_id) + if node: + # Update the node in the database + provider = os.getenv("GRAPH_DATABASE_PROVIDER", "kuzu").lower() + + if provider == "kuzu": + # Kuzu stores properties as JSON + result = await graph_engine.query( + "MATCH (n:Node {id: $id}) RETURN n.properties", + {"id": node_id} + ) + + if result and result[0]: + props = json.loads(result[0][0]) if result[0][0] else {} + props["last_accessed_at"] = timestamp_ms + + await graph_engine.query( + "MATCH (n:Node {id: $id}) SET n.properties = $props", + {"id": node_id, "props": json.dumps(props)} + ) + elif provider == "neo4j": + await graph_engine.query( + "MATCH (n:__Node__ {id: $id}) SET n.last_accessed_at = $timestamp", + {"id": node_id, "timestamp": timestamp_ms} + ) + elif provider == "neptune": + await graph_engine.query( + "MATCH (n:Node {id: $id}) SET n.last_accessed_at = $timestamp", + {"id": node_id, "timestamp": timestamp_ms} + ) + +async def _find_origin_documents_via_projection(graph_engine, node_ids): + """Find origin documents using graph projection instead of DB queries""" + # Project the entire graph with necessary properties + memory_fragment = CogneeGraph() + await memory_fragment.project_graph_from_db( + graph_engine, + node_properties_to_project=["id", "type"], + edge_properties_to_project=["relationship_name"] + ) + + # Find origin documents by traversing the in-memory graph + doc_ids = set() + for node_id in node_ids: + node = memory_fragment.get_node(node_id) + if node and node.get_attribute("type") == "DocumentChunk": + # Traverse edges to find connected documents + for edge in node.get_skeleton_edges(): + # Get the neighbor node + neighbor = edge.get_destination_node() if edge.get_source_node().id == node_id else edge.get_source_node() + if neighbor and neighbor.get_attribute("type") in ["TextDocument", "Document"]: + doc_ids.add(neighbor.id) + + return list(doc_ids) async def _update_sql_records(doc_ids, timestamp_dt): """Update SQL Data table (same for all providers)""" @@ -124,6 +122,6 @@ async def _update_sql_records(doc_ids, timestamp_dt): stmt = update(Data).where( Data.id.in_([UUID(doc_id) for doc_id in doc_ids]) ).values(last_accessed=timestamp_dt) - + await session.execute(stmt) await session.commit() diff --git a/cognee/tasks/cleanup/cleanup_unused_data.py b/cognee/tasks/cleanup/cleanup_unused_data.py index b89c939a8..c70b97a00 100644 --- a/cognee/tasks/cleanup/cleanup_unused_data.py +++ b/cognee/tasks/cleanup/cleanup_unused_data.py @@ -1,448 +1,371 @@ -""" -Task for automatically deleting unused data from the memify pipeline. - -This task identifies and removes data (chunks, entities, summaries)) that hasn't -been accessed by retrievers for a specified period, helping maintain system -efficiency and storage optimization. -""" - -import json -from datetime import datetime, timezone, timedelta -from typing import Optional, Dict, Any -from uuid import UUID -import os -from cognee.infrastructure.databases.graph import get_graph_engine -from cognee.infrastructure.databases.vector import get_vector_engine -from cognee.infrastructure.databases.relational import get_relational_engine -from cognee.modules.data.models import Data, DatasetData -from cognee.shared.logging_utils import get_logger -from sqlalchemy import select, or_ -import cognee -import sqlalchemy as sa - -logger = get_logger(__name__) +""" +Task for automatically deleting unused data from the memify pipeline. + +This task identifies and removes data (chunks, entities, summaries)) that hasn't +been accessed by retrievers for a specified period, helping maintain system +efficiency and storage optimization. +""" + +import json +from datetime import datetime, timezone, timedelta +from typing import Optional, Dict, Any +from uuid import UUID +import os +from cognee.infrastructure.databases.graph import get_graph_engine +from cognee.infrastructure.databases.vector import get_vector_engine +from cognee.infrastructure.databases.relational import get_relational_engine +from cognee.modules.data.models import Data, DatasetData +from cognee.shared.logging_utils import get_logger +from sqlalchemy import select, or_ +import cognee +import sqlalchemy as sa +from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph + +logger = get_logger(__name__) + + +async def cleanup_unused_data( + minutes_threshold: Optional[int], + dry_run: bool = True, + user_id: Optional[UUID] = None, + text_doc: bool = False +) -> Dict[str, Any]: + """ + Identify and remove unused data from the memify pipeline. + + Parameters + ---------- + minutes_threshold : int + days since last access to consider data unused + dry_run : bool + If True, only report what would be delete without actually deleting (default: True) + user_id : UUID, optional + Limit cleanup to specific user's data (default: None) + text_doc : bool + If True, use SQL-based filtering to find unused TextDocuments and call cognee.delete() + for proper whole-document deletion (default: False) + + Returns + ------- + Dict[str, Any] + Cleanup results with status, counts, and timestamp + """ + # Check 1: Environment variable must be enabled + if os.getenv("ENABLE_LAST_ACCESSED", "false").lower() != "true": + logger.warning( + "Cleanup skipped: ENABLE_LAST_ACCESSED is not enabled." + ) + return { + "status": "skipped", + "reason": "ENABLE_LAST_ACCESSED not enabled", + "unused_count": 0, + "deleted_count": {}, + "cleanup_date": datetime.now(timezone.utc).isoformat() + } + + # Check 2: Verify tracking has actually been running + db_engine = get_relational_engine() + async with db_engine.get_async_session() as session: + # Count records with non-NULL last_accessed + tracked_count = await session.execute( + select(sa.func.count(Data.id)).where(Data.last_accessed.isnot(None)) + ) + tracked_records = tracked_count.scalar() + + if tracked_records == 0: + logger.warning( + "Cleanup skipped: No records have been tracked yet. " + "ENABLE_LAST_ACCESSED may have been recently enabled. " + "Wait for retrievers to update timestamps before running cleanup." + ) + return { + "status": "skipped", + "reason": "No tracked records found - tracking may be newly enabled", + "unused_count": 0, + "deleted_count": {}, + "cleanup_date": datetime.now(timezone.utc).isoformat() + } + + logger.info( + "Starting cleanup task", + minutes_threshold=minutes_threshold, + dry_run=dry_run, + user_id=str(user_id) if user_id else None, + text_doc=text_doc + ) + + # Calculate cutoff timestamp + cutoff_date = datetime.now(timezone.utc) - timedelta(minutes=minutes_threshold) + + if text_doc: + # SQL-based approach: Find unused TextDocuments and use cognee.delete() + return await _cleanup_via_sql(cutoff_date, dry_run, user_id) + else: + # Graph-based approach: Find unused nodes using projection (database-agnostic) + cutoff_timestamp_ms = int(cutoff_date.timestamp() * 1000) + logger.debug(f"Cutoff timestamp: {cutoff_date.isoformat()} ({cutoff_timestamp_ms}ms)") + + # Find unused nodes using graph projection + unused_nodes = await _find_unused_nodes_via_projection(cutoff_timestamp_ms) + + total_unused = sum(len(nodes) for nodes in unused_nodes.values()) + logger.info(f"Found {total_unused} unused nodes", unused_nodes={k: len(v) for k, v in unused_nodes.items()}) + + if dry_run: + return { + "status": "dry_run", + "unused_count": total_unused, + "deleted_count": { + "data_items": 0, + "chunks": 0, + "entities": 0, + "summaries": 0, + "associations": 0 + }, + "cleanup_date": datetime.now(timezone.utc).isoformat(), + "preview": { + "chunks": len(unused_nodes["DocumentChunk"]), + "entities": len(unused_nodes["Entity"]), + "summaries": len(unused_nodes["TextSummary"]) + } + } + + # Delete unused nodes (provider-agnostic deletion) + deleted_counts = await _delete_unused_nodes(unused_nodes) + + logger.info("Cleanup completed", deleted_counts=deleted_counts) + + return { + "status": "completed", + "unused_count": total_unused, + "deleted_count": { + "data_items": 0, + "chunks": deleted_counts["DocumentChunk"], + "entities": deleted_counts["Entity"], + "summaries": deleted_counts["TextSummary"], + "associations": deleted_counts["associations"] + }, + "cleanup_date": datetime.now(timezone.utc).isoformat() + } -async def cleanup_unused_data( - minutes_threshold: Optional[int], - dry_run: bool = True, - user_id: Optional[UUID] = None, - text_doc: bool = False +async def _cleanup_via_sql( + cutoff_date: datetime, + dry_run: bool, + user_id: Optional[UUID] = None ) -> Dict[str, Any]: """ - Identify and remove unused data from the memify pipeline. + SQL-based cleanup: Query Data table for unused documents and use cognee.delete(). Parameters ---------- - minutes_threshold : int - days since last access to consider data unused + cutoff_date : datetime + Cutoff date for last_accessed filtering dry_run : bool - If True, only report what would be delete without actually deleting (default: True) + If True, only report what would be deleted user_id : UUID, optional - Limit cleanup to specific user's data (default: None) - text_doc : bool - If True, use SQL-based filtering to find unused TextDocuments and call cognee.delete() - for proper whole-document deletion (default: False) + Filter by user ID if provided Returns ------- Dict[str, Any] - Cleanup results with status, counts, and timestamp - """ - # Check 1: Environment variable must be enabled - if os.getenv("ENABLE_LAST_ACCESSED", "false").lower() != "true": - logger.warning( - "Cleanup skipped: ENABLE_LAST_ACCESSED is not enabled." - ) - return { - "status": "skipped", - "reason": "ENABLE_LAST_ACCESSED not enabled", - "unused_count": 0, - "deleted_count": {}, - "cleanup_date": datetime.now(timezone.utc).isoformat() - } - - # Check 2: Verify tracking has actually been running - db_engine = get_relational_engine() - async with db_engine.get_async_session() as session: - # Count records with non-NULL last_accessed - tracked_count = await session.execute( - select(sa.func.count(Data.id)).where(Data.last_accessed.isnot(None)) - ) - tracked_records = tracked_count.scalar() - - if tracked_records == 0: - logger.warning( - "Cleanup skipped: No records have been tracked yet. " - "ENABLE_LAST_ACCESSED may have been recently enabled. " - "Wait for retrievers to update timestamps before running cleanup." - ) - return { - "status": "skipped", - "reason": "No tracked records found - tracking may be newly enabled", - "unused_count": 0, - "deleted_count": {}, - "cleanup_date": datetime.now(timezone.utc).isoformat() - } - - logger.info( - "Starting cleanup task", - minutes_threshold=minutes_threshold, - dry_run=dry_run, - user_id=str(user_id) if user_id else None, - text_doc=text_doc - ) + Cleanup results + """ + db_engine = get_relational_engine() - # Calculate cutoff timestamp - cutoff_date = datetime.now(timezone.utc) - timedelta(minutes=minutes_threshold) - - if text_doc: - # SQL-based approach: Find unused TextDocuments and use cognee.delete() - return await _cleanup_via_sql(cutoff_date, dry_run, user_id) - else: - # Graph-based approach: Find unused nodes directly from graph - cutoff_timestamp_ms = int(cutoff_date.timestamp() * 1000) - logger.debug(f"Cutoff timestamp: {cutoff_date.isoformat()} ({cutoff_timestamp_ms}ms)") + async with db_engine.get_async_session() as session: + # Query for Data records with old last_accessed timestamps + query = select(Data, DatasetData).join( + DatasetData, Data.id == DatasetData.data_id + ).where( + or_( + Data.last_accessed < cutoff_date, + Data.last_accessed.is_(None) + ) + ) - # Detect database provider and find unused nodes - provider = os.getenv("GRAPH_DATABASE_PROVIDER", "kuzu").lower() - unused_nodes = await _find_unused_nodes(cutoff_timestamp_ms, user_id, provider) - - total_unused = sum(len(nodes) for nodes in unused_nodes.values()) - logger.info(f"Found {total_unused} unused nodes", unused_nodes={k: len(v) for k, v in unused_nodes.items()}) - - if dry_run: - return { - "status": "dry_run", - "unused_count": total_unused, - "deleted_count": { - "data_items": 0, - "chunks": 0, - "entities": 0, - "summaries": 0, - "associations": 0 - }, - "cleanup_date": datetime.now(timezone.utc).isoformat(), - "preview": { - "chunks": len(unused_nodes["DocumentChunk"]), - "entities": len(unused_nodes["Entity"]), - "summaries": len(unused_nodes["TextSummary"]) - } - } - - # Delete unused nodes with provider-specific logic - deleted_counts = await _delete_unused_nodes(unused_nodes, provider) - - logger.info("Cleanup completed", deleted_counts=deleted_counts) + if user_id: + from cognee.modules.data.models import Dataset + query = query.join(Dataset, DatasetData.dataset_id == Dataset.id).where( + Dataset.owner_id == user_id + ) + result = await session.execute(query) + unused_data = result.all() + + logger.info(f"Found {len(unused_data)} unused documents in SQL") + + if dry_run: return { - "status": "completed", - "unused_count": total_unused, + "status": "dry_run", + "unused_count": len(unused_data), "deleted_count": { "data_items": 0, - "chunks": deleted_counts["DocumentChunk"], - "entities": deleted_counts["Entity"], - "summaries": deleted_counts["TextSummary"], - "associations": deleted_counts["associations"] + "documents": 0 }, - "cleanup_date": datetime.now(timezone.utc).isoformat() - } - - -async def _cleanup_via_sql( - cutoff_date: datetime, - dry_run: bool, - user_id: Optional[UUID] = None -) -> Dict[str, Any]: - """ - SQL-based cleanup: Query Data table for unused documents and use cognee.delete(). + "cleanup_date": datetime.now(timezone.utc).isoformat(), + "preview": { + "documents": len(unused_data) + } + } + + # Delete each document using cognee.delete() + deleted_count = 0 + from cognee.modules.users.methods import get_default_user + user = await get_default_user() if user_id is None else None + + for data, dataset_data in unused_data: + try: + await cognee.delete( + data_id=data.id, + dataset_id=dataset_data.dataset_id, + mode="hard", # Use hard mode to also remove orphaned entities + user=user + ) + deleted_count += 1 + logger.info(f"Deleted document {data.id} from dataset {dataset_data.dataset_id}") + except Exception as e: + logger.error(f"Failed to delete document {data.id}: {e}") + + logger.info("Cleanup completed", deleted_count=deleted_count) + + return { + "status": "completed", + "unused_count": len(unused_data), + "deleted_count": { + "data_items": deleted_count, + "documents": deleted_count + }, + "cleanup_date": datetime.now(timezone.utc).isoformat() + } - Parameters - ---------- - cutoff_date : datetime - Cutoff date for last_accessed filtering - dry_run : bool - If True, only report what would be deleted - user_id : UUID, optional - Filter by user ID if provided - Returns - ------- - Dict[str, Any] - Cleanup results - """ - db_engine = get_relational_engine() +async def _find_unused_nodes_via_projection(cutoff_timestamp_ms: int) -> Dict[str, list]: + """ + Find unused nodes using graph projection - database-agnostic approach. + + Parameters + ---------- + cutoff_timestamp_ms : int + Cutoff timestamp in milliseconds since epoch + + Returns + ------- + Dict[str, list] + Dictionary mapping node types to lists of unused node IDs + """ + graph_engine = await get_graph_engine() + + # Project the entire graph with necessary properties + memory_fragment = CogneeGraph() + await memory_fragment.project_graph_from_db( + graph_engine, + node_properties_to_project=["id", "type", "last_accessed_at"], + edge_properties_to_project=[] + ) - async with db_engine.get_async_session() as session: - # Query for Data records with old last_accessed timestamps - query = select(Data, DatasetData).join( - DatasetData, Data.id == DatasetData.data_id - ).where( - or_( - Data.last_accessed < cutoff_date, - Data.last_accessed.is_(None) + unused_nodes = {"DocumentChunk": [], "Entity": [], "TextSummary": []} + + # Get all nodes from the projected graph + all_nodes = memory_fragment.get_nodes() + + for node in all_nodes: + node_type = node.get_attribute("type") + if node_type not in unused_nodes: + continue + + # Check last_accessed_at property + last_accessed = node.get_attribute("last_accessed_at") + + if last_accessed is None or last_accessed < cutoff_timestamp_ms: + unused_nodes[node_type].append(node.id) + logger.debug( + f"Found unused {node_type}", + node_id=node.id, + last_accessed=last_accessed ) + + return unused_nodes + + +async def _delete_unused_nodes(unused_nodes: Dict[str, list]) -> Dict[str, int]: + """ + Delete unused nodes from graph and vector databases. + + Parameters + ---------- + unused_nodes : Dict[str, list] + Dictionary mapping node types to lists of node IDs to delete + + Returns + ------- + Dict[str, int] + Count of deleted items by type + """ + graph_engine = await get_graph_engine() + vector_engine = get_vector_engine() + + deleted_counts = { + "DocumentChunk": 0, + "Entity": 0, + "TextSummary": 0, + "associations": 0 + } + + # Count associations before deletion (using graph projection for consistency) + if any(unused_nodes.values()): + memory_fragment = CogneeGraph() + await memory_fragment.project_graph_from_db( + graph_engine, + node_properties_to_project=["id"], + edge_properties_to_project=[] ) - if user_id: - from cognee.modules.data.models import Dataset - query = query.join(Dataset, DatasetData.dataset_id == Dataset.id).where( - Dataset.owner_id == user_id - ) + for node_type, node_ids in unused_nodes.items(): + if not node_ids: + continue + + # Count edges connected to these nodes + for node_id in node_ids: + node = memory_fragment.get_node(node_id) + if node: + # Count edges from the in-memory graph + edge_count = len(node.get_skeleton_edges()) + deleted_counts["associations"] += edge_count + + # Delete from graph database (uses DETACH DELETE, so edges are automatically removed) + for node_type, node_ids in unused_nodes.items(): + if not node_ids: + continue + + logger.info(f"Deleting {len(node_ids)} {node_type} nodes from graph database") + + # Delete nodes in batches (database-agnostic) + await graph_engine.delete_nodes(node_ids) + deleted_counts[node_type] = len(node_ids) - result = await session.execute(query) - unused_data = result.all() - - logger.info(f"Found {len(unused_data)} unused documents in SQL") - - if dry_run: - return { - "status": "dry_run", - "unused_count": len(unused_data), - "deleted_count": { - "data_items": 0, - "documents": 0 - }, - "cleanup_date": datetime.now(timezone.utc).isoformat(), - "preview": { - "documents": len(unused_data) - } - } - - # Delete each document using cognee.delete() - deleted_count = 0 - from cognee.modules.users.methods import get_default_user - user = await get_default_user() if user_id is None else None - - for data, dataset_data in unused_data: + # Delete from vector database + vector_collections = { + "DocumentChunk": "DocumentChunk_text", + "Entity": "Entity_name", + "TextSummary": "TextSummary_text" + } + + + for node_type, collection_name in vector_collections.items(): + node_ids = unused_nodes[node_type] + if not node_ids: + continue + + logger.info(f"Deleting {len(node_ids)} {node_type} embeddings from vector database") + try: - await cognee.delete( - data_id=data.id, - dataset_id=dataset_data.dataset_id, - mode="hard", # Use hard mode to also remove orphaned entities - user=user - ) - deleted_count += 1 - logger.info(f"Deleted document {data.id} from dataset {dataset_data.dataset_id}") + if await vector_engine.has_collection(collection_name): + await vector_engine.delete_data_points( + collection_name, + [str(node_id) for node_id in node_ids] + ) except Exception as e: - logger.error(f"Failed to delete document {data.id}: {e}") - - logger.info("Cleanup completed", deleted_count=deleted_count) - - return { - "status": "completed", - "unused_count": len(unused_data), - "deleted_count": { - "data_items": deleted_count, - "documents": deleted_count - }, - "cleanup_date": datetime.now(timezone.utc).isoformat() - } - - -async def _find_unused_nodes( - cutoff_timestamp_ms: int, - user_id: Optional[UUID] = None, - provider: str = "kuzu" -) -> Dict[str, list]: - """ - Find unused nodes with provider-specific queries. - - Parameters - ---------- - cutoff_timestamp_ms : int - Cutoff timestamp in milliseconds since epoch - user_id : UUID, optional - Filter by user ID if provided - provider : str - Graph database provider (kuzu, neo4j, neptune) - - Returns - ------- - Dict[str, list] - Dictionary mapping node types to lists of unused node IDs - """ - graph_engine = await get_graph_engine() - - if provider == "kuzu": - return await _find_unused_nodes_kuzu(graph_engine, cutoff_timestamp_ms) - elif provider == "neo4j": - return await _find_unused_nodes_neo4j(graph_engine, cutoff_timestamp_ms) - elif provider == "neptune": - return await _find_unused_nodes_neptune(graph_engine, cutoff_timestamp_ms) - else: - logger.warning(f"Unsupported graph provider: {provider}") - return {"DocumentChunk": [], "Entity": [], "TextSummary": []} - - -async def _find_unused_nodes_kuzu(graph_engine, cutoff_timestamp_ms): - """Kuzu-specific unused node detection""" - query = "MATCH (n:Node) RETURN n.id, n.type, n.properties" - results = await graph_engine.query(query) - - unused_nodes = {"DocumentChunk": [], "Entity": [], "TextSummary": []} - - for node_id, node_type, props_json in results: - if node_type not in unused_nodes: - continue - - if props_json: - try: - props = json.loads(props_json) - last_accessed = props.get("last_accessed_at") - - if last_accessed is None or last_accessed < cutoff_timestamp_ms: - unused_nodes[node_type].append(node_id) - logger.debug( - f"Found unused {node_type}", - node_id=node_id, - last_accessed=last_accessed - ) - except json.JSONDecodeError: - logger.warning(f"Failed to parse properties for node {node_id}") - continue - - return unused_nodes - - -async def _find_unused_nodes_neo4j(graph_engine, cutoff_timestamp_ms): - """Neo4j-specific unused node detection""" - query = "MATCH (n:__Node__) RETURN n.id, n.type, n.last_accessed_at" - results = await graph_engine.query(query) - - unused_nodes = {"DocumentChunk": [], "Entity": [], "TextSummary": []} - - for row in results: - node_id = row["n"]["id"] - node_type = row["n"]["type"] - last_accessed = row["n"].get("last_accessed_at") - - if node_type not in unused_nodes: - continue - - if last_accessed is None or last_accessed < cutoff_timestamp_ms: - unused_nodes[node_type].append(node_id) - logger.debug( - f"Found unused {node_type}", - node_id=node_id, - last_accessed=last_accessed - ) - - return unused_nodes - - -async def _find_unused_nodes_neptune(graph_engine, cutoff_timestamp_ms): - """Neptune-specific unused node detection""" - query = "MATCH (n:Node) RETURN n.id, n.type, n.last_accessed_at" - results = await graph_engine.query(query) - - unused_nodes = {"DocumentChunk": [], "Entity": [], "TextSummary": []} - - for row in results: - node_id = row["n"]["id"] - node_type = row["n"]["type"] - last_accessed = row["n"].get("last_accessed_at") - - if node_type not in unused_nodes: - continue - - if last_accessed is None or last_accessed < cutoff_timestamp_ms: - unused_nodes[node_type].append(node_id) - logger.debug( - f"Found unused {node_type}", - node_id=node_id, - last_accessed=last_accessed - ) - - return unused_nodes - - -async def _delete_unused_nodes(unused_nodes: Dict[str, list], provider: str) -> Dict[str, int]: - """ - Delete unused nodes from graph and vector databases. - - Parameters - ---------- - unused_nodes : Dict[str, list] - Dictionary mapping node types to lists of node IDs to delete - provider : str - Graph database provider (kuzu, neo4j, neptune) - - Returns - ------- - Dict[str, int] - Count of deleted items by type - """ - graph_engine = await get_graph_engine() - vector_engine = get_vector_engine() - - deleted_counts = { - "DocumentChunk": 0, - "Entity": 0, - "TextSummary": 0, - "associations": 0 - } - - # Count associations before deletion - for node_type, node_ids in unused_nodes.items(): - if not node_ids: - continue - - # Count edges connected to these nodes - for node_id in node_ids: - if provider == "kuzu": - result = await graph_engine.query( - "MATCH (n:Node {id: $id})-[r:EDGE]-() RETURN count(r)", - {"id": node_id} - ) - elif provider == "neo4j": - result = await graph_engine.query( - "MATCH (n:__Node__ {id: $id})-[r:EDGE]-() RETURN count(r)", - {"id": node_id} - ) - elif provider == "neptune": - result = await graph_engine.query( - "MATCH (n:Node {id: $id})-[r:EDGE]-() RETURN count(r)", - {"id": node_id} - ) - - if result and len(result) > 0: - count = result[0][0] if provider == "kuzu" else result[0]["count_count(r)"] - deleted_counts["associations"] += count - - # Delete from graph database (uses DETACH DELETE, so edges are automatically removed) - for node_type, node_ids in unused_nodes.items(): - if not node_ids: - continue - - logger.info(f"Deleting {len(node_ids)} {node_type} nodes from graph database") - - # Delete nodes in batches - await graph_engine.delete_nodes(node_ids) - deleted_counts[node_type] = len(node_ids) - - # Delete from vector database - vector_collections = { - "DocumentChunk": "DocumentChunk_text", - "Entity": "Entity_name", - "TextSummary": "TextSummary_text" - } - - - for node_type, collection_name in vector_collections.items(): - node_ids = unused_nodes[node_type] - if not node_ids: - continue - - logger.info(f"Deleting {len(node_ids)} {node_type} embeddings from vector database") - - try: - if await vector_engine.has_collection(collection_name): - await vector_engine.delete_data_points( - collection_name, - [str(node_id) for node_id in node_ids] - ) - except Exception as e: - logger.error(f"Error deleting from vector collection {collection_name}: {e}") - + logger.error(f"Error deleting from vector collection {collection_name}: {e}") + return deleted_counts From f9b16e508d3e99dd34ee9e5f3b3ca893303f6faa Mon Sep 17 00:00:00 2001 From: ketanjain7981 Date: Tue, 2 Dec 2025 20:23:42 +0530 Subject: [PATCH 034/176] feat(database): add connect_args support to SqlAlchemyAdapter - Add optional connect_args parameter to __init__ method - Support DATABASE_CONNECT_ARGS environment variable for JSON-based configuration - Enable custom connection parameters for all database engines (SQLite and PostgreSQL) - Maintain backward compatibility with existing code - Add proper error handling and validation for environment variable parsing Signed-off-by: ketanjain7981 --- .env.template | 9 ++++++++ .../sqlalchemy/SqlAlchemyAdapter.py | 21 +++++++++++++++++-- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/.env.template b/.env.template index 61853b983..b93117c90 100644 --- a/.env.template +++ b/.env.template @@ -91,6 +91,15 @@ DB_NAME=cognee_db #DB_USERNAME=cognee #DB_PASSWORD=cognee +# -- Advanced: Custom database connection arguments (optional) --------------- +# Pass additional connection parameters as JSON. Useful for SSL, timeouts, etc. +# Examples: +# For PostgreSQL with SSL: +# DATABASE_CONNECT_ARGS='{"sslmode": "require", "connect_timeout": 10}' +# For SQLite with custom timeout: +# DATABASE_CONNECT_ARGS='{"timeout": 60}' +#DATABASE_CONNECT_ARGS='{}' + ################################################################################ # 🕸️ Graph Database settings ################################################################################ diff --git a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py index 380ce9917..a23c1b297 100644 --- a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +++ b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py @@ -3,6 +3,7 @@ import asyncio from os import path import tempfile from uuid import UUID +import json from typing import Optional from typing import AsyncGenerator, List from contextlib import asynccontextmanager @@ -29,10 +30,25 @@ class SQLAlchemyAdapter: functions. """ - def __init__(self, connection_string: str): + def __init__(self, connection_string: str, connect_args: Optional[dict] = None): self.db_path: str = None self.db_uri: str = connection_string + env_connect_args = os.getenv("DATABASE_CONNECT_ARGS") + if env_connect_args: + try: + env_connect_args = json.loads(env_connect_args) + if isinstance(env_connect_args, dict): + if connect_args is None: + connect_args = {} + connect_args.update(env_connect_args) + else: + logger.warning( + f"DATABASE_CONNECT_ARGS is not a valid JSON dictionary: {env_connect_args}" + ) + except json.JSONDecodeError as e: + logger.warning(f"Failed to parse DATABASE_CONNECT_ARGS as JSON: {e}") + if "sqlite" in connection_string: [prefix, db_path] = connection_string.split("///") self.db_path = db_path @@ -53,7 +69,7 @@ class SQLAlchemyAdapter: self.engine = create_async_engine( connection_string, poolclass=NullPool, - connect_args={"timeout": 30}, + connect_args={**(connect_args or {}), **{"timeout": 30}}, ) else: self.engine = create_async_engine( @@ -63,6 +79,7 @@ class SQLAlchemyAdapter: pool_recycle=280, pool_pre_ping=True, pool_timeout=280, + connect_args=connect_args or {}, ) self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False) From c892265644beaf61b778c10241b4770e00119530 Mon Sep 17 00:00:00 2001 From: ketanjain7981 Date: Tue, 2 Dec 2025 21:01:43 +0530 Subject: [PATCH 035/176] fix(database): address CodeRabbit review feedback - Add comprehensive docstring for __init__ method to meet 80% coverage requirement - Fix security issue: remove sensitive data from log messages - Fix merge precedence: programmatic args now correctly override env vars - Fix SQLite timeout order: user-specified timeout now overrides default 30s - Clarify precedence in docstring documentation Signed-off-by: ketanjain7981 --- .../sqlalchemy/SqlAlchemyAdapter.py | 38 +++++++++++++------ 1 file changed, 27 insertions(+), 11 deletions(-) diff --git a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py index a23c1b297..483228bfb 100644 --- a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +++ b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py @@ -31,23 +31,39 @@ class SQLAlchemyAdapter: """ def __init__(self, connection_string: str, connect_args: Optional[dict] = None): + """ + Initialize the SQLAlchemy adapter with connection settings. + + Parameters: + ----------- + connection_string (str): The database connection string (e.g., 'sqlite:///path/to/db' + or 'postgresql://user:pass@host:port/db'). + connect_args (Optional[dict]): Optional dictionary of connection arguments to pass to + the database engine. These are driver-specific parameters such as SSL settings, + timeouts, or connection pool options. If DATABASE_CONNECT_ARGS environment variable + is set, those values will be merged with this parameter (programmatic values take + precedence over environment variables). Defaults to None. + + Environment Variables: + ---------------------- + DATABASE_CONNECT_ARGS: Optional JSON string containing connection arguments. + Example: '{"sslmode": "require", "connect_timeout": 10}' + """ self.db_path: str = None self.db_uri: str = connection_string env_connect_args = os.getenv("DATABASE_CONNECT_ARGS") if env_connect_args: try: - env_connect_args = json.loads(env_connect_args) - if isinstance(env_connect_args, dict): - if connect_args is None: - connect_args = {} - connect_args.update(env_connect_args) + parsed_env_args = json.loads(env_connect_args) + if isinstance(parsed_env_args, dict): + # Merge: env vars as base, programmatic args override + merged_args = {**parsed_env_args, **(connect_args or {})} + connect_args = merged_args else: - logger.warning( - f"DATABASE_CONNECT_ARGS is not a valid JSON dictionary: {env_connect_args}" - ) - except json.JSONDecodeError as e: - logger.warning(f"Failed to parse DATABASE_CONNECT_ARGS as JSON: {e}") + logger.warning("DATABASE_CONNECT_ARGS is not a valid JSON dictionary, ignoring") + except json.JSONDecodeError: + logger.warning("Failed to parse DATABASE_CONNECT_ARGS as JSON, ignoring") if "sqlite" in connection_string: [prefix, db_path] = connection_string.split("///") @@ -69,7 +85,7 @@ class SQLAlchemyAdapter: self.engine = create_async_engine( connection_string, poolclass=NullPool, - connect_args={**(connect_args or {}), **{"timeout": 30}}, + connect_args={**{"timeout": 30}, **(connect_args or {})}, ) else: self.engine = create_async_engine( From 3f53534c992cff1b9893f7be08d5658f4dddd630 Mon Sep 17 00:00:00 2001 From: ketanjain7981 Date: Tue, 2 Dec 2025 21:12:07 +0530 Subject: [PATCH 036/176] refactor(database): simplify to env var only for connect_args - Remove unused connect_args parameter from __init__ - Programmatic parameter was dead code (never called by users) - Users call get_relational_engine() which doesn't expose connect_args - Keep DATABASE_CONNECT_ARGS env var support (actually used in production) - Simplify implementation and reduce complexity - Update docstring to reflect env-var-only approach - Add production examples to docstring Signed-off-by: ketanjain7981 --- .../sqlalchemy/SqlAlchemyAdapter.py | 32 +++++++++++-------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py index 483228bfb..3e800102a 100644 --- a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +++ b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py @@ -30,7 +30,7 @@ class SQLAlchemyAdapter: functions. """ - def __init__(self, connection_string: str, connect_args: Optional[dict] = None): + def __init__(self, connection_string: str): """ Initialize the SQLAlchemy adapter with connection settings. @@ -38,32 +38,38 @@ class SQLAlchemyAdapter: ----------- connection_string (str): The database connection string (e.g., 'sqlite:///path/to/db' or 'postgresql://user:pass@host:port/db'). - connect_args (Optional[dict]): Optional dictionary of connection arguments to pass to - the database engine. These are driver-specific parameters such as SSL settings, - timeouts, or connection pool options. If DATABASE_CONNECT_ARGS environment variable - is set, those values will be merged with this parameter (programmatic values take - precedence over environment variables). Defaults to None. Environment Variables: ---------------------- DATABASE_CONNECT_ARGS: Optional JSON string containing connection arguments. - Example: '{"sslmode": "require", "connect_timeout": 10}' + Allows configuration of driver-specific parameters such as SSL settings, + timeouts, or connection pool options without code changes. + + Examples: + PostgreSQL with SSL: + DATABASE_CONNECT_ARGS='{"sslmode": "require", "connect_timeout": 10}' + + SQLite with custom timeout: + DATABASE_CONNECT_ARGS='{"timeout": 60}' + + Note: This follows cognee's environment-based configuration pattern and is + the recommended approach for production deployments. """ self.db_path: str = None self.db_uri: str = connection_string + # Parse optional connection arguments from environment variable + connect_args = None env_connect_args = os.getenv("DATABASE_CONNECT_ARGS") if env_connect_args: try: - parsed_env_args = json.loads(env_connect_args) - if isinstance(parsed_env_args, dict): - # Merge: env vars as base, programmatic args override - merged_args = {**parsed_env_args, **(connect_args or {})} - connect_args = merged_args - else: + connect_args = json.loads(env_connect_args) + if not isinstance(connect_args, dict): logger.warning("DATABASE_CONNECT_ARGS is not a valid JSON dictionary, ignoring") + connect_args = None except json.JSONDecodeError: logger.warning("Failed to parse DATABASE_CONNECT_ARGS as JSON, ignoring") + connect_args = None if "sqlite" in connection_string: [prefix, db_path] = connection_string.split("///") From 5f00abf3e4f3b913ae67391d487104ea3b9ae872 Mon Sep 17 00:00:00 2001 From: chinu0609 Date: Tue, 2 Dec 2025 22:25:03 +0530 Subject: [PATCH 037/176] fix: fallback and document deletion --- .../retrieval/utils/access_tracking.py | 73 +++++++++++-------- cognee/tasks/cleanup/cleanup_unused_data.py | 41 +++++++---- 2 files changed, 68 insertions(+), 46 deletions(-) diff --git a/cognee/modules/retrieval/utils/access_tracking.py b/cognee/modules/retrieval/utils/access_tracking.py index 935c47157..c7b06ee17 100644 --- a/cognee/modules/retrieval/utils/access_tracking.py +++ b/cognee/modules/retrieval/utils/access_tracking.py @@ -36,16 +36,22 @@ async def update_node_access_timestamps(items: List[Any]): return try: - # Update nodes using graph projection ( database-agnostic approach + # Try to update nodes in graph database (may fail for unsupported DBs) await _update_nodes_via_projection(graph_engine, node_ids, timestamp_ms) + except Exception as e: + logger.warning( + f"Failed to update node timestamps in graph database: {e}. " + "Will update document-level timestamps in SQL instead." + ) - # Find origin documents and update SQL + # Always try to find origin documents and update SQL + # This ensures document-level tracking works even if graph updates fail + try: doc_ids = await _find_origin_documents_via_projection(graph_engine, node_ids) if doc_ids: await _update_sql_records(doc_ids, timestamp_dt) - except Exception as e: - logger.error(f"Failed to update timestamps: {e}") + logger.error(f"Failed to update SQL timestamps: {e}") raise async def _update_nodes_via_projection(graph_engine, node_ids, timestamp_ms): @@ -59,37 +65,42 @@ async def _update_nodes_via_projection(graph_engine, node_ids, timestamp_ms): ) # Update each node's last_accessed_at property + provider = os.getenv("GRAPH_DATABASE_PROVIDER", "kuzu").lower() + for node_id in node_ids: node = memory_fragment.get_node(node_id) if node: - # Update the node in the database - provider = os.getenv("GRAPH_DATABASE_PROVIDER", "kuzu").lower() - - if provider == "kuzu": - # Kuzu stores properties as JSON - result = await graph_engine.query( - "MATCH (n:Node {id: $id}) RETURN n.properties", - {"id": node_id} - ) - - if result and result[0]: - props = json.loads(result[0][0]) if result[0][0] else {} - props["last_accessed_at"] = timestamp_ms - - await graph_engine.query( - "MATCH (n:Node {id: $id}) SET n.properties = $props", - {"id": node_id, "props": json.dumps(props)} + try: + # Update the node in the database + if provider == "kuzu": + # Kuzu stores properties as JSON + result = await graph_engine.query( + "MATCH (n:Node {id: $id}) RETURN n.properties", + {"id": node_id} ) - elif provider == "neo4j": - await graph_engine.query( - "MATCH (n:__Node__ {id: $id}) SET n.last_accessed_at = $timestamp", - {"id": node_id, "timestamp": timestamp_ms} - ) - elif provider == "neptune": - await graph_engine.query( - "MATCH (n:Node {id: $id}) SET n.last_accessed_at = $timestamp", - {"id": node_id, "timestamp": timestamp_ms} - ) + + if result and result[0]: + props = json.loads(result[0][0]) if result[0][0] else {} + props["last_accessed_at"] = timestamp_ms + + await graph_engine.query( + "MATCH (n:Node {id: $id}) SET n.properties = $props", + {"id": node_id, "props": json.dumps(props)} + ) + elif provider == "neo4j": + await graph_engine.query( + "MATCH (n:__Node__ {id: $id}) SET n.last_accessed_at = $timestamp", + {"id": node_id, "timestamp": timestamp_ms} + ) + elif provider == "neptune": + await graph_engine.query( + "MATCH (n:Node {id: $id}) SET n.last_accessed_at = $timestamp", + {"id": node_id, "timestamp": timestamp_ms} + ) + except Exception as e: + # Log but continue with other nodes + logger.debug(f"Failed to update node {node_id}: {e}") + continue async def _find_origin_documents_via_projection(graph_engine, node_ids): """Find origin documents using graph projection instead of DB queries""" diff --git a/cognee/tasks/cleanup/cleanup_unused_data.py b/cognee/tasks/cleanup/cleanup_unused_data.py index c70b97a00..3894635dd 100644 --- a/cognee/tasks/cleanup/cleanup_unused_data.py +++ b/cognee/tasks/cleanup/cleanup_unused_data.py @@ -1,9 +1,9 @@ """ Task for automatically deleting unused data from the memify pipeline. -This task identifies and removes data (chunks, entities, summaries)) that hasn't +This task identifies and removes entire documents that haven't been accessed by retrievers for a specified period, helping maintain system -efficiency and storage optimization. +efficiency and storage optimization through whole-document removal. """ import json @@ -28,22 +28,26 @@ async def cleanup_unused_data( minutes_threshold: Optional[int], dry_run: bool = True, user_id: Optional[UUID] = None, - text_doc: bool = False + text_doc: bool = True, # Changed default to True for document-level cleanup + node_level: bool = False # New parameter for explicit node-level cleanup ) -> Dict[str, Any]: """ Identify and remove unused data from the memify pipeline. - + Parameters ---------- minutes_threshold : int - days since last access to consider data unused + Minutes since last access to consider data unused dry_run : bool - If True, only report what would be delete without actually deleting (default: True) + If True, only report what would be deleted without actually deleting (default: True) user_id : UUID, optional Limit cleanup to specific user's data (default: None) text_doc : bool - If True, use SQL-based filtering to find unused TextDocuments and call cognee.delete() - for proper whole-document deletion (default: False) + If True (default), use SQL-based filtering to find unused TextDocuments and call cognee.delete() + for proper whole-document deletion + node_level : bool + If True, perform chaotic node-level deletion of unused chunks, entities, and summaries + (default: False - deprecated in favor of document-level cleanup) Returns ------- @@ -91,17 +95,19 @@ async def cleanup_unused_data( minutes_threshold=minutes_threshold, dry_run=dry_run, user_id=str(user_id) if user_id else None, - text_doc=text_doc + text_doc=text_doc, + node_level=node_level ) # Calculate cutoff timestamp cutoff_date = datetime.now(timezone.utc) - timedelta(minutes=minutes_threshold) - if text_doc: - # SQL-based approach: Find unused TextDocuments and use cognee.delete() - return await _cleanup_via_sql(cutoff_date, dry_run, user_id) - else: - # Graph-based approach: Find unused nodes using projection (database-agnostic) + if node_level: + # Deprecated: Node-level approach (chaotic) + logger.warning( + "Node-level cleanup is deprecated and may lead to fragmented knowledge graphs. " + "Consider using document-level cleanup (default) instead." + ) cutoff_timestamp_ms = int(cutoff_date.timestamp() * 1000) logger.debug(f"Cutoff timestamp: {cutoff_date.isoformat()} ({cutoff_timestamp_ms}ms)") @@ -147,6 +153,9 @@ async def cleanup_unused_data( }, "cleanup_date": datetime.now(timezone.utc).isoformat() } + else: + # Default: Document-level approach (recommended) + return await _cleanup_via_sql(cutoff_date, dry_run, user_id) async def _cleanup_via_sql( @@ -243,6 +252,7 @@ async def _cleanup_via_sql( async def _find_unused_nodes_via_projection(cutoff_timestamp_ms: int) -> Dict[str, list]: """ Find unused nodes using graph projection - database-agnostic approach. + NOTE: This function is deprecated as it leads to fragmented knowledge graphs. Parameters ---------- @@ -291,6 +301,7 @@ async def _find_unused_nodes_via_projection(cutoff_timestamp_ms: int) -> Dict[st async def _delete_unused_nodes(unused_nodes: Dict[str, list]) -> Dict[str, int]: """ Delete unused nodes from graph and vector databases. + NOTE: This function is deprecated as it leads to fragmented knowledge graphs. Parameters ---------- @@ -325,7 +336,7 @@ async def _delete_unused_nodes(unused_nodes: Dict[str, list]) -> Dict[str, int]: if not node_ids: continue - # Count edges connected to these nodes + # Count edges from the in-memory graph for node_id in node_ids: node = memory_fragment.get_node(node_id) if node: From 4f3a1bcf012c5c380823ef7786de48a757505c37 Mon Sep 17 00:00:00 2001 From: ketanjain7981 Date: Tue, 2 Dec 2025 23:25:47 +0530 Subject: [PATCH 038/176] test: add unit tests for SQLAlchemyAdapter connection arguments Signed-off-by: ketanjain7981 --- .../sqlalchemy/test_SqlAlchemyAdapter.py | 84 +++++++++++++++++++ 1 file changed, 84 insertions(+) create mode 100644 cognee/tests/unit/infrastructure/databases/relational/sqlalchemy/test_SqlAlchemyAdapter.py diff --git a/cognee/tests/unit/infrastructure/databases/relational/sqlalchemy/test_SqlAlchemyAdapter.py b/cognee/tests/unit/infrastructure/databases/relational/sqlalchemy/test_SqlAlchemyAdapter.py new file mode 100644 index 000000000..bde5b9855 --- /dev/null +++ b/cognee/tests/unit/infrastructure/databases/relational/sqlalchemy/test_SqlAlchemyAdapter.py @@ -0,0 +1,84 @@ +from unittest.mock import patch +from cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter import ( + SQLAlchemyAdapter, +) + + +class TestSqlAlchemyAdapter: + @patch( + "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" + ) + @patch("os.getenv") + def test_sqlite_default_timeout(self, mock_getenv, mock_create_engine): + """Test that SQLite connection uses default timeout when no env var is set.""" + mock_getenv.return_value = None + SQLAlchemyAdapter("sqlite:///test.db") + mock_create_engine.assert_called_once() + _, kwargs = mock_create_engine.call_args + assert "connect_args" in kwargs + assert kwargs["connect_args"] == {"timeout": 30} + + @patch( + "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" + ) + @patch("os.getenv") + def test_sqlite_with_env_var_timeout(self, mock_getenv, mock_create_engine): + """Test that SQLite connection uses timeout from env var.""" + mock_getenv.return_value = '{"timeout": 60}' + SQLAlchemyAdapter("sqlite:///test.db") + mock_create_engine.assert_called_once() + _, kwargs = mock_create_engine.call_args + assert "connect_args" in kwargs + assert kwargs["connect_args"] == {"timeout": 60} + + @patch( + "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" + ) + @patch("os.getenv") + def test_sqlite_with_other_env_var_args(self, mock_getenv, mock_create_engine): + """Test that SQLite connection merges default timeout with other args from env var.""" + mock_getenv.return_value = '{"foo": "bar"}' + SQLAlchemyAdapter("sqlite:///test.db") + mock_create_engine.assert_called_once() + _, kwargs = mock_create_engine.call_args + assert "connect_args" in kwargs + assert kwargs["connect_args"] == {"timeout": 30, "foo": "bar"} + + @patch( + "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" + ) + @patch("os.getenv") + def test_sqlite_with_invalid_json_env_var(self, mock_getenv, mock_create_engine): + """Test that SQLite connection uses default timeout when env var has invalid JSON.""" + mock_getenv.return_value = '{"timeout": 60' # Invalid JSON + SQLAlchemyAdapter("sqlite:///test.db") + mock_create_engine.assert_called_once() + _, kwargs = mock_create_engine.call_args + assert "connect_args" in kwargs + assert kwargs["connect_args"] == {"timeout": 30} + + @patch( + "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" + ) + @patch("os.getenv") + def test_postgresql_with_env_var(self, mock_getenv, mock_create_engine): + """Test that PostgreSQL connection uses connect_args from env var.""" + mock_getenv.return_value = '{"sslmode": "require"}' + SQLAlchemyAdapter("postgresql://user:pass@host/db") + mock_create_engine.assert_called_once() + _, kwargs = mock_create_engine.call_args + assert "connect_args" in kwargs + assert kwargs["connect_args"] == {"sslmode": "require"} + + @patch( + "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" + ) + @patch("os.getenv") + def test_postgresql_without_env_var(self, mock_getenv, mock_create_engine): + """Test that PostgreSQL connection has empty connect_args when no env var is set.""" + mock_getenv.return_value = None + SQLAlchemyAdapter("postgresql://user:pass@host/db") + mock_create_engine.assert_called_once() + _, kwargs = mock_create_engine.call_args + assert "connect_args" in kwargs + assert kwargs["connect_args"] == {} From a7da9c7d655e705beceaf376a6cc6f3587be41d7 Mon Sep 17 00:00:00 2001 From: ketanjain7981 Date: Tue, 2 Dec 2025 23:35:35 +0530 Subject: [PATCH 039/176] test: verify logger warning for invalid JSON in SQLAlchemyAdapter Signed-off-by: ketanjain7981 --- .../relational/sqlalchemy/test_SqlAlchemyAdapter.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/cognee/tests/unit/infrastructure/databases/relational/sqlalchemy/test_SqlAlchemyAdapter.py b/cognee/tests/unit/infrastructure/databases/relational/sqlalchemy/test_SqlAlchemyAdapter.py index bde5b9855..abff77660 100644 --- a/cognee/tests/unit/infrastructure/databases/relational/sqlalchemy/test_SqlAlchemyAdapter.py +++ b/cognee/tests/unit/infrastructure/databases/relational/sqlalchemy/test_SqlAlchemyAdapter.py @@ -47,11 +47,17 @@ class TestSqlAlchemyAdapter: @patch( "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" ) + @patch("cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.logger") @patch("os.getenv") - def test_sqlite_with_invalid_json_env_var(self, mock_getenv, mock_create_engine): + def test_sqlite_with_invalid_json_env_var(self, mock_getenv, mock_logger, mock_create_engine): """Test that SQLite connection uses default timeout when env var has invalid JSON.""" mock_getenv.return_value = '{"timeout": 60' # Invalid JSON SQLAlchemyAdapter("sqlite:///test.db") + + mock_logger.warning.assert_called_with( + "Failed to parse DATABASE_CONNECT_ARGS as JSON, ignoring" + ) + mock_create_engine.assert_called_once() _, kwargs = mock_create_engine.call_args assert "connect_args" in kwargs From f26b490a8f5b3f0bb400396a4fb71771ee325ab0 Mon Sep 17 00:00:00 2001 From: ketanjain7981 Date: Wed, 3 Dec 2025 00:26:44 +0530 Subject: [PATCH 040/176] refactor: improve test isolation and add connect_args precedence Signed-off-by: ketanjain7981 --- .../sqlalchemy/SqlAlchemyAdapter.py | 21 +-- .../sqlalchemy/test_SqlAlchemyAdapter.py | 132 +++++++++++------- 2 files changed, 94 insertions(+), 59 deletions(-) diff --git a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py index 3e800102a..c7825041e 100644 --- a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +++ b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py @@ -30,7 +30,7 @@ class SQLAlchemyAdapter: functions. """ - def __init__(self, connection_string: str): + def __init__(self, connection_string: str, connect_args: dict = None): """ Initialize the SQLAlchemy adapter with connection settings. @@ -38,6 +38,8 @@ class SQLAlchemyAdapter: ----------- connection_string (str): The database connection string (e.g., 'sqlite:///path/to/db' or 'postgresql://user:pass@host:port/db'). + connect_args (dict, optional): Database driver arguments. These take precedence over + DATABASE_CONNECT_ARGS environment variable. Environment Variables: ---------------------- @@ -59,17 +61,20 @@ class SQLAlchemyAdapter: self.db_uri: str = connection_string # Parse optional connection arguments from environment variable - connect_args = None + env_connect_args_dict = {} env_connect_args = os.getenv("DATABASE_CONNECT_ARGS") if env_connect_args: try: - connect_args = json.loads(env_connect_args) - if not isinstance(connect_args, dict): + parsed_args = json.loads(env_connect_args) + if isinstance(parsed_args, dict): + env_connect_args_dict = parsed_args + else: logger.warning("DATABASE_CONNECT_ARGS is not a valid JSON dictionary, ignoring") - connect_args = None except json.JSONDecodeError: logger.warning("Failed to parse DATABASE_CONNECT_ARGS as JSON, ignoring") - connect_args = None + + # Merge environment args with explicit args (explicit args take precedence) + final_connect_args = {**env_connect_args_dict, **(connect_args or {})} if "sqlite" in connection_string: [prefix, db_path] = connection_string.split("///") @@ -91,7 +96,7 @@ class SQLAlchemyAdapter: self.engine = create_async_engine( connection_string, poolclass=NullPool, - connect_args={**{"timeout": 30}, **(connect_args or {})}, + connect_args={**{"timeout": 30}, **final_connect_args}, ) else: self.engine = create_async_engine( @@ -101,7 +106,7 @@ class SQLAlchemyAdapter: pool_recycle=280, pool_pre_ping=True, pool_timeout=280, - connect_args=connect_args or {}, + connect_args=final_connect_args, ) self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False) diff --git a/cognee/tests/unit/infrastructure/databases/relational/sqlalchemy/test_SqlAlchemyAdapter.py b/cognee/tests/unit/infrastructure/databases/relational/sqlalchemy/test_SqlAlchemyAdapter.py index abff77660..57c80ddcf 100644 --- a/cognee/tests/unit/infrastructure/databases/relational/sqlalchemy/test_SqlAlchemyAdapter.py +++ b/cognee/tests/unit/infrastructure/databases/relational/sqlalchemy/test_SqlAlchemyAdapter.py @@ -1,3 +1,4 @@ +import os from unittest.mock import patch from cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter import ( SQLAlchemyAdapter, @@ -5,86 +6,115 @@ from cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter imp class TestSqlAlchemyAdapter: + """Test suite for SqlAlchemyAdapter environment variable handling and connection arguments.""" + @patch( "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" ) - @patch("os.getenv") - def test_sqlite_default_timeout(self, mock_getenv, mock_create_engine): + def test_sqlite_default_timeout(self, mock_create_engine): """Test that SQLite connection uses default timeout when no env var is set.""" - mock_getenv.return_value = None - SQLAlchemyAdapter("sqlite:///test.db") - mock_create_engine.assert_called_once() - _, kwargs = mock_create_engine.call_args - assert "connect_args" in kwargs - assert kwargs["connect_args"] == {"timeout": 30} + with patch.dict(os.environ, {}, clear=True): + SQLAlchemyAdapter("sqlite:///test.db") + mock_create_engine.assert_called_once() + _, kwargs = mock_create_engine.call_args + assert "connect_args" in kwargs + assert kwargs["connect_args"] == {"timeout": 30} @patch( "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" ) - @patch("os.getenv") - def test_sqlite_with_env_var_timeout(self, mock_getenv, mock_create_engine): + def test_sqlite_with_env_var_timeout(self, mock_create_engine): """Test that SQLite connection uses timeout from env var.""" - mock_getenv.return_value = '{"timeout": 60}' - SQLAlchemyAdapter("sqlite:///test.db") - mock_create_engine.assert_called_once() - _, kwargs = mock_create_engine.call_args - assert "connect_args" in kwargs - assert kwargs["connect_args"] == {"timeout": 60} + with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60}'}): + SQLAlchemyAdapter("sqlite:///test.db") + mock_create_engine.assert_called_once() + _, kwargs = mock_create_engine.call_args + assert "connect_args" in kwargs + assert kwargs["connect_args"] == {"timeout": 60} @patch( "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" ) - @patch("os.getenv") - def test_sqlite_with_other_env_var_args(self, mock_getenv, mock_create_engine): + def test_sqlite_with_other_env_var_args(self, mock_create_engine): """Test that SQLite connection merges default timeout with other args from env var.""" - mock_getenv.return_value = '{"foo": "bar"}' - SQLAlchemyAdapter("sqlite:///test.db") - mock_create_engine.assert_called_once() - _, kwargs = mock_create_engine.call_args - assert "connect_args" in kwargs - assert kwargs["connect_args"] == {"timeout": 30, "foo": "bar"} + with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"foo": "bar"}'}): + SQLAlchemyAdapter("sqlite:///test.db") + mock_create_engine.assert_called_once() + _, kwargs = mock_create_engine.call_args + assert "connect_args" in kwargs + assert kwargs["connect_args"] == {"timeout": 30, "foo": "bar"} @patch( "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" ) @patch("cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.logger") - @patch("os.getenv") - def test_sqlite_with_invalid_json_env_var(self, mock_getenv, mock_logger, mock_create_engine): + def test_sqlite_with_invalid_json_env_var(self, mock_logger, mock_create_engine): """Test that SQLite connection uses default timeout when env var has invalid JSON.""" - mock_getenv.return_value = '{"timeout": 60' # Invalid JSON - SQLAlchemyAdapter("sqlite:///test.db") + with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60'}): # Invalid JSON + SQLAlchemyAdapter("sqlite:///test.db") - mock_logger.warning.assert_called_with( - "Failed to parse DATABASE_CONNECT_ARGS as JSON, ignoring" - ) + mock_logger.warning.assert_called_with( + "Failed to parse DATABASE_CONNECT_ARGS as JSON, ignoring" + ) - mock_create_engine.assert_called_once() - _, kwargs = mock_create_engine.call_args - assert "connect_args" in kwargs - assert kwargs["connect_args"] == {"timeout": 30} + mock_create_engine.assert_called_once() + _, kwargs = mock_create_engine.call_args + assert "connect_args" in kwargs + assert kwargs["connect_args"] == {"timeout": 30} @patch( "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" ) - @patch("os.getenv") - def test_postgresql_with_env_var(self, mock_getenv, mock_create_engine): + @patch("cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.logger") + def test_sqlite_with_non_dict_json_env_var(self, mock_logger, mock_create_engine): + """Test that SQLite connection uses default timeout when env var is valid JSON but not a dict.""" + with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '["list", "instead", "of", "dict"]'}): + SQLAlchemyAdapter("sqlite:///test.db") + + mock_logger.warning.assert_called_with( + "DATABASE_CONNECT_ARGS is not a valid JSON dictionary, ignoring" + ) + + mock_create_engine.assert_called_once() + _, kwargs = mock_create_engine.call_args + assert "connect_args" in kwargs + assert kwargs["connect_args"] == {"timeout": 30} + + @patch( + "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" + ) + def test_postgresql_with_env_var(self, mock_create_engine): """Test that PostgreSQL connection uses connect_args from env var.""" - mock_getenv.return_value = '{"sslmode": "require"}' - SQLAlchemyAdapter("postgresql://user:pass@host/db") - mock_create_engine.assert_called_once() - _, kwargs = mock_create_engine.call_args - assert "connect_args" in kwargs - assert kwargs["connect_args"] == {"sslmode": "require"} + with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"sslmode": "require"}'}): + SQLAlchemyAdapter("postgresql://user:pass@host/db") + mock_create_engine.assert_called_once() + _, kwargs = mock_create_engine.call_args + assert "connect_args" in kwargs + assert kwargs["connect_args"] == {"sslmode": "require"} @patch( "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" ) - @patch("os.getenv") - def test_postgresql_without_env_var(self, mock_getenv, mock_create_engine): + def test_postgresql_without_env_var(self, mock_create_engine): """Test that PostgreSQL connection has empty connect_args when no env var is set.""" - mock_getenv.return_value = None - SQLAlchemyAdapter("postgresql://user:pass@host/db") - mock_create_engine.assert_called_once() - _, kwargs = mock_create_engine.call_args - assert "connect_args" in kwargs - assert kwargs["connect_args"] == {} + with patch.dict(os.environ, {}, clear=True): + SQLAlchemyAdapter("postgresql://user:pass@host/db") + mock_create_engine.assert_called_once() + _, kwargs = mock_create_engine.call_args + assert "connect_args" in kwargs + assert kwargs["connect_args"] == {} + + @patch( + "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" + ) + def test_connect_args_precedence(self, mock_create_engine): + """Test that explicit connect_args take precedence over env var args.""" + with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60}'}): + # Pass explicit connect_args that should override env var + SQLAlchemyAdapter("sqlite:///test.db", connect_args={"timeout": 120}) + + mock_create_engine.assert_called_once() + _, kwargs = mock_create_engine.call_args + assert "connect_args" in kwargs + # timeout should be 120 (explicit), not 60 (env var) or 30 (default) + assert kwargs["connect_args"] == {"timeout": 120} From e1d313a46b962109644facdb51a34a3023299187 Mon Sep 17 00:00:00 2001 From: ketanjain7981 Date: Tue, 9 Dec 2025 10:14:46 +0530 Subject: [PATCH 041/176] move DATABASE_CONNECT_ARGS parsing to RelationalConfig Signed-off-by: ketanjain7981 --- .../databases/relational/config.py | 17 ++- .../relational/create_relational_engine.py | 4 +- .../sqlalchemy/SqlAlchemyAdapter.py | 32 +---- .../sqlalchemy/test_SqlAlchemyAdapter.py | 120 ------------------ .../relational/test_RelationalConfig.py | 69 ++++++++++ 5 files changed, 93 insertions(+), 149 deletions(-) delete mode 100644 cognee/tests/unit/infrastructure/databases/relational/sqlalchemy/test_SqlAlchemyAdapter.py create mode 100644 cognee/tests/unit/infrastructure/databases/relational/test_RelationalConfig.py diff --git a/cognee/infrastructure/databases/relational/config.py b/cognee/infrastructure/databases/relational/config.py index ff7c410a1..fae7ca329 100644 --- a/cognee/infrastructure/databases/relational/config.py +++ b/cognee/infrastructure/databases/relational/config.py @@ -1,4 +1,5 @@ import os +import json import pydantic from typing import Union from functools import lru_cache @@ -19,6 +20,7 @@ class RelationalConfig(BaseSettings): db_username: Union[str, None] = None # "cognee" db_password: Union[str, None] = None # "cognee" db_provider: str = "sqlite" + database_connect_args: Union[str, None] = None model_config = SettingsConfigDict(env_file=".env", extra="allow") @@ -30,6 +32,17 @@ class RelationalConfig(BaseSettings): databases_directory_path = os.path.join(base_config.system_root_directory, "databases") self.db_path = databases_directory_path + # Parse database_connect_args if provided as JSON string + if self.database_connect_args and isinstance(self.database_connect_args, str): + try: + parsed_args = json.loads(self.database_connect_args) + if isinstance(parsed_args, dict): + self.database_connect_args = parsed_args + else: + self.database_connect_args = {} + except json.JSONDecodeError: + self.database_connect_args = {} + return self def to_dict(self) -> dict: @@ -40,7 +53,8 @@ class RelationalConfig(BaseSettings): -------- - dict: A dictionary containing database configuration settings including db_path, - db_name, db_host, db_port, db_username, db_password, and db_provider. + db_name, db_host, db_port, db_username, db_password, db_provider, and + database_connect_args. """ return { "db_path": self.db_path, @@ -50,6 +64,7 @@ class RelationalConfig(BaseSettings): "db_username": self.db_username, "db_password": self.db_password, "db_provider": self.db_provider, + "database_connect_args": self.database_connect_args, } diff --git a/cognee/infrastructure/databases/relational/create_relational_engine.py b/cognee/infrastructure/databases/relational/create_relational_engine.py index deaeaa2da..b2a79c818 100644 --- a/cognee/infrastructure/databases/relational/create_relational_engine.py +++ b/cognee/infrastructure/databases/relational/create_relational_engine.py @@ -11,6 +11,7 @@ def create_relational_engine( db_username: str, db_password: str, db_provider: str, + database_connect_args: dict = None, ): """ Create a relational database engine based on the specified parameters. @@ -29,6 +30,7 @@ def create_relational_engine( - db_password (str): The password for database authentication, required for PostgreSQL. - db_provider (str): The type of database provider (e.g., 'sqlite' or 'postgres'). + - database_connect_args (dict, optional): Database driver connection arguments. Returns: -------- @@ -51,4 +53,4 @@ def create_relational_engine( "PostgreSQL dependencies are not installed. Please install with 'pip install cognee\"[postgres]\"' or 'pip install cognee\"[postgres-binary]\"' to use PostgreSQL functionality." ) - return SQLAlchemyAdapter(connection_string) + return SQLAlchemyAdapter(connection_string, connect_args=database_connect_args) diff --git a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py index c7825041e..37ceb170d 100644 --- a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +++ b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py @@ -3,7 +3,6 @@ import asyncio from os import path import tempfile from uuid import UUID -import json from typing import Optional from typing import AsyncGenerator, List from contextlib import asynccontextmanager @@ -38,14 +37,9 @@ class SQLAlchemyAdapter: ----------- connection_string (str): The database connection string (e.g., 'sqlite:///path/to/db' or 'postgresql://user:pass@host:port/db'). - connect_args (dict, optional): Database driver arguments. These take precedence over - DATABASE_CONNECT_ARGS environment variable. - - Environment Variables: - ---------------------- - DATABASE_CONNECT_ARGS: Optional JSON string containing connection arguments. - Allows configuration of driver-specific parameters such as SSL settings, - timeouts, or connection pool options without code changes. + connect_args (dict, optional): Database driver connection arguments. + Configuration is loaded from RelationalConfig.database_connect_args, which reads + from the DATABASE_CONNECT_ARGS environment variable. Examples: PostgreSQL with SSL: @@ -53,28 +47,12 @@ class SQLAlchemyAdapter: SQLite with custom timeout: DATABASE_CONNECT_ARGS='{"timeout": 60}' - - Note: This follows cognee's environment-based configuration pattern and is - the recommended approach for production deployments. """ self.db_path: str = None self.db_uri: str = connection_string - # Parse optional connection arguments from environment variable - env_connect_args_dict = {} - env_connect_args = os.getenv("DATABASE_CONNECT_ARGS") - if env_connect_args: - try: - parsed_args = json.loads(env_connect_args) - if isinstance(parsed_args, dict): - env_connect_args_dict = parsed_args - else: - logger.warning("DATABASE_CONNECT_ARGS is not a valid JSON dictionary, ignoring") - except json.JSONDecodeError: - logger.warning("Failed to parse DATABASE_CONNECT_ARGS as JSON, ignoring") - - # Merge environment args with explicit args (explicit args take precedence) - final_connect_args = {**env_connect_args_dict, **(connect_args or {})} + # Use provided connect_args (already parsed from config) + final_connect_args = connect_args or {} if "sqlite" in connection_string: [prefix, db_path] = connection_string.split("///") diff --git a/cognee/tests/unit/infrastructure/databases/relational/sqlalchemy/test_SqlAlchemyAdapter.py b/cognee/tests/unit/infrastructure/databases/relational/sqlalchemy/test_SqlAlchemyAdapter.py deleted file mode 100644 index 57c80ddcf..000000000 --- a/cognee/tests/unit/infrastructure/databases/relational/sqlalchemy/test_SqlAlchemyAdapter.py +++ /dev/null @@ -1,120 +0,0 @@ -import os -from unittest.mock import patch -from cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter import ( - SQLAlchemyAdapter, -) - - -class TestSqlAlchemyAdapter: - """Test suite for SqlAlchemyAdapter environment variable handling and connection arguments.""" - - @patch( - "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" - ) - def test_sqlite_default_timeout(self, mock_create_engine): - """Test that SQLite connection uses default timeout when no env var is set.""" - with patch.dict(os.environ, {}, clear=True): - SQLAlchemyAdapter("sqlite:///test.db") - mock_create_engine.assert_called_once() - _, kwargs = mock_create_engine.call_args - assert "connect_args" in kwargs - assert kwargs["connect_args"] == {"timeout": 30} - - @patch( - "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" - ) - def test_sqlite_with_env_var_timeout(self, mock_create_engine): - """Test that SQLite connection uses timeout from env var.""" - with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60}'}): - SQLAlchemyAdapter("sqlite:///test.db") - mock_create_engine.assert_called_once() - _, kwargs = mock_create_engine.call_args - assert "connect_args" in kwargs - assert kwargs["connect_args"] == {"timeout": 60} - - @patch( - "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" - ) - def test_sqlite_with_other_env_var_args(self, mock_create_engine): - """Test that SQLite connection merges default timeout with other args from env var.""" - with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"foo": "bar"}'}): - SQLAlchemyAdapter("sqlite:///test.db") - mock_create_engine.assert_called_once() - _, kwargs = mock_create_engine.call_args - assert "connect_args" in kwargs - assert kwargs["connect_args"] == {"timeout": 30, "foo": "bar"} - - @patch( - "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" - ) - @patch("cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.logger") - def test_sqlite_with_invalid_json_env_var(self, mock_logger, mock_create_engine): - """Test that SQLite connection uses default timeout when env var has invalid JSON.""" - with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60'}): # Invalid JSON - SQLAlchemyAdapter("sqlite:///test.db") - - mock_logger.warning.assert_called_with( - "Failed to parse DATABASE_CONNECT_ARGS as JSON, ignoring" - ) - - mock_create_engine.assert_called_once() - _, kwargs = mock_create_engine.call_args - assert "connect_args" in kwargs - assert kwargs["connect_args"] == {"timeout": 30} - - @patch( - "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" - ) - @patch("cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.logger") - def test_sqlite_with_non_dict_json_env_var(self, mock_logger, mock_create_engine): - """Test that SQLite connection uses default timeout when env var is valid JSON but not a dict.""" - with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '["list", "instead", "of", "dict"]'}): - SQLAlchemyAdapter("sqlite:///test.db") - - mock_logger.warning.assert_called_with( - "DATABASE_CONNECT_ARGS is not a valid JSON dictionary, ignoring" - ) - - mock_create_engine.assert_called_once() - _, kwargs = mock_create_engine.call_args - assert "connect_args" in kwargs - assert kwargs["connect_args"] == {"timeout": 30} - - @patch( - "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" - ) - def test_postgresql_with_env_var(self, mock_create_engine): - """Test that PostgreSQL connection uses connect_args from env var.""" - with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"sslmode": "require"}'}): - SQLAlchemyAdapter("postgresql://user:pass@host/db") - mock_create_engine.assert_called_once() - _, kwargs = mock_create_engine.call_args - assert "connect_args" in kwargs - assert kwargs["connect_args"] == {"sslmode": "require"} - - @patch( - "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" - ) - def test_postgresql_without_env_var(self, mock_create_engine): - """Test that PostgreSQL connection has empty connect_args when no env var is set.""" - with patch.dict(os.environ, {}, clear=True): - SQLAlchemyAdapter("postgresql://user:pass@host/db") - mock_create_engine.assert_called_once() - _, kwargs = mock_create_engine.call_args - assert "connect_args" in kwargs - assert kwargs["connect_args"] == {} - - @patch( - "cognee.infrastructure.databases.relational.sqlalchemy.SqlAlchemyAdapter.create_async_engine" - ) - def test_connect_args_precedence(self, mock_create_engine): - """Test that explicit connect_args take precedence over env var args.""" - with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60}'}): - # Pass explicit connect_args that should override env var - SQLAlchemyAdapter("sqlite:///test.db", connect_args={"timeout": 120}) - - mock_create_engine.assert_called_once() - _, kwargs = mock_create_engine.call_args - assert "connect_args" in kwargs - # timeout should be 120 (explicit), not 60 (env var) or 30 (default) - assert kwargs["connect_args"] == {"timeout": 120} diff --git a/cognee/tests/unit/infrastructure/databases/relational/test_RelationalConfig.py b/cognee/tests/unit/infrastructure/databases/relational/test_RelationalConfig.py new file mode 100644 index 000000000..8bbfc2450 --- /dev/null +++ b/cognee/tests/unit/infrastructure/databases/relational/test_RelationalConfig.py @@ -0,0 +1,69 @@ +import os +from unittest.mock import patch +from cognee.infrastructure.databases.relational.config import RelationalConfig + + +class TestRelationalConfig: + """Test suite for RelationalConfig DATABASE_CONNECT_ARGS parsing.""" + + def test_database_connect_args_valid_json_dict(self): + """Test that DATABASE_CONNECT_ARGS is parsed correctly when it's a valid JSON dict.""" + with patch.dict( + os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60, "sslmode": "require"}'} + ): + config = RelationalConfig() + assert config.database_connect_args == {"timeout": 60, "sslmode": "require"} + + def test_database_connect_args_empty_string(self): + """Test that empty DATABASE_CONNECT_ARGS is handled correctly.""" + with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": ""}): + config = RelationalConfig() + assert config.database_connect_args == "" + + def test_database_connect_args_not_set(self): + """Test that missing DATABASE_CONNECT_ARGS results in None.""" + with patch.dict(os.environ, {}, clear=True): + config = RelationalConfig() + assert config.database_connect_args is None + + def test_database_connect_args_invalid_json(self): + """Test that invalid JSON in DATABASE_CONNECT_ARGS results in empty dict.""" + with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60'}): # Invalid JSON + config = RelationalConfig() + assert config.database_connect_args == {} + + def test_database_connect_args_non_dict_json(self): + """Test that non-dict JSON in DATABASE_CONNECT_ARGS results in empty dict.""" + with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '["list", "instead", "of", "dict"]'}): + config = RelationalConfig() + assert config.database_connect_args == {} + + def test_database_connect_args_to_dict(self): + """Test that database_connect_args is included in to_dict() output.""" + with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"timeout": 60}'}): + config = RelationalConfig() + config_dict = config.to_dict() + assert "database_connect_args" in config_dict + assert config_dict["database_connect_args"] == {"timeout": 60} + + def test_database_connect_args_integer_value(self): + """Test that DATABASE_CONNECT_ARGS with integer values is parsed correctly.""" + with patch.dict(os.environ, {"DATABASE_CONNECT_ARGS": '{"connect_timeout": 10}'}): + config = RelationalConfig() + assert config.database_connect_args == {"connect_timeout": 10} + + def test_database_connect_args_mixed_types(self): + """Test that DATABASE_CONNECT_ARGS with mixed value types is parsed correctly.""" + with patch.dict( + os.environ, + { + "DATABASE_CONNECT_ARGS": '{"timeout": 60, "sslmode": "require", "retries": 3, "keepalive": true}' + }, + ): + config = RelationalConfig() + assert config.database_connect_args == { + "timeout": 60, + "sslmode": "require", + "retries": 3, + "keepalive": True, + } From 28faf7ce04cecba0a522bdf3db1b43e907535523 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Tue, 9 Dec 2025 17:53:18 +0100 Subject: [PATCH 042/176] test: Add permission example test with running s3 file system --- .github/workflows/examples_tests.yml | 60 +++++++++++++++++++++++++++- 1 file changed, 59 insertions(+), 1 deletion(-) diff --git a/.github/workflows/examples_tests.yml b/.github/workflows/examples_tests.yml index f7cc278cb..ab0138d62 100644 --- a/.github/workflows/examples_tests.yml +++ b/.github/workflows/examples_tests.yml @@ -256,7 +256,7 @@ jobs: with: python-version: '3.11.x' - - name: Run Memify Tests + - name: Run Permissions Example env: ENV: 'dev' OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} @@ -269,6 +269,64 @@ jobs: EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} run: uv run python ./examples/python/permissions_example.py + + test-s3-permissions-example: # Make sure permission and multi-user mode work with S3 file system + name: Run Permissions Example + runs-on: ubuntu-22.04 + defaults: + run: + shell: bash + services: + postgres: # Using postgres to avoid storing and using SQLite from S3 + image: pgvector/pgvector:pg17 + env: + POSTGRES_USER: cognee + POSTGRES_PASSWORD: cognee + POSTGRES_DB: cognee_db + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + ports: + - 5432:5432 + steps: + - name: Check out repository + uses: actions/checkout@v4 + + - name: Cognee Setup + uses: ./.github/actions/cognee_setup + with: + python-version: '3.11.x' + extra-dependencies: "aws" + + - name: Run S3 Permissions Example + env: + ENV: 'dev' + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + LLM_MODEL: ${{ secrets.LLM_MODEL }} + LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} + LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} + EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} + EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} + EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} + EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} + STORAGE_BACKEND: 's3' + AWS_ENDPOINT_URL: https://s3-eu-west-1.amazonaws.com + AWS_REGION: eu-west-1 + DATA_ROOT_DIRECTORY: "s3://cognee-temp/cognee/data" + SYSTEM_ROOT_DIRECTORY: "s3://cognee-temp/cognee/system" + AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + DB_PROVIDER: 'postgres' + DB_NAME: 'cognee_db' + DB_HOST: '127.0.0.1' + DB_PORT: 5432 + DB_USERNAME: cognee + DB_PASSWORD: cognee + run: uv run python ./examples/python/permissions_example.py + test_docling_add: name: Run Add with Docling Test runs-on: macos-15 From 032a74a409b6f6245a369680353c5d7ba4472035 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Tue, 9 Dec 2025 17:56:34 +0100 Subject: [PATCH 043/176] chore: add postgres dependency for cicd test --- .github/workflows/examples_tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/examples_tests.yml b/.github/workflows/examples_tests.yml index ab0138d62..14b065356 100644 --- a/.github/workflows/examples_tests.yml +++ b/.github/workflows/examples_tests.yml @@ -298,7 +298,7 @@ jobs: uses: ./.github/actions/cognee_setup with: python-version: '3.11.x' - extra-dependencies: "aws" + extra-dependencies: "postgres aws" - name: Run S3 Permissions Example env: From d57d1884599257a1305395212935fc352c1caf84 Mon Sep 17 00:00:00 2001 From: rajeevrajeshuni Date: Wed, 10 Dec 2025 10:52:10 +0530 Subject: [PATCH 044/176] resolving merge conflicts --- .../litellm_instructor/llm/ollama/adapter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py index ec7addcaf..abcd21f86 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py @@ -69,7 +69,7 @@ class OllamaAPIAdapter(LLMInterface): @retry( stop=stop_after_delay(128), - wait=wait_exponential_jitter(2, 128), + wait=wait_exponential_jitter(8, 128), retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError), before_sleep=before_sleep_log(logger, logging.DEBUG), reraise=True, @@ -117,7 +117,7 @@ class OllamaAPIAdapter(LLMInterface): @retry( stop=stop_after_delay(128), - wait=wait_exponential_jitter(2, 128), + wait=wait_exponential_jitter(8, 128), retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError), before_sleep=before_sleep_log(logger, logging.DEBUG), reraise=True, From ab20443330bd047415698f4a2d81e3427dbc272c Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Wed, 10 Dec 2025 12:35:58 +0100 Subject: [PATCH 045/176] chore: Change s3 bucket for permission example --- .github/workflows/examples_tests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/examples_tests.yml b/.github/workflows/examples_tests.yml index 14b065356..3fd48523c 100644 --- a/.github/workflows/examples_tests.yml +++ b/.github/workflows/examples_tests.yml @@ -315,8 +315,8 @@ jobs: STORAGE_BACKEND: 's3' AWS_ENDPOINT_URL: https://s3-eu-west-1.amazonaws.com AWS_REGION: eu-west-1 - DATA_ROOT_DIRECTORY: "s3://cognee-temp/cognee/data" - SYSTEM_ROOT_DIRECTORY: "s3://cognee-temp/cognee/system" + DATA_ROOT_DIRECTORY: "s3://github-runner-cognee-tests/cognee/data" + SYSTEM_ROOT_DIRECTORY: "s3://github-runner-cognee-tests/cognee/system" AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} DB_PROVIDER: 'postgres' From 4d0f1328225f024059f52aaa00957743fc82aa26 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Wed, 10 Dec 2025 16:22:40 +0100 Subject: [PATCH 046/176] chore: Remove AWS url --- .github/workflows/examples_tests.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/examples_tests.yml b/.github/workflows/examples_tests.yml index 3fd48523c..1a3d868c4 100644 --- a/.github/workflows/examples_tests.yml +++ b/.github/workflows/examples_tests.yml @@ -313,7 +313,6 @@ jobs: EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} STORAGE_BACKEND: 's3' - AWS_ENDPOINT_URL: https://s3-eu-west-1.amazonaws.com AWS_REGION: eu-west-1 DATA_ROOT_DIRECTORY: "s3://github-runner-cognee-tests/cognee/data" SYSTEM_ROOT_DIRECTORY: "s3://github-runner-cognee-tests/cognee/system" From 829a6f0d04bcfec6e9c9f94219a29d6ab5cd909d Mon Sep 17 00:00:00 2001 From: chinu0609 Date: Wed, 10 Dec 2025 22:41:01 +0530 Subject: [PATCH 047/176] fix: only document level deletion --- .../retrieval/utils/access_tracking.py | 80 +-- cognee/tasks/cleanup/cleanup_unused_data.py | 521 ++++++------------ cognee/tests/test_cleanup_unused_data.py | 388 ++++++------- 3 files changed, 333 insertions(+), 656 deletions(-) diff --git a/cognee/modules/retrieval/utils/access_tracking.py b/cognee/modules/retrieval/utils/access_tracking.py index c7b06ee17..54fd043b9 100644 --- a/cognee/modules/retrieval/utils/access_tracking.py +++ b/cognee/modules/retrieval/utils/access_tracking.py @@ -4,7 +4,7 @@ import json from datetime import datetime, timezone from typing import List, Any from uuid import UUID -import os +import os from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.relational import get_relational_engine from cognee.modules.data.models import Data @@ -14,38 +14,28 @@ from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph logger = get_logger(__name__) + async def update_node_access_timestamps(items: List[Any]): if os.getenv("ENABLE_LAST_ACCESSED", "false").lower() != "true": return - + if not items: return - + graph_engine = await get_graph_engine() - timestamp_ms = int(datetime.now(timezone.utc).timestamp() * 1000) timestamp_dt = datetime.now(timezone.utc) - + # Extract node IDs node_ids = [] for item in items: item_id = item.payload.get("id") if hasattr(item, 'payload') else item.get("id") if item_id: node_ids.append(str(item_id)) - + if not node_ids: return - - try: - # Try to update nodes in graph database (may fail for unsupported DBs) - await _update_nodes_via_projection(graph_engine, node_ids, timestamp_ms) - except Exception as e: - logger.warning( - f"Failed to update node timestamps in graph database: {e}. " - "Will update document-level timestamps in SQL instead." - ) - - # Always try to find origin documents and update SQL - # This ensures document-level tracking works even if graph updates fail + + # Focus on document-level tracking via projection try: doc_ids = await _find_origin_documents_via_projection(graph_engine, node_ids) if doc_ids: @@ -54,53 +44,6 @@ async def update_node_access_timestamps(items: List[Any]): logger.error(f"Failed to update SQL timestamps: {e}") raise -async def _update_nodes_via_projection(graph_engine, node_ids, timestamp_ms): - """Update nodes using graph projection - works with any graph database""" - # Project the graph with necessary properties - memory_fragment = CogneeGraph() - await memory_fragment.project_graph_from_db( - graph_engine, - node_properties_to_project=["id"], - edge_properties_to_project=[] - ) - - # Update each node's last_accessed_at property - provider = os.getenv("GRAPH_DATABASE_PROVIDER", "kuzu").lower() - - for node_id in node_ids: - node = memory_fragment.get_node(node_id) - if node: - try: - # Update the node in the database - if provider == "kuzu": - # Kuzu stores properties as JSON - result = await graph_engine.query( - "MATCH (n:Node {id: $id}) RETURN n.properties", - {"id": node_id} - ) - - if result and result[0]: - props = json.loads(result[0][0]) if result[0][0] else {} - props["last_accessed_at"] = timestamp_ms - - await graph_engine.query( - "MATCH (n:Node {id: $id}) SET n.properties = $props", - {"id": node_id, "props": json.dumps(props)} - ) - elif provider == "neo4j": - await graph_engine.query( - "MATCH (n:__Node__ {id: $id}) SET n.last_accessed_at = $timestamp", - {"id": node_id, "timestamp": timestamp_ms} - ) - elif provider == "neptune": - await graph_engine.query( - "MATCH (n:Node {id: $id}) SET n.last_accessed_at = $timestamp", - {"id": node_id, "timestamp": timestamp_ms} - ) - except Exception as e: - # Log but continue with other nodes - logger.debug(f"Failed to update node {node_id}: {e}") - continue async def _find_origin_documents_via_projection(graph_engine, node_ids): """Find origin documents using graph projection instead of DB queries""" @@ -111,7 +54,7 @@ async def _find_origin_documents_via_projection(graph_engine, node_ids): node_properties_to_project=["id", "type"], edge_properties_to_project=["relationship_name"] ) - + # Find origin documents by traversing the in-memory graph doc_ids = set() for node_id in node_ids: @@ -123,9 +66,10 @@ async def _find_origin_documents_via_projection(graph_engine, node_ids): neighbor = edge.get_destination_node() if edge.get_source_node().id == node_id else edge.get_source_node() if neighbor and neighbor.get_attribute("type") in ["TextDocument", "Document"]: doc_ids.add(neighbor.id) - + return list(doc_ids) + async def _update_sql_records(doc_ids, timestamp_dt): """Update SQL Data table (same for all providers)""" db_engine = get_relational_engine() @@ -133,6 +77,6 @@ async def _update_sql_records(doc_ids, timestamp_dt): stmt = update(Data).where( Data.id.in_([UUID(doc_id) for doc_id in doc_ids]) ).values(last_accessed=timestamp_dt) - + await session.execute(stmt) await session.commit() diff --git a/cognee/tasks/cleanup/cleanup_unused_data.py b/cognee/tasks/cleanup/cleanup_unused_data.py index 3894635dd..34cde1b6f 100644 --- a/cognee/tasks/cleanup/cleanup_unused_data.py +++ b/cognee/tasks/cleanup/cleanup_unused_data.py @@ -1,382 +1,187 @@ -""" -Task for automatically deleting unused data from the memify pipeline. +""" +Task for automatically deleting unused data from the memify pipeline. + +This task identifies and removes entire documents that haven't +been accessed by retrievers for a specified period, helping maintain system +efficiency and storage optimization through whole-document removal. +""" + +import json +from datetime import datetime, timezone, timedelta +from typing import Optional, Dict, Any +from uuid import UUID +import os +from cognee.infrastructure.databases.graph import get_graph_engine +from cognee.infrastructure.databases.vector import get_vector_engine +from cognee.infrastructure.databases.relational import get_relational_engine +from cognee.modules.data.models import Data, DatasetData +from cognee.shared.logging_utils import get_logger +from sqlalchemy import select, or_ +import cognee +import sqlalchemy as sa +from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph + +logger = get_logger(__name__) + + +async def cleanup_unused_data( + minutes_threshold: Optional[int], + dry_run: bool = True, + user_id: Optional[UUID] = None +) -> Dict[str, Any]: + """ + Identify and remove unused data from the memify pipeline. + + Parameters + ---------- + minutes_threshold : int + Minutes since last access to consider data unused + dry_run : bool + If True, only report what would be deleted without actually deleting (default: True) + user_id : UUID, optional + Limit cleanup to specific user's data (default: None) + + Returns + ------- + Dict[str, Any] + Cleanup results with status, counts, and timestamp + """ + # Check 1: Environment variable must be enabled + if os.getenv("ENABLE_LAST_ACCESSED", "false").lower() != "true": + logger.warning( + "Cleanup skipped: ENABLE_LAST_ACCESSED is not enabled." + ) + return { + "status": "skipped", + "reason": "ENABLE_LAST_ACCESSED not enabled", + "unused_count": 0, + "deleted_count": {}, + "cleanup_date": datetime.now(timezone.utc).isoformat() + } + + # Check 2: Verify tracking has actually been running + db_engine = get_relational_engine() + async with db_engine.get_async_session() as session: + # Count records with non-NULL last_accessed + tracked_count = await session.execute( + select(sa.func.count(Data.id)).where(Data.last_accessed.isnot(None)) + ) + tracked_records = tracked_count.scalar() + + if tracked_records == 0: + logger.warning( + "Cleanup skipped: No records have been tracked yet. " + "ENABLE_LAST_ACCESSED may have been recently enabled. " + "Wait for retrievers to update timestamps before running cleanup." + ) + return { + "status": "skipped", + "reason": "No tracked records found - tracking may be newly enabled", + "unused_count": 0, + "deleted_count": {}, + "cleanup_date": datetime.now(timezone.utc).isoformat() + } + + logger.info( + "Starting cleanup task", + minutes_threshold=minutes_threshold, + dry_run=dry_run, + user_id=str(user_id) if user_id else None + ) + + # Calculate cutoff timestamp + cutoff_date = datetime.now(timezone.utc) - timedelta(minutes=minutes_threshold) -This task identifies and removes entire documents that haven't -been accessed by retrievers for a specified period, helping maintain system -efficiency and storage optimization through whole-document removal. -""" - -import json -from datetime import datetime, timezone, timedelta -from typing import Optional, Dict, Any -from uuid import UUID -import os -from cognee.infrastructure.databases.graph import get_graph_engine -from cognee.infrastructure.databases.vector import get_vector_engine -from cognee.infrastructure.databases.relational import get_relational_engine -from cognee.modules.data.models import Data, DatasetData -from cognee.shared.logging_utils import get_logger -from sqlalchemy import select, or_ -import cognee -import sqlalchemy as sa -from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph - -logger = get_logger(__name__) + # Document-level approach (recommended) + return await _cleanup_via_sql(cutoff_date, dry_run, user_id) -async def cleanup_unused_data( - minutes_threshold: Optional[int], - dry_run: bool = True, - user_id: Optional[UUID] = None, - text_doc: bool = True, # Changed default to True for document-level cleanup - node_level: bool = False # New parameter for explicit node-level cleanup +async def _cleanup_via_sql( + cutoff_date: datetime, + dry_run: bool, + user_id: Optional[UUID] = None ) -> Dict[str, Any]: """ - Identify and remove unused data from the memify pipeline. - + SQL-based cleanup: Query Data table for unused documents and use cognee.delete(). + Parameters ---------- - minutes_threshold : int - Minutes since last access to consider data unused + cutoff_date : datetime + Cutoff date for last_accessed filtering dry_run : bool - If True, only report what would be deleted without actually deleting (default: True) + If True, only report what would be deleted user_id : UUID, optional - Limit cleanup to specific user's data (default: None) - text_doc : bool - If True (default), use SQL-based filtering to find unused TextDocuments and call cognee.delete() - for proper whole-document deletion - node_level : bool - If True, perform chaotic node-level deletion of unused chunks, entities, and summaries - (default: False - deprecated in favor of document-level cleanup) + Filter by user ID if provided Returns ------- Dict[str, Any] - Cleanup results with status, counts, and timestamp - """ - # Check 1: Environment variable must be enabled - if os.getenv("ENABLE_LAST_ACCESSED", "false").lower() != "true": - logger.warning( - "Cleanup skipped: ENABLE_LAST_ACCESSED is not enabled." - ) - return { - "status": "skipped", - "reason": "ENABLE_LAST_ACCESSED not enabled", - "unused_count": 0, - "deleted_count": {}, - "cleanup_date": datetime.now(timezone.utc).isoformat() - } - - # Check 2: Verify tracking has actually been running - db_engine = get_relational_engine() - async with db_engine.get_async_session() as session: - # Count records with non-NULL last_accessed - tracked_count = await session.execute( - select(sa.func.count(Data.id)).where(Data.last_accessed.isnot(None)) - ) - tracked_records = tracked_count.scalar() - - if tracked_records == 0: - logger.warning( - "Cleanup skipped: No records have been tracked yet. " - "ENABLE_LAST_ACCESSED may have been recently enabled. " - "Wait for retrievers to update timestamps before running cleanup." - ) - return { - "status": "skipped", - "reason": "No tracked records found - tracking may be newly enabled", - "unused_count": 0, - "deleted_count": {}, - "cleanup_date": datetime.now(timezone.utc).isoformat() - } - - logger.info( - "Starting cleanup task", - minutes_threshold=minutes_threshold, - dry_run=dry_run, - user_id=str(user_id) if user_id else None, - text_doc=text_doc, - node_level=node_level - ) + Cleanup results + """ + db_engine = get_relational_engine() - # Calculate cutoff timestamp - cutoff_date = datetime.now(timezone.utc) - timedelta(minutes=minutes_threshold) - - if node_level: - # Deprecated: Node-level approach (chaotic) - logger.warning( - "Node-level cleanup is deprecated and may lead to fragmented knowledge graphs. " - "Consider using document-level cleanup (default) instead." - ) - cutoff_timestamp_ms = int(cutoff_date.timestamp() * 1000) - logger.debug(f"Cutoff timestamp: {cutoff_date.isoformat()} ({cutoff_timestamp_ms}ms)") + async with db_engine.get_async_session() as session: + # Query for Data records with old last_accessed timestamps + query = select(Data, DatasetData).join( + DatasetData, Data.id == DatasetData.data_id + ).where( + or_( + Data.last_accessed < cutoff_date, + Data.last_accessed.is_(None) + ) + ) - # Find unused nodes using graph projection - unused_nodes = await _find_unused_nodes_via_projection(cutoff_timestamp_ms) - - total_unused = sum(len(nodes) for nodes in unused_nodes.values()) - logger.info(f"Found {total_unused} unused nodes", unused_nodes={k: len(v) for k, v in unused_nodes.items()}) - - if dry_run: - return { - "status": "dry_run", - "unused_count": total_unused, - "deleted_count": { - "data_items": 0, - "chunks": 0, - "entities": 0, - "summaries": 0, - "associations": 0 - }, - "cleanup_date": datetime.now(timezone.utc).isoformat(), - "preview": { - "chunks": len(unused_nodes["DocumentChunk"]), - "entities": len(unused_nodes["Entity"]), - "summaries": len(unused_nodes["TextSummary"]) - } - } - - # Delete unused nodes (provider-agnostic deletion) - deleted_counts = await _delete_unused_nodes(unused_nodes) - - logger.info("Cleanup completed", deleted_counts=deleted_counts) + if user_id: + from cognee.modules.data.models import Dataset + query = query.join(Dataset, DatasetData.dataset_id == Dataset.id).where( + Dataset.owner_id == user_id + ) + result = await session.execute(query) + unused_data = result.all() + + logger.info(f"Found {len(unused_data)} unused documents in SQL") + + if dry_run: return { - "status": "completed", - "unused_count": total_unused, + "status": "dry_run", + "unused_count": len(unused_data), "deleted_count": { "data_items": 0, - "chunks": deleted_counts["DocumentChunk"], - "entities": deleted_counts["Entity"], - "summaries": deleted_counts["TextSummary"], - "associations": deleted_counts["associations"] + "documents": 0 }, - "cleanup_date": datetime.now(timezone.utc).isoformat() - } - else: - # Default: Document-level approach (recommended) - return await _cleanup_via_sql(cutoff_date, dry_run, user_id) - - -async def _cleanup_via_sql( - cutoff_date: datetime, - dry_run: bool, - user_id: Optional[UUID] = None -) -> Dict[str, Any]: - """ - SQL-based cleanup: Query Data table for unused documents and use cognee.delete(). - - Parameters - ---------- - cutoff_date : datetime - Cutoff date for last_accessed filtering - dry_run : bool - If True, only report what would be deleted - user_id : UUID, optional - Filter by user ID if provided - - Returns - ------- - Dict[str, Any] - Cleanup results - """ - db_engine = get_relational_engine() - - async with db_engine.get_async_session() as session: - # Query for Data records with old last_accessed timestamps - query = select(Data, DatasetData).join( - DatasetData, Data.id == DatasetData.data_id - ).where( - or_( - Data.last_accessed < cutoff_date, - Data.last_accessed.is_(None) - ) - ) - - if user_id: - from cognee.modules.data.models import Dataset - query = query.join(Dataset, DatasetData.dataset_id == Dataset.id).where( - Dataset.owner_id == user_id - ) - - result = await session.execute(query) - unused_data = result.all() - - logger.info(f"Found {len(unused_data)} unused documents in SQL") - - if dry_run: - return { - "status": "dry_run", - "unused_count": len(unused_data), - "deleted_count": { - "data_items": 0, - "documents": 0 - }, - "cleanup_date": datetime.now(timezone.utc).isoformat(), - "preview": { - "documents": len(unused_data) - } - } - - # Delete each document using cognee.delete() - deleted_count = 0 - from cognee.modules.users.methods import get_default_user - user = await get_default_user() if user_id is None else None - - for data, dataset_data in unused_data: - try: - await cognee.delete( - data_id=data.id, - dataset_id=dataset_data.dataset_id, - mode="hard", # Use hard mode to also remove orphaned entities - user=user - ) - deleted_count += 1 - logger.info(f"Deleted document {data.id} from dataset {dataset_data.dataset_id}") - except Exception as e: - logger.error(f"Failed to delete document {data.id}: {e}") - - logger.info("Cleanup completed", deleted_count=deleted_count) - - return { - "status": "completed", - "unused_count": len(unused_data), - "deleted_count": { - "data_items": deleted_count, - "documents": deleted_count - }, - "cleanup_date": datetime.now(timezone.utc).isoformat() - } - - -async def _find_unused_nodes_via_projection(cutoff_timestamp_ms: int) -> Dict[str, list]: - """ - Find unused nodes using graph projection - database-agnostic approach. - NOTE: This function is deprecated as it leads to fragmented knowledge graphs. + "cleanup_date": datetime.now(timezone.utc).isoformat(), + "preview": { + "documents": len(unused_data) + } + } - Parameters - ---------- - cutoff_timestamp_ms : int - Cutoff timestamp in milliseconds since epoch + # Delete each document using cognee.delete() + deleted_count = 0 + from cognee.modules.users.methods import get_default_user + user = await get_default_user() if user_id is None else None - Returns - ------- - Dict[str, list] - Dictionary mapping node types to lists of unused node IDs - """ - graph_engine = await get_graph_engine() + for data, dataset_data in unused_data: + try: + await cognee.delete( + data_id=data.id, + dataset_id=dataset_data.dataset_id, + mode="hard", # Use hard mode to also remove orphaned entities + user=user + ) + deleted_count += 1 + logger.info(f"Deleted document {data.id} from dataset {dataset_data.dataset_id}") + except Exception as e: + logger.error(f"Failed to delete document {data.id}: {e}") - # Project the entire graph with necessary properties - memory_fragment = CogneeGraph() - await memory_fragment.project_graph_from_db( - graph_engine, - node_properties_to_project=["id", "type", "last_accessed_at"], - edge_properties_to_project=[] - ) - - unused_nodes = {"DocumentChunk": [], "Entity": [], "TextSummary": []} - - # Get all nodes from the projected graph - all_nodes = memory_fragment.get_nodes() - - for node in all_nodes: - node_type = node.get_attribute("type") - if node_type not in unused_nodes: - continue - - # Check last_accessed_at property - last_accessed = node.get_attribute("last_accessed_at") + logger.info("Cleanup completed", deleted_count=deleted_count) - if last_accessed is None or last_accessed < cutoff_timestamp_ms: - unused_nodes[node_type].append(node.id) - logger.debug( - f"Found unused {node_type}", - node_id=node.id, - last_accessed=last_accessed - ) - - return unused_nodes - - -async def _delete_unused_nodes(unused_nodes: Dict[str, list]) -> Dict[str, int]: - """ - Delete unused nodes from graph and vector databases. - NOTE: This function is deprecated as it leads to fragmented knowledge graphs. - - Parameters - ---------- - unused_nodes : Dict[str, list] - Dictionary mapping node types to lists of node IDs to delete - - Returns - ------- - Dict[str, int] - Count of deleted items by type - """ - graph_engine = await get_graph_engine() - vector_engine = get_vector_engine() - - deleted_counts = { - "DocumentChunk": 0, - "Entity": 0, - "TextSummary": 0, - "associations": 0 - } - - # Count associations before deletion (using graph projection for consistency) - if any(unused_nodes.values()): - memory_fragment = CogneeGraph() - await memory_fragment.project_graph_from_db( - graph_engine, - node_properties_to_project=["id"], - edge_properties_to_project=[] - ) - - for node_type, node_ids in unused_nodes.items(): - if not node_ids: - continue - - # Count edges from the in-memory graph - for node_id in node_ids: - node = memory_fragment.get_node(node_id) - if node: - # Count edges from the in-memory graph - edge_count = len(node.get_skeleton_edges()) - deleted_counts["associations"] += edge_count - - # Delete from graph database (uses DETACH DELETE, so edges are automatically removed) - for node_type, node_ids in unused_nodes.items(): - if not node_ids: - continue - - logger.info(f"Deleting {len(node_ids)} {node_type} nodes from graph database") - - # Delete nodes in batches (database-agnostic) - await graph_engine.delete_nodes(node_ids) - deleted_counts[node_type] = len(node_ids) - - # Delete from vector database - vector_collections = { - "DocumentChunk": "DocumentChunk_text", - "Entity": "Entity_name", - "TextSummary": "TextSummary_text" - } - - - for node_type, collection_name in vector_collections.items(): - node_ids = unused_nodes[node_type] - if not node_ids: - continue - - logger.info(f"Deleting {len(node_ids)} {node_type} embeddings from vector database") - - try: - if await vector_engine.has_collection(collection_name): - await vector_engine.delete_data_points( - collection_name, - [str(node_id) for node_id in node_ids] - ) - except Exception as e: - logger.error(f"Error deleting from vector collection {collection_name}: {e}") - - return deleted_counts + return { + "status": "completed", + "unused_count": len(unused_data), + "deleted_count": { + "data_items": deleted_count, + "documents": deleted_count + }, + "cleanup_date": datetime.now(timezone.utc).isoformat() + } diff --git a/cognee/tests/test_cleanup_unused_data.py b/cognee/tests/test_cleanup_unused_data.py index c21b9f5ea..e738dcba0 100644 --- a/cognee/tests/test_cleanup_unused_data.py +++ b/cognee/tests/test_cleanup_unused_data.py @@ -1,244 +1,172 @@ -import os -import pathlib -import cognee -from datetime import datetime, timezone, timedelta -from uuid import UUID -from sqlalchemy import select, update -from cognee.modules.data.models import Data, DatasetData -from cognee.infrastructure.databases.relational import get_relational_engine -from cognee.modules.users.methods import get_default_user -from cognee.shared.logging_utils import get_logger -from cognee.modules.search.types import SearchType - -logger = get_logger() - - -async def test_textdocument_cleanup_with_sql(): - """ - End-to-end test for TextDocument cleanup based on last_accessed timestamps. +import os +import pathlib +import cognee +from datetime import datetime, timezone, timedelta +from uuid import UUID +from sqlalchemy import select, update +from cognee.modules.data.models import Data, DatasetData +from cognee.infrastructure.databases.relational import get_relational_engine +from cognee.modules.users.methods import get_default_user +from cognee.shared.logging_utils import get_logger +from cognee.modules.search.types import SearchType - Tests: - 1. Add and cognify a document - 2. Perform search to populate last_accessed timestamp - 3. Verify last_accessed is set in SQL Data table - 4. Manually age the timestamp beyond cleanup threshold - 5. Run cleanup with text_doc=True - 6. Verify document was deleted from all databases (relational, graph, and vector) - """ - # Setup test directories - data_directory_path = str( - pathlib.Path( - os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_cleanup") - ).resolve() - ) - cognee_directory_path = str( - pathlib.Path( - os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_cleanup") - ).resolve() - ) +logger = get_logger() - cognee.config.data_root_directory(data_directory_path) - cognee.config.system_root_directory(cognee_directory_path) - # Initialize database - from cognee.modules.engine.operations.setup import setup +async def test_textdocument_cleanup_with_sql(): + """ + End-to-end test for TextDocument cleanup based on last_accessed timestamps. + """ + # Enable last accessed tracking BEFORE any cognee operations + os.environ["ENABLE_LAST_ACCESSED"] = "true" - # Clean slate - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - - logger.info("🧪 Testing TextDocument cleanup based on last_accessed") - - # Step 1: Add and cognify a test document - dataset_name = "test_cleanup_dataset" - test_text = """ - Machine learning is a subset of artificial intelligence that enables systems to learn - and improve from experience without being explicitly programmed. Deep learning uses - neural networks with multiple layers to process data. - """ - - await setup() - user = await get_default_user() - await cognee.add([test_text], dataset_name=dataset_name, user=user) - - cognify_result = await cognee.cognify([dataset_name], user=user) - - # Extract dataset_id from cognify result (ds_id is already a UUID) - dataset_id = None - for ds_id, pipeline_result in cognify_result.items(): - dataset_id = ds_id # Don't wrap in UUID() - it's already a UUID object - break - - assert dataset_id is not None, "Failed to get dataset_id from cognify result" - logger.info(f"✅ Document added and cognified. Dataset ID: {dataset_id}") - - # Step 2: Perform search to trigger last_accessed update - logger.info("Triggering search to update last_accessed...") - search_results = await cognee.search( - query_type=SearchType.CHUNKS, - query_text="machine learning", - datasets=[dataset_name], - user=user - ) - logger.info(f"✅ Search completed, found {len(search_results)} results") - - # Step 3: Verify last_accessed was set in SQL Data table - db_engine = get_relational_engine() - async with db_engine.get_async_session() as session: - # Get the Data record for this dataset - result = await session.execute( - select(Data, DatasetData) - .join(DatasetData, Data.id == DatasetData.data_id) - .where(DatasetData.dataset_id == dataset_id) - ) - data_records = result.all() - assert len(data_records) > 0, "No Data records found for the dataset" - data_record = data_records[0][0] - data_id = data_record.id + # Setup test directories + data_directory_path = str( + pathlib.Path( + os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_cleanup") + ).resolve() + ) + cognee_directory_path = str( + pathlib.Path( + os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_cleanup") + ).resolve() + ) - # Verify last_accessed is set (should be set by search operation) - assert data_record.last_accessed is not None, ( - "last_accessed should be set after search operation" - ) + cognee.config.data_root_directory(data_directory_path) + cognee.config.system_root_directory(cognee_directory_path) - original_last_accessed = data_record.last_accessed - logger.info(f"✅ last_accessed verified: {original_last_accessed}") - - # Step 4: Manually age the timestamp to be older than cleanup threshold - days_threshold = 30 - aged_timestamp = datetime.now(timezone.utc) - timedelta(days=days_threshold + 10) - - async with db_engine.get_async_session() as session: - stmt = update(Data).where(Data.id == data_id).values(last_accessed=aged_timestamp) - await session.execute(stmt) - await session.commit() - - # Query in a NEW session to avoid cached values - async with db_engine.get_async_session() as session: - result = await session.execute(select(Data).where(Data.id == data_id)) - updated_data = result.scalar_one_or_none() + # Initialize database + from cognee.modules.engine.operations.setup import setup - # Make both timezone-aware for comparison - retrieved_timestamp = updated_data.last_accessed - if retrieved_timestamp.tzinfo is None: - # If database returned naive datetime, make it UTC-aware - retrieved_timestamp = retrieved_timestamp.replace(tzinfo=timezone.utc) + # Clean slate + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) - assert retrieved_timestamp == aged_timestamp, ( - f"Timestamp should be updated to aged value. " - f"Expected: {aged_timestamp}, Got: {retrieved_timestamp}" - ) - - # Step 5: Test cleanup with text_doc=True - from cognee.tasks.cleanup.cleanup_unused_data import cleanup_unused_data - - # First do a dry run - logger.info("Testing dry run with text_doc=True...") - dry_run_result = await cleanup_unused_data( - days_threshold=30, - dry_run=True, - user_id=user.id, - text_doc=True - ) - - assert dry_run_result['status'] == 'dry_run', "Status should be 'dry_run'" - assert dry_run_result['unused_count'] > 0, ( - "Should find at least one unused document" - ) - logger.info(f"✅ Dry run found {dry_run_result['unused_count']} unused documents") - - # Now run actual cleanup - logger.info("Executing cleanup with text_doc=True...") - cleanup_result = await cleanup_unused_data( - days_threshold=30, - dry_run=False, - user_id=user.id, - text_doc=True - ) - - assert cleanup_result["status"] == "completed", "Cleanup should complete successfully" - assert cleanup_result["deleted_count"]["documents"] > 0, ( - "At least one document should be deleted" - ) - logger.info(f"✅ Cleanup completed. Deleted {cleanup_result['deleted_count']['documents']} documents") - - # Step 6: Verify the document was actually deleted from SQL - async with db_engine.get_async_session() as session: - deleted_data = ( - await session.execute(select(Data).where(Data.id == data_id)) - ).scalar_one_or_none() + logger.info("🧪 Testing TextDocument cleanup based on last_accessed") - assert deleted_data is None, ( - "Data record should be deleted after cleanup" - ) - logger.info("✅ Confirmed: Data record was deleted from SQL database") - - # Verify the dataset-data link was also removed - async with db_engine.get_async_session() as session: - dataset_data_link = ( - await session.execute( - select(DatasetData).where( - DatasetData.data_id == data_id, - DatasetData.dataset_id == dataset_id - ) - ) - ).scalar_one_or_none() + # Step 1: Add and cognify a test document + dataset_name = "test_cleanup_dataset" + test_text = """ + Machine learning is a subset of artificial intelligence that enables systems to learn + and improve from experience without being explicitly programmed. Deep learning uses + neural networks with multiple layers to process data. + """ - assert dataset_data_link is None, ( - "DatasetData link should be deleted after cleanup" - ) - logger.info("✅ Confirmed: DatasetData link was deleted") + await setup() + user = await get_default_user() + await cognee.add([test_text], dataset_name=dataset_name, user=user) + + cognify_result = await cognee.cognify([dataset_name], user=user) + + # Extract dataset_id from cognify result + dataset_id = None + for ds_id, pipeline_result in cognify_result.items(): + dataset_id = ds_id + break + + assert dataset_id is not None, "Failed to get dataset_id from cognify result" + logger.info(f"✅ Document added and cognified. Dataset ID: {dataset_id}") + + # Step 2: Perform search to trigger last_accessed update + logger.info("Triggering search to update last_accessed...") + search_results = await cognee.search( + query_type=SearchType.CHUNKS, + query_text="machine learning", + datasets=[dataset_name], + user=user + ) + logger.info(f"✅ Search completed, found {len(search_results)} results") + assert len(search_results) > 0, "Search should return results" + + # Step 3: Verify last_accessed was set and get data_id + db_engine = get_relational_engine() + async with db_engine.get_async_session() as session: + result = await session.execute( + select(Data, DatasetData) + .join(DatasetData, Data.id == DatasetData.data_id) + .where(DatasetData.dataset_id == dataset_id) + ) + data_records = result.all() + assert len(data_records) > 0, "No Data records found for the dataset" + data_record = data_records[0][0] + data_id = data_record.id + + # Verify last_accessed is set + assert data_record.last_accessed is not None, ( + "last_accessed should be set after search operation" + ) + + original_last_accessed = data_record.last_accessed + logger.info(f"✅ last_accessed verified: {original_last_accessed}") + + # Step 4: Manually age the timestamp + minutes_threshold = 30 + aged_timestamp = datetime.now(timezone.utc) - timedelta(minutes=minutes_threshold + 10) + + async with db_engine.get_async_session() as session: + stmt = update(Data).where(Data.id == data_id).values(last_accessed=aged_timestamp) + await session.execute(stmt) + await session.commit() + + # Verify timestamp was updated + async with db_engine.get_async_session() as session: + result = await session.execute(select(Data).where(Data.id == data_id)) + updated_data = result.scalar_one_or_none() + assert updated_data is not None, "Data record should exist" + retrieved_timestamp = updated_data.last_accessed + if retrieved_timestamp.tzinfo is None: + retrieved_timestamp = retrieved_timestamp.replace(tzinfo=timezone.utc) + assert retrieved_timestamp == aged_timestamp, ( + f"Timestamp should be updated to aged value" + ) + + # Step 5: Test cleanup (document-level is now the default) + from cognee.tasks.cleanup.cleanup_unused_data import cleanup_unused_data + + # First do a dry run + logger.info("Testing dry run...") + dry_run_result = await cleanup_unused_data( + minutes_threshold=10, + dry_run=True, + user_id=user.id + ) + + # Debug: Print the actual result + logger.info(f"Dry run result: {dry_run_result}") - # Verify graph nodes were cleaned up - from cognee.infrastructure.databases.graph import get_graph_engine + assert dry_run_result['status'] == 'dry_run', f"Status should be 'dry_run', got: {dry_run_result['status']}" + assert dry_run_result['unused_count'] > 0, ( + "Should find at least one unused document" + ) + logger.info(f"✅ Dry run found {dry_run_result['unused_count']} unused documents") + + # Now run actual cleanup + logger.info("Executing cleanup...") + cleanup_result = await cleanup_unused_data( + minutes_threshold=30, + dry_run=False, + user_id=user.id + ) + + assert cleanup_result["status"] == "completed", "Cleanup should complete successfully" + assert cleanup_result["deleted_count"]["documents"] > 0, ( + "At least one document should be deleted" + ) + logger.info(f"✅ Cleanup completed. Deleted {cleanup_result['deleted_count']['documents']} documents") + + # Step 6: Verify deletion + async with db_engine.get_async_session() as session: + deleted_data = ( + await session.execute(select(Data).where(Data.id == data_id)) + ).scalar_one_or_none() + assert deleted_data is None, "Data record should be deleted" + logger.info("✅ Confirmed: Data record was deleted") + + logger.info("🎉 All cleanup tests passed!") + return True - graph_engine = await get_graph_engine() - # Try to find the TextDocument node - it should not exist - result = await graph_engine.query( - "MATCH (n:Node {id: $id}) RETURN n", - {"id": str(data_id)} - ) - - assert len(result) == 0, ( - "TextDocument node should be deleted from graph database" - ) - logger.info("✅ Confirmed: TextDocument node was deleted from graph database") - - # Verify vector database was cleaned up - from cognee.infrastructure.databases.vector import get_vector_engine - - vector_engine = get_vector_engine() - - # Check each collection that should have been cleaned up - vector_collections = [ - "DocumentChunk_text", - "Entity_name", - "TextSummary_text" - ] - - for collection_name in vector_collections: - if await vector_engine.has_collection(collection_name): - # Try to retrieve the deleted data points - try: - results = await vector_engine.retrieve(collection_name, [str(data_id)]) - assert len(results) == 0, ( - f"Data points should be deleted from {collection_name} collection" - ) - logger.info(f"✅ Confirmed: {collection_name} collection is clean") - except Exception as e: - # Collection might be empty or not exist, which is fine - logger.info(f"✅ Confirmed: {collection_name} collection is empty or doesn't exist") - pass - - logger.info("✅ Confirmed: Vector database entries were deleted") - - logger.info("🎉 All cleanup tests passed!") - - return True - - -if __name__ == "__main__": - import asyncio - success = asyncio.run(test_textdocument_cleanup_with_sql()) +if __name__ == "__main__": + import asyncio + success = asyncio.run(test_textdocument_cleanup_with_sql()) exit(0 if success else 1) From 6260f9eb82c9078eded523a80f035f4054d7091c Mon Sep 17 00:00:00 2001 From: rajeevrajeshuni Date: Thu, 11 Dec 2025 06:53:36 +0530 Subject: [PATCH 048/176] strandardizing return type for transcription and some CR changes --- .../llm/anthropic/adapter.py | 2 +- .../litellm_instructor/llm/gemini/adapter.py | 1 - .../llm/generic_llm_api/adapter.py | 11 ++++++++-- .../litellm_instructor/llm/get_llm_client.py | 3 ++- .../litellm_instructor/llm/mistral/adapter.py | 22 +++++++++---------- .../litellm_instructor/llm/openai/adapter.py | 11 ++++++---- .../litellm_instructor/llm/types.py | 9 ++++++++ 7 files changed, 39 insertions(+), 20 deletions(-) create mode 100644 cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/types.py diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py index 4d75c886a..49b13fcaa 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py @@ -44,7 +44,7 @@ class AnthropicAdapter(GenericAPIAdapter): self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode self.aclient = instructor.patch( - create=anthropic.AsyncAnthropic(api_key=get_llm_config().llm_api_key).messages.create, + create=anthropic.AsyncAnthropic(api_key=self.api_key).messages.create, mode=instructor.Mode(self.instructor_mode), ) diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py index ffb7bf77b..99dfd6179 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py @@ -10,7 +10,6 @@ from instructor.core import InstructorRetryException import logging from cognee.shared.rate_limiting import llm_rate_limiter_context_manager -from cognee.shared.logging_utils import get_logger from tenacity import ( retry, diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py index 408058d3e..7905c25bf 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py @@ -27,6 +27,8 @@ from tenacity import ( before_sleep_log, ) +from ..types import TranscriptionReturnType + logger = get_logger() observe = get_observe() @@ -191,7 +193,7 @@ class GenericAPIAdapter(LLMInterface): before_sleep=before_sleep_log(logger, logging.DEBUG), reraise=True, ) - async def create_transcript(self, input) -> Optional[BaseModel]: + async def create_transcript(self, input) -> Optional[TranscriptionReturnType]: """ Generate an audio transcript from a user query. @@ -214,7 +216,7 @@ class GenericAPIAdapter(LLMInterface): raise ValueError( f"Could not determine MIME type for audio file: {input}. Is the extension correct?" ) - return litellm.completion( + response = litellm.completion( model=self.transcription_model, messages=[ { @@ -234,6 +236,11 @@ class GenericAPIAdapter(LLMInterface): api_base=self.endpoint, max_retries=self.MAX_RETRIES, ) + if response and response.choices and len(response.choices) > 0: + return TranscriptionReturnType(response.choices[0].message.content,response) + else: + return None + @observe(as_type="transcribe_image") @retry( diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py index de6cfaf19..e5f4bd1b1 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py @@ -97,10 +97,11 @@ def get_llm_client(raise_api_key_error: bool = True): ) return OllamaAPIAdapter( + llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, + "Ollama", max_completion_tokens, - llm_config.llm_endpoint, instructor_mode=llm_config.llm_instructor_mode.lower(), ) diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py index 954510a25..b141f7585 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py @@ -1,9 +1,9 @@ import litellm import instructor from pydantic import BaseModel -from typing import Type -from litellm import JSONSchemaValidationError, transcription - +from typing import Type, Optional +from litellm import JSONSchemaValidationError +from cognee.infrastructure.files.utils.open_data_file import open_data_file from cognee.shared.logging_utils import get_logger from cognee.modules.observability.get_observe import get_observe from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import ( @@ -20,6 +20,7 @@ from tenacity import ( retry_if_not_exception_type, before_sleep_log, ) +from ..types import TranscriptionReturnType from mistralai import Mistral logger = get_logger() @@ -47,8 +48,6 @@ class MistralAdapter(GenericAPIAdapter): image_transcribe_model: str = None, instructor_mode: str = None, ): - from mistralai import Mistral - super().__init__( api_key=api_key, model=model, @@ -66,6 +65,7 @@ class MistralAdapter(GenericAPIAdapter): mode=instructor.Mode(self.instructor_mode), api_key=get_llm_config().llm_api_key, ) + self.mistral_client = Mistral(api_key=self.api_key) @observe(as_type="generation") @retry( @@ -135,7 +135,7 @@ class MistralAdapter(GenericAPIAdapter): before_sleep=before_sleep_log(logger, logging.DEBUG), reraise=True, ) - async def create_transcript(self, input): + async def create_transcript(self, input) -> Optional[TranscriptionReturnType]: """ Generate an audio transcript from a user query. @@ -154,14 +154,14 @@ class MistralAdapter(GenericAPIAdapter): if self.transcription_model.startswith("mistral"): transcription_model = self.transcription_model.split("/")[-1] file_name = input.split("/")[-1] - client = Mistral(api_key=self.api_key) - with open(input, "rb") as f: - transcription_response = client.audio.transcriptions.complete( + async with open_data_file(input, mode="rb") as f: + transcription_response = self.mistral_client.audio.transcriptions.complete( model=transcription_model, file={ "content": f, "file_name": file_name, }, ) - # TODO: We need to standardize return type of create_transcript across different models. - return transcription_response + if transcription_response: + return TranscriptionReturnType(transcription_response.text, transcription_response) + return None diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py index 57b6d339a..94c6aed6d 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py @@ -1,6 +1,6 @@ import litellm import instructor -from typing import Type +from typing import Type, Optional from pydantic import BaseModel from openai import ContentFilterFinishReasonError from litellm.exceptions import ContentPolicyViolationError @@ -25,6 +25,7 @@ from cognee.shared.rate_limiting import llm_rate_limiter_context_manager from cognee.infrastructure.files.utils.open_data_file import open_data_file from cognee.modules.observability.get_observe import get_observe from cognee.shared.logging_utils import get_logger +from ..types import TranscriptionReturnType logger = get_logger() @@ -200,7 +201,7 @@ class OpenAIAdapter(GenericAPIAdapter): before_sleep=before_sleep_log(logger, logging.DEBUG), reraise=True, ) - async def create_transcript(self, input): + async def create_transcript(self, input) -> Optional[TranscriptionReturnType]: """ Generate an audio transcript from a user query. @@ -228,7 +229,9 @@ class OpenAIAdapter(GenericAPIAdapter): api_version=self.api_version, max_retries=self.MAX_RETRIES, ) + if transcription: + return TranscriptionReturnType(transcription.text, transcription) - return transcription + return None - # transcribe image inherited from GenericAdapter + # transcribe_image is inherited from GenericAPIAdapter diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/types.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/types.py new file mode 100644 index 000000000..887cdd88d --- /dev/null +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/types.py @@ -0,0 +1,9 @@ +from pydantic import BaseModel + +class TranscriptionReturnType: + text: str + payload: BaseModel + + def __init__(self, text:str, payload: BaseModel): + self.text = text + self.payload = payload \ No newline at end of file From f48df27fc85b0df53701f0bf813563dc20494f74 Mon Sep 17 00:00:00 2001 From: hiyan Date: Thu, 11 Dec 2025 10:32:45 +0530 Subject: [PATCH 049/176] fix(db): url-encode postgres credentials to handle special characters --- .../databases/relational/create_relational_engine.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/cognee/infrastructure/databases/relational/create_relational_engine.py b/cognee/infrastructure/databases/relational/create_relational_engine.py index deaeaa2da..8813dfcb2 100644 --- a/cognee/infrastructure/databases/relational/create_relational_engine.py +++ b/cognee/infrastructure/databases/relational/create_relational_engine.py @@ -1,5 +1,6 @@ from .sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter from functools import lru_cache +from urllib.parse import quote_plus @lru_cache @@ -43,9 +44,10 @@ def create_relational_engine( # Test if asyncpg is available import asyncpg - connection_string = ( - f"postgresql+asyncpg://{db_username}:{db_password}@{db_host}:{db_port}/{db_name}" - ) + encoded_username = quote_plus(db_username) + encoded_password = quote_plus(db_password) + + connection_string = f"postgresql+asyncpg://{encoded_username}:{encoded_password}@{db_host}:{db_port}/{db_name}" except ImportError: raise ImportError( "PostgreSQL dependencies are not installed. Please install with 'pip install cognee\"[postgres]\"' or 'pip install cognee\"[postgres-binary]\"' to use PostgreSQL functionality." From 2485c3f5f0c2b25572213fe7638467859679c8d2 Mon Sep 17 00:00:00 2001 From: chinu0609 Date: Thu, 11 Dec 2025 12:48:06 +0530 Subject: [PATCH 050/176] fix: only document level deletion --- cognee/infrastructure/engine/models/DataPoint.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/cognee/infrastructure/engine/models/DataPoint.py b/cognee/infrastructure/engine/models/DataPoint.py index 3178713c8..812380eaa 100644 --- a/cognee/infrastructure/engine/models/DataPoint.py +++ b/cognee/infrastructure/engine/models/DataPoint.py @@ -43,9 +43,6 @@ class DataPoint(BaseModel): updated_at: int = Field( default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000) ) - last_accessed_at: int = Field( - default_factory=lambda: int(datetime.now(timezone.utc).timestamp() * 1000) - ) ontology_valid: bool = False version: int = 1 # Default version topological_rank: Optional[int] = 0 From cd60ae31740acc9444f5aaf61fd7720deb2a5c51 Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Thu, 11 Dec 2025 15:25:44 +0100 Subject: [PATCH 051/176] test: remove docs tests. add trigger to docs repo --- .github/workflows/docs_tests.yml | 280 ------------------ .github/workflows/release_test.yml | 23 +- .../tests/docs/guides/custom_data_models.py | 38 --- cognee/tests/docs/guides/custom_prompts.py | 30 -- .../docs/guides/custom_tasks_and_pipelines.py | 53 ---- .../tests/docs/guides/graph_visualization.py | 13 - cognee/tests/docs/guides/low_level_llm.py | 31 -- cognee/tests/docs/guides/memify_quickstart.py | 29 -- .../tests/docs/guides/ontology_quickstart.py | 30 -- cognee/tests/docs/guides/s3_storage.py | 25 -- cognee/tests/docs/guides/search_basics.py | 58 ---- cognee/tests/docs/guides/temporal_cognify.py | 57 ---- 12 files changed, 16 insertions(+), 651 deletions(-) delete mode 100644 .github/workflows/docs_tests.yml delete mode 100644 cognee/tests/docs/guides/custom_data_models.py delete mode 100644 cognee/tests/docs/guides/custom_prompts.py delete mode 100644 cognee/tests/docs/guides/custom_tasks_and_pipelines.py delete mode 100644 cognee/tests/docs/guides/graph_visualization.py delete mode 100644 cognee/tests/docs/guides/low_level_llm.py delete mode 100644 cognee/tests/docs/guides/memify_quickstart.py delete mode 100644 cognee/tests/docs/guides/ontology_quickstart.py delete mode 100644 cognee/tests/docs/guides/s3_storage.py delete mode 100644 cognee/tests/docs/guides/search_basics.py delete mode 100644 cognee/tests/docs/guides/temporal_cognify.py diff --git a/.github/workflows/docs_tests.yml b/.github/workflows/docs_tests.yml deleted file mode 100644 index 7f7282bb2..000000000 --- a/.github/workflows/docs_tests.yml +++ /dev/null @@ -1,280 +0,0 @@ -name: Docs Tests - -permissions: - contents: read - -on: - workflow_dispatch: - workflow_call: - secrets: - LLM_PROVIDER: - required: true - LLM_MODEL: - required: true - LLM_ENDPOINT: - required: true - LLM_API_KEY: - required: true - LLM_API_VERSION: - required: true - EMBEDDING_PROVIDER: - required: true - EMBEDDING_MODEL: - required: true - EMBEDDING_ENDPOINT: - required: true - EMBEDDING_API_KEY: - required: true - EMBEDDING_API_VERSION: - required: true - -env: - ENV: 'dev' - -jobs: - test-search-basics: - name: Test Search Basics - runs-on: ubuntu-22.04 - steps: - - name: Check out repository - uses: actions/checkout@v4 - - - name: Cognee Setup - uses: ./.github/actions/cognee_setup - with: - python-version: '3.11.x' - - - name: Run Search Basics Test - env: - LLM_MODEL: ${{ secrets.LLM_MODEL }} - LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} - LLM_API_KEY: ${{ secrets.LLM_API_KEY }} - LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} - EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} - EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} - EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} - EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} - run: uv run python ./cognee/tests/docs/guides/search_basics.py - - test-temporal-cognify: - name: Test Temporal Cognify - runs-on: ubuntu-22.04 - steps: - - name: Check out repository - uses: actions/checkout@v4 - - - name: Cognee Setup - uses: ./.github/actions/cognee_setup - with: - python-version: '3.11.x' - - - name: Run Temporal Cognify Test - env: - LLM_MODEL: ${{ secrets.LLM_MODEL }} - LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} - LLM_API_KEY: ${{ secrets.LLM_API_KEY }} - LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} - EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} - EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} - EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} - EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} - run: uv run python ./cognee/tests/docs/guides/temporal_cognify.py - - test-ontology-quickstart: - name: Test Temporal Cognify - runs-on: ubuntu-22.04 - steps: - - name: Check out repository - uses: actions/checkout@v4 - - - name: Cognee Setup - uses: ./.github/actions/cognee_setup - with: - python-version: '3.11.x' - - - name: Run Temporal Cognify Test - env: - LLM_MODEL: ${{ secrets.LLM_MODEL }} - LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} - LLM_API_KEY: ${{ secrets.LLM_API_KEY }} - LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} - EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} - EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} - EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} - EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} - run: uv run python ./cognee/tests/docs/guides/temporal_cognify.py - - test-s3-storage: - name: Test S3 Docs Guide - runs-on: ubuntu-22.04 - steps: - - name: Check out repository - uses: actions/checkout@v4 - - - name: Cognee Setup - uses: ./.github/actions/cognee_setup - with: - python-version: '3.11.x' - extra-dependencies: "aws" - - - name: Run S3 Docs Guide Test - env: - ENABLE_BACKEND_ACCESS_CONTROL: True - LLM_MODEL: ${{ secrets.LLM_MODEL }} - LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} - LLM_API_KEY: ${{ secrets.LLM_API_KEY }} - LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} - EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} - EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} - EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} - EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} - STORAGE_BACKEND: s3 - AWS_REGION: eu-west-1 - AWS_ENDPOINT_URL: https://s3-eu-west-1.amazonaws.com - AWS_ACCESS_KEY_ID: ${{ secrets.AWS_S3_DEV_USER_KEY_ID }} - AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_S3_DEV_USER_SECRET_KEY }} - run: uv run python ./cognee/tests/docs/guides/s3_storage.py - - test-graph-visualization: - name: Test Graph Visualization - runs-on: ubuntu-22.04 - steps: - - name: Check out repository - uses: actions/checkout@v4 - - - name: Cognee Setup - uses: ./.github/actions/cognee_setup - with: - python-version: '3.11.x' - - - name: Run Graph Visualization Test - env: - LLM_MODEL: ${{ secrets.LLM_MODEL }} - LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} - LLM_API_KEY: ${{ secrets.LLM_API_KEY }} - LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} - EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} - EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} - EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} - EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} - run: uv run python ./cognee/tests/docs/guides/graph_visualization.py - - test-low-level-llm: - name: Test Low Level LLM - runs-on: ubuntu-22.04 - steps: - - name: Check out repository - uses: actions/checkout@v4 - - - name: Cognee Setup - uses: ./.github/actions/cognee_setup - with: - python-version: '3.11.x' - - - name: Run Low Level LLM Test - env: - LLM_MODEL: ${{ secrets.LLM_MODEL }} - LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} - LLM_API_KEY: ${{ secrets.LLM_API_KEY }} - LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} - EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} - EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} - EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} - EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} - run: uv run python ./cognee/tests/docs/guides/low_level_llm.py - - test-memify-quickstart: - name: Test Memify Quickstart - runs-on: ubuntu-22.04 - steps: - - name: Check out repository - uses: actions/checkout@v4 - - - name: Cognee Setup - uses: ./.github/actions/cognee_setup - with: - python-version: '3.11.x' - - - name: Run Memify Quickstart Test - env: - LLM_MODEL: ${{ secrets.LLM_MODEL }} - LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} - LLM_API_KEY: ${{ secrets.LLM_API_KEY }} - LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} - EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} - EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} - EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} - EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} - run: uv run python ./cognee/tests/docs/guides/memify_quickstart.py - - test-custom-data-models: - name: Test Custom Data Models - runs-on: ubuntu-22.04 - steps: - - name: Check out repository - uses: actions/checkout@v4 - - - name: Cognee Setup - uses: ./.github/actions/cognee_setup - with: - python-version: '3.11.x' - - - name: Run Custom Data Models Test - env: - LLM_MODEL: ${{ secrets.LLM_MODEL }} - LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} - LLM_API_KEY: ${{ secrets.LLM_API_KEY }} - LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} - EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} - EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} - EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} - EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} - run: uv run python ./cognee/tests/docs/guides/custom_data_models.py - - test-custom-tasks-and-pipelines: - name: Test Custom Tasks and Pipelines - runs-on: ubuntu-22.04 - steps: - - name: Check out repository - uses: actions/checkout@v4 - - - name: Cognee Setup - uses: ./.github/actions/cognee_setup - with: - python-version: '3.11.x' - - - name: Run Custom Tasks and Pipelines Test - env: - LLM_MODEL: ${{ secrets.LLM_MODEL }} - LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} - LLM_API_KEY: ${{ secrets.LLM_API_KEY }} - LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} - EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} - EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} - EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} - EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} - run: uv run python ./cognee/tests/docs/guides/custom_tasks_and_pipelines.py - - test-custom-prompts: - name: Test Custom Prompts - runs-on: ubuntu-22.04 - steps: - - name: Check out repository - uses: actions/checkout@v4 - - - name: Cognee Setup - uses: ./.github/actions/cognee_setup - with: - python-version: '3.11.x' - - - name: Run Custom Prompts Test - env: - LLM_MODEL: ${{ secrets.LLM_MODEL }} - LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} - LLM_API_KEY: ${{ secrets.LLM_API_KEY }} - LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} - EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} - EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} - EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} - EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} - run: uv run python ./cognee/tests/docs/guides/custom_prompts.py \ No newline at end of file diff --git a/.github/workflows/release_test.yml b/.github/workflows/release_test.yml index 89540fcfb..c6dd68484 100644 --- a/.github/workflows/release_test.yml +++ b/.github/workflows/release_test.yml @@ -5,18 +5,27 @@ permissions: contents: read on: + push: + branches: + - feature/cog-3213-docs-set-up-guide-script-tests workflow_dispatch: pull_request: branches: - main jobs: - load-tests: - name: Load Tests - uses: ./.github/workflows/load_tests.yml - secrets: inherit +# load-tests: +# name: Load Tests +# uses: ./.github/workflows/load_tests.yml +# secrets: inherit docs-tests: - name: Docs Tests - uses: ./.github/workflows/docs_tests.yml - secrets: inherit \ No newline at end of file + runs-on: ubuntu-22.04 + steps: + - name: Trigger docs tests + run: | + curl -sS -X POST \ + -H "Accept: application/vnd.github+json" \ + -H "Authorization: Bearer ${{ secrets.DOCS_REPO_PAT_TOKEN }}" \ + https://api.github.com/repos/your-org/repo-b/dispatches \ + -d '{"event_type":"new-main-release","client_payload":{"caller_repo":"'"${GITHUB_REPOSITORY}"'"}}' diff --git a/cognee/tests/docs/guides/custom_data_models.py b/cognee/tests/docs/guides/custom_data_models.py deleted file mode 100644 index 0eb314227..000000000 --- a/cognee/tests/docs/guides/custom_data_models.py +++ /dev/null @@ -1,38 +0,0 @@ -import asyncio -from typing import Any -from pydantic import SkipValidation - -import cognee -from cognee.infrastructure.engine import DataPoint -from cognee.infrastructure.engine.models.Edge import Edge -from cognee.tasks.storage import add_data_points - - -class Person(DataPoint): - name: str - # Keep it simple for forward refs / mixed values - knows: SkipValidation[Any] = None # single Person or list[Person] - # Recommended: specify which fields to index for search - metadata: dict = {"index_fields": ["name"]} - - -async def main(): - # Start clean (optional in your app) - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - - alice = Person(name="Alice") - bob = Person(name="Bob") - charlie = Person(name="Charlie") - - # Create relationships - field name becomes edge label - alice.knows = bob - # You can also do lists: alice.knows = [bob, charlie] - - # Optional: add weights and custom relationship types - bob.knows = (Edge(weight=0.9, relationship_type="friend_of"), charlie) - - await add_data_points([alice, bob, charlie]) - - -asyncio.run(main()) diff --git a/cognee/tests/docs/guides/custom_prompts.py b/cognee/tests/docs/guides/custom_prompts.py deleted file mode 100644 index 0d0a55a80..000000000 --- a/cognee/tests/docs/guides/custom_prompts.py +++ /dev/null @@ -1,30 +0,0 @@ -import asyncio -import cognee -from cognee.api.v1.search import SearchType - -custom_prompt = """ -Extract only people and cities as entities. -Connect people to cities with the relationship "lives_in". -Ignore all other entities. -""" - - -async def main(): - await cognee.add( - [ - "Alice moved to Paris in 2010, while Bob has always lived in New York.", - "Andreas was born in Venice, but later settled in Lisbon.", - "Diana and Tom were born and raised in Helsingy. Diana currently resides in Berlin, while Tom never moved.", - ] - ) - await cognee.cognify(custom_prompt=custom_prompt) - - res = await cognee.search( - query_type=SearchType.GRAPH_COMPLETION, - query_text="Where does Alice live?", - ) - print(res) - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/cognee/tests/docs/guides/custom_tasks_and_pipelines.py b/cognee/tests/docs/guides/custom_tasks_and_pipelines.py deleted file mode 100644 index 202bb128a..000000000 --- a/cognee/tests/docs/guides/custom_tasks_and_pipelines.py +++ /dev/null @@ -1,53 +0,0 @@ -import asyncio -from typing import Any, Dict, List -from pydantic import BaseModel, SkipValidation - -import cognee -from cognee.modules.engine.operations.setup import setup -from cognee.infrastructure.llm.LLMGateway import LLMGateway -from cognee.infrastructure.engine import DataPoint -from cognee.tasks.storage import add_data_points -from cognee.modules.pipelines import Task, run_pipeline - - -class Person(DataPoint): - name: str - # Optional relationships (we'll let the LLM populate this) - knows: List["Person"] = [] - # Make names searchable in the vector store - metadata: Dict[str, Any] = {"index_fields": ["name"]} - - -class People(BaseModel): - persons: List[Person] - - -async def extract_people(text: str) -> List[Person]: - system_prompt = ( - "Extract people mentioned in the text. " - "Return as `persons: Person[]` with each Person having `name` and optional `knows` relations. " - "If the text says someone knows someone set `knows` accordingly. " - "Only include facts explicitly stated." - ) - people = await LLMGateway.acreate_structured_output(text, system_prompt, People) - return people.persons - - -async def main(): - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - - text = "Alice knows Bob." - - tasks = [ - Task(extract_people), # input: text -> output: list[Person] - Task(add_data_points), # input: list[Person] -> output: list[Person] - ] - - async for _ in run_pipeline(tasks=tasks, data=text, datasets=["people_demo"]): - pass - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/cognee/tests/docs/guides/graph_visualization.py b/cognee/tests/docs/guides/graph_visualization.py deleted file mode 100644 index d463cbb56..000000000 --- a/cognee/tests/docs/guides/graph_visualization.py +++ /dev/null @@ -1,13 +0,0 @@ -import asyncio -import cognee -from cognee.api.v1.visualize.visualize import visualize_graph - - -async def main(): - await cognee.add(["Alice knows Bob.", "NLP is a subfield of CS."]) - await cognee.cognify() - - await visualize_graph("./graph_after_cognify.html") - - -asyncio.run(main()) diff --git a/cognee/tests/docs/guides/low_level_llm.py b/cognee/tests/docs/guides/low_level_llm.py deleted file mode 100644 index 454f53f44..000000000 --- a/cognee/tests/docs/guides/low_level_llm.py +++ /dev/null @@ -1,31 +0,0 @@ -import asyncio - -from pydantic import BaseModel -from typing import List -from cognee.infrastructure.llm.LLMGateway import LLMGateway - - -class MiniEntity(BaseModel): - name: str - type: str - - -class MiniGraph(BaseModel): - nodes: List[MiniEntity] - - -async def main(): - system_prompt = ( - "Extract entities as nodes with name and type. " - "Use concise, literal values present in the text." - ) - - text = "Apple develops iPhone; Audi produces the R8." - - result = await LLMGateway.acreate_structured_output(text, system_prompt, MiniGraph) - print(result) - # MiniGraph(nodes=[MiniEntity(name='Apple', type='Organization'), ...]) - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/cognee/tests/docs/guides/memify_quickstart.py b/cognee/tests/docs/guides/memify_quickstart.py deleted file mode 100644 index 040654350..000000000 --- a/cognee/tests/docs/guides/memify_quickstart.py +++ /dev/null @@ -1,29 +0,0 @@ -import asyncio -import cognee -from cognee import SearchType - - -async def main(): - # 1) Add two short chats and build a graph - await cognee.add( - [ - "We follow PEP8. Add type hints and docstrings.", - "Releases should not be on Friday. Susan must review PRs.", - ], - dataset_name="rules_demo", - ) - await cognee.cognify(datasets=["rules_demo"]) # builds graph - - # 2) Enrich the graph (uses default memify tasks) - await cognee.memify(dataset="rules_demo") - - # 3) Query the new coding rules - rules = await cognee.search( - query_type=SearchType.CODING_RULES, - query_text="List coding rules", - node_name=["coding_agent_rules"], - ) - print("Rules:", rules) - - -asyncio.run(main()) diff --git a/cognee/tests/docs/guides/ontology_quickstart.py b/cognee/tests/docs/guides/ontology_quickstart.py deleted file mode 100644 index 2784dab19..000000000 --- a/cognee/tests/docs/guides/ontology_quickstart.py +++ /dev/null @@ -1,30 +0,0 @@ -import asyncio -import cognee - - -async def main(): - texts = ["Audi produces the R8 and e-tron.", "Apple develops iPhone and MacBook."] - - await cognee.add(texts) - # or: await cognee.add("/path/to/folder/of/files") - - import os - from cognee.modules.ontology.ontology_config import Config - from cognee.modules.ontology.rdf_xml.RDFLibOntologyResolver import RDFLibOntologyResolver - - ontology_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "ontology_input_example/basic_ontology.owl" - ) - - # Create full config structure manually - config: Config = { - "ontology_config": { - "ontology_resolver": RDFLibOntologyResolver(ontology_file=ontology_path) - } - } - - await cognee.cognify(config=config) - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/cognee/tests/docs/guides/s3_storage.py b/cognee/tests/docs/guides/s3_storage.py deleted file mode 100644 index 1044e05b4..000000000 --- a/cognee/tests/docs/guides/s3_storage.py +++ /dev/null @@ -1,25 +0,0 @@ -import asyncio -import cognee - - -async def main(): - # Single file - await cognee.add("s3://cognee-temp/2024-11-04.md") - - # Folder/prefix (recursively expands) - await cognee.add("s3://cognee-temp") - - # Mixed list - await cognee.add( - [ - "s3://cognee-temp/2024-11-04.md", - "Some inline text to ingest", - ] - ) - - # Process the data - await cognee.cognify() - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/cognee/tests/docs/guides/search_basics.py b/cognee/tests/docs/guides/search_basics.py deleted file mode 100644 index f1847ad4b..000000000 --- a/cognee/tests/docs/guides/search_basics.py +++ /dev/null @@ -1,58 +0,0 @@ -import asyncio -import cognee - -from cognee.modules.search.types import SearchType, CombinedSearchResult - - -async def main(): - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - - text = """ - Natural language processing (NLP) is an interdisciplinary - subfield of computer science and information retrieval. - First rule of coding: Do not talk about coding. - """ - - text2 = """ - Sandwiches are best served toasted with cheese, ham, mayo, - lettuce, mustard, and salt & pepper. - """ - - await cognee.add(text, dataset_name="NLP_coding") - await cognee.add(text2, dataset_name="Sandwiches") - await cognee.add(text2) - - await cognee.cognify() - - # Make sure you've already run cognee.cognify(...) so the graph has content - answers = await cognee.search(query_text="What are the main themes in my data?") - assert len(answers) > 0 - - answers = await cognee.search( - query_text="List coding guidelines", - query_type=SearchType.CODING_RULES, - ) - assert len(answers) > 0 - - answers = await cognee.search( - query_text="Give me a confident answer: What is NLP?", - system_prompt="Answer succinctly and state confidence at the end.", - ) - assert len(answers) > 0 - - answers = await cognee.search( - query_text="Tell me about NLP", - only_context=True, - ) - assert len(answers) > 0 - - answers = await cognee.search( - query_text="Quarterly financial highlights", - datasets=["NLP_coding", "Sandwiches"], - use_combined_context=True, - ) - assert isinstance(answers, CombinedSearchResult) - - -asyncio.run(main()) diff --git a/cognee/tests/docs/guides/temporal_cognify.py b/cognee/tests/docs/guides/temporal_cognify.py deleted file mode 100644 index 34c1ee33c..000000000 --- a/cognee/tests/docs/guides/temporal_cognify.py +++ /dev/null @@ -1,57 +0,0 @@ -import asyncio -import cognee - - -async def main(): - text = """ - In 1998 the project launched. In 2001 version 1.0 shipped. In 2004 the team merged - with another group. In 2010 support for v1 ended. - """ - - await cognee.add(text, dataset_name="timeline_demo") - - await cognee.cognify(datasets=["timeline_demo"], temporal_cognify=True) - - from cognee.api.v1.search import SearchType - - # Before / after queries - result = await cognee.search( - query_type=SearchType.TEMPORAL, query_text="What happened before 2000?", top_k=10 - ) - - assert result != [] - - result = await cognee.search( - query_type=SearchType.TEMPORAL, query_text="What happened after 2010?", top_k=10 - ) - - assert result != [] - - # Between queries - result = await cognee.search( - query_type=SearchType.TEMPORAL, query_text="Events between 2001 and 2004", top_k=10 - ) - - assert result != [] - - # Scoped descriptions - result = await cognee.search( - query_type=SearchType.TEMPORAL, - query_text="Key project milestones between 1998 and 2010", - top_k=10, - ) - - assert result != [] - - result = await cognee.search( - query_type=SearchType.TEMPORAL, - query_text="What happened after 2004?", - datasets=["timeline_demo"], - top_k=10, - ) - - assert result != [] - - -if __name__ == "__main__": - asyncio.run(main()) From 41edeb0cf890e0d0b733bcd4befb03b870e70cbc Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Thu, 11 Dec 2025 16:01:26 +0100 Subject: [PATCH 052/176] test: change target repo name --- .github/workflows/release_test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release_test.yml b/.github/workflows/release_test.yml index c6dd68484..3fef0732a 100644 --- a/.github/workflows/release_test.yml +++ b/.github/workflows/release_test.yml @@ -27,5 +27,5 @@ jobs: curl -sS -X POST \ -H "Accept: application/vnd.github+json" \ -H "Authorization: Bearer ${{ secrets.DOCS_REPO_PAT_TOKEN }}" \ - https://api.github.com/repos/your-org/repo-b/dispatches \ + https://api.github.com/repos/topoteretes/cognee-docs/dispatches \ -d '{"event_type":"new-main-release","client_payload":{"caller_repo":"'"${GITHUB_REPOSITORY}"'"}}' From 0f4cf15d588e5dfa672d680e5258de284d308367 Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Thu, 11 Dec 2025 16:24:47 +0100 Subject: [PATCH 053/176] test: fix docs test trigger --- .github/workflows/release_test.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/release_test.yml b/.github/workflows/release_test.yml index 3fef0732a..76ce3b09d 100644 --- a/.github/workflows/release_test.yml +++ b/.github/workflows/release_test.yml @@ -24,8 +24,9 @@ jobs: steps: - name: Trigger docs tests run: | - curl -sS -X POST \ + curl -L -X POST \ -H "Accept: application/vnd.github+json" \ -H "Authorization: Bearer ${{ secrets.DOCS_REPO_PAT_TOKEN }}" \ + -H "X-GitHub-Api-Version: 2022-11-28" \ https://api.github.com/repos/topoteretes/cognee-docs/dispatches \ -d '{"event_type":"new-main-release","client_payload":{"caller_repo":"'"${GITHUB_REPOSITORY}"'"}}' From fd23c75c09d1cc22a406b72fc723a7975a6a4cca Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Fri, 12 Dec 2025 14:44:41 +0100 Subject: [PATCH 054/176] chore: adds new Unit tests for retrievers --- .../retrieval/chunks_retriever_test.py | 322 ++++--- .../retrieval/conversation_history_test.py | 338 ++++++++ ...letion_retriever_context_extension_test.py | 552 +++++++++--- .../graph_completion_retriever_cot_test.py | 794 +++++++++++++++--- .../graph_completion_retriever_test.py | 793 +++++++++++++---- .../rag_completion_retriever_test.py | 462 ++++++---- .../retrieval/structured_output_test.py | 204 ----- .../retrieval/summaries_retriever_test.py | 312 ++++--- .../retrieval/temporal_retriever_test.py | 597 +++++++++++-- .../test_brute_force_triplet_search.py | 227 ++++- .../unit/modules/retrieval/test_completion.py | 343 ++++++++ ...test_graph_summary_completion_retriever.py | 157 ++++ .../retrieval/test_user_qa_feedback.py | 312 +++++++ .../retrieval/triplet_retriever_test.py | 246 ++++++ 14 files changed, 4454 insertions(+), 1205 deletions(-) delete mode 100644 cognee/tests/unit/modules/retrieval/structured_output_test.py create mode 100644 cognee/tests/unit/modules/retrieval/test_completion.py create mode 100644 cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py create mode 100644 cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py diff --git a/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py b/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py index 44786f79d..98bfd48fe 100644 --- a/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py @@ -1,201 +1,183 @@ -import os import pytest -import pathlib -from typing import List -import cognee -from cognee.low_level import setup -from cognee.tasks.storage import add_data_points -from cognee.infrastructure.databases.vector import get_vector_engine -from cognee.modules.chunking.models import DocumentChunk -from cognee.modules.data.processing.document_types import TextDocument -from cognee.modules.retrieval.exceptions.exceptions import NoDataError +from unittest.mock import AsyncMock, patch, MagicMock + from cognee.modules.retrieval.chunks_retriever import ChunksRetriever -from cognee.infrastructure.engine import DataPoint -from cognee.modules.data.processing.document_types import Document -from cognee.modules.engine.models import Entity +from cognee.modules.retrieval.exceptions.exceptions import NoDataError +from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError -class DocumentChunkWithEntities(DataPoint): - text: str - chunk_size: int - chunk_index: int - cut_type: str - is_part_of: Document - contains: List[Entity] = None - - metadata: dict = {"index_fields": ["text"]} +@pytest.fixture +def mock_vector_engine(): + """Create a mock vector engine.""" + engine = AsyncMock() + engine.search = AsyncMock() + return engine -class TestChunksRetriever: - @pytest.mark.asyncio - async def test_chunk_context_simple(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_simple" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_simple" - ) - cognee.config.data_root_directory(data_directory_path) +@pytest.mark.asyncio +async def test_get_context_success(mock_vector_engine): + """Test successful retrieval of chunk context.""" + mock_result1 = MagicMock() + mock_result1.payload = {"text": "Steve Rodger", "chunk_index": 0} + mock_result2 = MagicMock() + mock_result2.payload = {"text": "Mike Broski", "chunk_index": 1} - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() + mock_vector_engine.search.return_value = [mock_result1, mock_result2] - document = TextDocument( - name="Steve Rodger's career", - raw_data_location="somewhere", - external_metadata="", - mime_type="text/plain", - ) + retriever = ChunksRetriever(top_k=5) - chunk1 = DocumentChunk( - text="Steve Rodger", - chunk_size=2, - chunk_index=0, - cut_type="sentence_end", - is_part_of=document, - contains=[], - ) - chunk2 = DocumentChunk( - text="Mike Broski", - chunk_size=2, - chunk_index=1, - cut_type="sentence_end", - is_part_of=document, - contains=[], - ) - chunk3 = DocumentChunk( - text="Christina Mayer", - chunk_size=2, - chunk_index=2, - cut_type="sentence_end", - is_part_of=document, - contains=[], - ) + with patch( + "cognee.modules.retrieval.chunks_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") - entities = [chunk1, chunk2, chunk3] + assert len(context) == 2 + assert context[0]["text"] == "Steve Rodger" + assert context[1]["text"] == "Mike Broski" + mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=5) - await add_data_points(entities) - retriever = ChunksRetriever() +@pytest.mark.asyncio +async def test_get_context_collection_not_found_error(mock_vector_engine): + """Test that CollectionNotFoundError is converted to NoDataError.""" + mock_vector_engine.search.side_effect = CollectionNotFoundError("Collection not found") - context = await retriever.get_context("Mike") + retriever = ChunksRetriever() - assert context[0]["text"] == "Mike Broski", "Failed to get Mike Broski" + with patch( + "cognee.modules.retrieval.chunks_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + with pytest.raises(NoDataError, match="No data found"): + await retriever.get_context("test query") - @pytest.mark.asyncio - async def test_chunk_context_complex(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_complex" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_complex" - ) - cognee.config.data_root_directory(data_directory_path) - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() +@pytest.mark.asyncio +async def test_get_context_empty_results(mock_vector_engine): + """Test that empty list is returned when no chunks are found.""" + mock_vector_engine.search.return_value = [] - document1 = TextDocument( - name="Employee List", - raw_data_location="somewhere", - external_metadata="", - mime_type="text/plain", - ) + retriever = ChunksRetriever() - document2 = TextDocument( - name="Car List", - raw_data_location="somewhere", - external_metadata="", - mime_type="text/plain", - ) + with patch( + "cognee.modules.retrieval.chunks_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") - chunk1 = DocumentChunk( - text="Steve Rodger", - chunk_size=2, - chunk_index=0, - cut_type="sentence_end", - is_part_of=document1, - contains=[], - ) - chunk2 = DocumentChunk( - text="Mike Broski", - chunk_size=2, - chunk_index=1, - cut_type="sentence_end", - is_part_of=document1, - contains=[], - ) - chunk3 = DocumentChunk( - text="Christina Mayer", - chunk_size=2, - chunk_index=2, - cut_type="sentence_end", - is_part_of=document1, - contains=[], - ) + assert context == [] - chunk4 = DocumentChunk( - text="Range Rover", - chunk_size=2, - chunk_index=0, - cut_type="sentence_end", - is_part_of=document2, - contains=[], - ) - chunk5 = DocumentChunk( - text="Hyundai", - chunk_size=2, - chunk_index=1, - cut_type="sentence_end", - is_part_of=document2, - contains=[], - ) - chunk6 = DocumentChunk( - text="Chrysler", - chunk_size=2, - chunk_index=2, - cut_type="sentence_end", - is_part_of=document2, - contains=[], - ) - entities = [chunk1, chunk2, chunk3, chunk4, chunk5, chunk6] +@pytest.mark.asyncio +async def test_get_context_top_k_limit(mock_vector_engine): + """Test that top_k parameter limits the number of results.""" + mock_results = [MagicMock() for _ in range(3)] + for i, result in enumerate(mock_results): + result.payload = {"text": f"Chunk {i}"} - await add_data_points(entities) + mock_vector_engine.search.return_value = mock_results - retriever = ChunksRetriever(top_k=20) + retriever = ChunksRetriever(top_k=3) - context = await retriever.get_context("Christina") + with patch( + "cognee.modules.retrieval.chunks_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") - assert context[0]["text"] == "Christina Mayer", "Failed to get Christina Mayer" + assert len(context) == 3 + mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=3) - @pytest.mark.asyncio - async def test_chunk_context_on_empty_graph(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_on_empty_graph" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_on_empty_graph" - ) - cognee.config.data_root_directory(data_directory_path) - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) +@pytest.mark.asyncio +async def test_get_completion_with_context(mock_vector_engine): + """Test get_completion returns provided context.""" + retriever = ChunksRetriever() - retriever = ChunksRetriever() + provided_context = [{"text": "Steve Rodger"}, {"text": "Mike Broski"}] + completion = await retriever.get_completion("test query", context=provided_context) - with pytest.raises(NoDataError): - await retriever.get_context("Christina Mayer") + assert completion == provided_context - vector_engine = get_vector_engine() - await vector_engine.create_collection( - "DocumentChunk_text", payload_schema=DocumentChunkWithEntities - ) - context = await retriever.get_context("Christina Mayer") - assert len(context) == 0, "Found chunks when none should exist" +@pytest.mark.asyncio +async def test_get_completion_without_context(mock_vector_engine): + """Test get_completion retrieves context when not provided.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Steve Rodger"} + mock_vector_engine.search.return_value = [mock_result] + + retriever = ChunksRetriever() + + with patch( + "cognee.modules.retrieval.chunks_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + completion = await retriever.get_completion("test query") + + assert len(completion) == 1 + assert completion[0]["text"] == "Steve Rodger" + + +@pytest.mark.asyncio +async def test_init_defaults(): + """Test ChunksRetriever initialization with defaults.""" + retriever = ChunksRetriever() + + assert retriever.top_k == 5 + + +@pytest.mark.asyncio +async def test_init_custom_top_k(): + """Test ChunksRetriever initialization with custom top_k.""" + retriever = ChunksRetriever(top_k=10) + + assert retriever.top_k == 10 + + +@pytest.mark.asyncio +async def test_init_none_top_k(): + """Test ChunksRetriever initialization with None top_k.""" + retriever = ChunksRetriever(top_k=None) + + assert retriever.top_k is None + + +@pytest.mark.asyncio +async def test_get_context_empty_payload(mock_vector_engine): + """Test get_context handles empty payload.""" + mock_result = MagicMock() + mock_result.payload = {} + + mock_vector_engine.search.return_value = [mock_result] + + retriever = ChunksRetriever() + + with patch( + "cognee.modules.retrieval.chunks_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") + + assert len(context) == 1 + assert context[0] == {} + + +@pytest.mark.asyncio +async def test_get_completion_with_session_id(mock_vector_engine): + """Test get_completion with session_id parameter.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Steve Rodger"} + mock_vector_engine.search.return_value = [mock_result] + + retriever = ChunksRetriever() + + with patch( + "cognee.modules.retrieval.chunks_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + completion = await retriever.get_completion("test query", session_id="test_session") + + assert len(completion) == 1 + assert completion[0]["text"] == "Steve Rodger" diff --git a/cognee/tests/unit/modules/retrieval/conversation_history_test.py b/cognee/tests/unit/modules/retrieval/conversation_history_test.py index d464a99d8..f1ce9b370 100644 --- a/cognee/tests/unit/modules/retrieval/conversation_history_test.py +++ b/cognee/tests/unit/modules/retrieval/conversation_history_test.py @@ -152,3 +152,341 @@ class TestConversationHistoryUtils: assert result is True call_kwargs = mock_cache.add_qa.call_args.kwargs assert call_kwargs["session_id"] == "default_session" + + @pytest.mark.asyncio + async def test_save_conversation_history_no_user_id(self): + """Test save_conversation_history returns False when user_id is None.""" + session_user.set(None) + + with patch("cognee.modules.retrieval.utils.session_cache.CacheConfig") as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = True + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + save_conversation_history, + ) + + result = await save_conversation_history( + query="Test question", + context_summary="Test context", + answer="Test answer", + ) + + assert result is False + + @pytest.mark.asyncio + async def test_save_conversation_history_caching_disabled(self): + """Test save_conversation_history returns False when caching is disabled.""" + user = create_mock_user() + session_user.set(user) + + with patch("cognee.modules.retrieval.utils.session_cache.CacheConfig") as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = False + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + save_conversation_history, + ) + + result = await save_conversation_history( + query="Test question", + context_summary="Test context", + answer="Test answer", + ) + + assert result is False + + @pytest.mark.asyncio + async def test_save_conversation_history_cache_engine_none(self): + """Test save_conversation_history returns False when cache_engine is None.""" + user = create_mock_user() + session_user.set(user) + + cache_module = importlib.import_module( + "cognee.infrastructure.databases.cache.get_cache_engine" + ) + + with patch.object(cache_module, "get_cache_engine", return_value=None): + with patch( + "cognee.modules.retrieval.utils.session_cache.CacheConfig" + ) as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = True + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + save_conversation_history, + ) + + result = await save_conversation_history( + query="Test question", + context_summary="Test context", + answer="Test answer", + ) + + assert result is False + + @pytest.mark.asyncio + async def test_save_conversation_history_cache_connection_error(self): + """Test save_conversation_history handles CacheConnectionError gracefully.""" + user = create_mock_user() + session_user.set(user) + + from cognee.infrastructure.databases.exceptions import CacheConnectionError + + mock_cache = create_mock_cache_engine([]) + mock_cache.add_qa = AsyncMock(side_effect=CacheConnectionError("Connection failed")) + + cache_module = importlib.import_module( + "cognee.infrastructure.databases.cache.get_cache_engine" + ) + + with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): + with patch( + "cognee.modules.retrieval.utils.session_cache.CacheConfig" + ) as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = True + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + save_conversation_history, + ) + + result = await save_conversation_history( + query="Test question", + context_summary="Test context", + answer="Test answer", + ) + + assert result is False + + @pytest.mark.asyncio + async def test_save_conversation_history_generic_exception(self): + """Test save_conversation_history handles generic exceptions gracefully.""" + user = create_mock_user() + session_user.set(user) + + mock_cache = create_mock_cache_engine([]) + mock_cache.add_qa = AsyncMock(side_effect=ValueError("Unexpected error")) + + cache_module = importlib.import_module( + "cognee.infrastructure.databases.cache.get_cache_engine" + ) + + with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): + with patch( + "cognee.modules.retrieval.utils.session_cache.CacheConfig" + ) as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = True + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + save_conversation_history, + ) + + result = await save_conversation_history( + query="Test question", + context_summary="Test context", + answer="Test answer", + ) + + assert result is False + + @pytest.mark.asyncio + async def test_get_conversation_history_no_user_id(self): + """Test get_conversation_history returns empty string when user_id is None.""" + session_user.set(None) + + with patch("cognee.modules.retrieval.utils.session_cache.CacheConfig") as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = True + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + get_conversation_history, + ) + + result = await get_conversation_history(session_id="test_session") + + assert result == "" + + @pytest.mark.asyncio + async def test_get_conversation_history_caching_disabled(self): + """Test get_conversation_history returns empty string when caching is disabled.""" + user = create_mock_user() + session_user.set(user) + + with patch("cognee.modules.retrieval.utils.session_cache.CacheConfig") as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = False + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + get_conversation_history, + ) + + result = await get_conversation_history(session_id="test_session") + + assert result == "" + + @pytest.mark.asyncio + async def test_get_conversation_history_default_session(self): + """Test get_conversation_history uses 'default_session' when session_id is None.""" + user = create_mock_user() + session_user.set(user) + + mock_cache = create_mock_cache_engine([]) + + cache_module = importlib.import_module( + "cognee.infrastructure.databases.cache.get_cache_engine" + ) + + with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): + with patch( + "cognee.modules.retrieval.utils.session_cache.CacheConfig" + ) as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = True + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + get_conversation_history, + ) + + await get_conversation_history(session_id=None) + + mock_cache.get_latest_qa.assert_called_once_with(str(user.id), "default_session") + + @pytest.mark.asyncio + async def test_get_conversation_history_cache_engine_none(self): + """Test get_conversation_history returns empty string when cache_engine is None.""" + user = create_mock_user() + session_user.set(user) + + cache_module = importlib.import_module( + "cognee.infrastructure.databases.cache.get_cache_engine" + ) + + with patch.object(cache_module, "get_cache_engine", return_value=None): + with patch( + "cognee.modules.retrieval.utils.session_cache.CacheConfig" + ) as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = True + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + get_conversation_history, + ) + + result = await get_conversation_history(session_id="test_session") + + assert result == "" + + @pytest.mark.asyncio + async def test_get_conversation_history_cache_connection_error(self): + """Test get_conversation_history handles CacheConnectionError gracefully.""" + user = create_mock_user() + session_user.set(user) + + from cognee.infrastructure.databases.exceptions import CacheConnectionError + + mock_cache = create_mock_cache_engine([]) + mock_cache.get_latest_qa = AsyncMock(side_effect=CacheConnectionError("Connection failed")) + + cache_module = importlib.import_module( + "cognee.infrastructure.databases.cache.get_cache_engine" + ) + + with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): + with patch( + "cognee.modules.retrieval.utils.session_cache.CacheConfig" + ) as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = True + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + get_conversation_history, + ) + + result = await get_conversation_history(session_id="test_session") + + assert result == "" + + @pytest.mark.asyncio + async def test_get_conversation_history_generic_exception(self): + """Test get_conversation_history handles generic exceptions gracefully.""" + user = create_mock_user() + session_user.set(user) + + mock_cache = create_mock_cache_engine([]) + mock_cache.get_latest_qa = AsyncMock(side_effect=ValueError("Unexpected error")) + + cache_module = importlib.import_module( + "cognee.infrastructure.databases.cache.get_cache_engine" + ) + + with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): + with patch( + "cognee.modules.retrieval.utils.session_cache.CacheConfig" + ) as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = True + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + get_conversation_history, + ) + + result = await get_conversation_history(session_id="test_session") + + assert result == "" + + @pytest.mark.asyncio + async def test_get_conversation_history_missing_keys(self): + """Test get_conversation_history handles missing keys in history entries.""" + user = create_mock_user() + session_user.set(user) + + mock_history = [ + { + "time": "2024-01-15 10:30:45", + "question": "What is AI?", + }, + { + "context": "AI is artificial intelligence", + "answer": "AI stands for Artificial Intelligence", + }, + {}, + ] + mock_cache = create_mock_cache_engine(mock_history) + + cache_module = importlib.import_module( + "cognee.infrastructure.databases.cache.get_cache_engine" + ) + + with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): + with patch( + "cognee.modules.retrieval.utils.session_cache.CacheConfig" + ) as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = True + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + get_conversation_history, + ) + + result = await get_conversation_history(session_id="test_session") + + assert "Previous conversation:" in result + assert "[2024-01-15 10:30:45]" in result + assert "QUESTION: What is AI?" in result + assert "Unknown time" in result + assert "CONTEXT: AI is artificial intelligence" in result + assert "ANSWER: AI stands for Artificial Intelligence" in result diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py index 0e21fe351..6a9b07d38 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py @@ -1,177 +1,469 @@ -import os import pytest -import pathlib -from typing import Optional, Union +from unittest.mock import AsyncMock, patch, MagicMock +from uuid import UUID -import cognee -from cognee.low_level import setup, DataPoint -from cognee.tasks.storage import add_data_points -from cognee.modules.graph.utils import resolve_edges_to_text from cognee.modules.retrieval.graph_completion_context_extension_retriever import ( GraphCompletionContextExtensionRetriever, ) +from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge -class TestGraphCompletionWithContextExtensionRetriever: - @pytest.mark.asyncio - async def test_graph_completion_extension_context_simple(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".cognee_system/test_graph_completion_extension_context_simple", - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".data_storage/test_graph_completion_extension_context_simple", - ) - cognee.config.data_root_directory(data_directory_path) +@pytest.fixture +def mock_edge(): + """Create a mock edge.""" + edge = MagicMock(spec=Edge) + return edge - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - class Company(DataPoint): - name: str +@pytest.mark.asyncio +async def test_get_triplets_inherited(mock_edge): + """Test that get_triplets is inherited from parent class.""" + retriever = GraphCompletionContextExtensionRetriever() - class Person(DataPoint): - name: str - works_for: Company + with patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ): + triplets = await retriever.get_triplets("test query") - company1 = Company(name="Figma") - company2 = Company(name="Canva") - person1 = Person(name="Steve Rodger", works_for=company1) - person2 = Person(name="Ike Loma", works_for=company1) - person3 = Person(name="Jason Statham", works_for=company1) - person4 = Person(name="Mike Broski", works_for=company2) - person5 = Person(name="Christina Mayer", works_for=company2) + assert len(triplets) == 1 + assert triplets[0] == mock_edge - entities = [company1, company2, person1, person2, person3, person4, person5] - await add_data_points(entities) +@pytest.mark.asyncio +async def test_init_defaults(): + """Test GraphCompletionContextExtensionRetriever initialization with defaults.""" + retriever = GraphCompletionContextExtensionRetriever() - retriever = GraphCompletionContextExtensionRetriever() + assert retriever.top_k == 5 + assert retriever.user_prompt_path == "graph_context_for_question.txt" + assert retriever.system_prompt_path == "answer_simple_question.txt" - context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?")) - assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski" - assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer" +@pytest.mark.asyncio +async def test_init_custom_params(): + """Test GraphCompletionContextExtensionRetriever initialization with custom parameters.""" + retriever = GraphCompletionContextExtensionRetriever( + top_k=10, + user_prompt_path="custom_user.txt", + system_prompt_path="custom_system.txt", + system_prompt="Custom prompt", + node_type=str, + node_name=["node1"], + save_interaction=True, + wide_search_top_k=200, + triplet_distance_penalty=5.0, + ) - answer = await retriever.get_completion("Who works at Canva?") + assert retriever.top_k == 10 + assert retriever.user_prompt_path == "custom_user.txt" + assert retriever.system_prompt_path == "custom_system.txt" + assert retriever.system_prompt == "Custom prompt" + assert retriever.node_type is str + assert retriever.node_name == ["node1"] + assert retriever.save_interaction is True + assert retriever.wide_search_top_k == 200 + assert retriever.triplet_distance_penalty == 5.0 - assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" - assert all(isinstance(item, str) and item.strip() for item in answer), ( - "Answer must contain only non-empty strings" + +@pytest.mark.asyncio +async def test_get_completion_without_context(mock_edge): + """Test get_completion retrieves context when not provided.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionContextExtensionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", context_extension_rounds=1) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_provided_context(mock_edge): + """Test get_completion uses provided context.""" + retriever = GraphCompletionContextExtensionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + "test query", context=[mock_edge], context_extension_rounds=1 ) - @pytest.mark.asyncio - async def test_graph_completion_extension_context_complex(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".cognee_system/test_graph_completion_extension_context_complex", - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".data_storage/test_graph_completion_extension_context_complex", - ) - cognee.config.data_root_directory(data_directory_path) + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - class Company(DataPoint): - name: str - metadata: dict = {"index_fields": ["name"]} +@pytest.mark.asyncio +async def test_get_completion_context_extension_rounds(mock_edge): + """Test get_completion with multiple context extension rounds.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) - class Car(DataPoint): - brand: str - model: str - year: int + retriever = GraphCompletionContextExtensionRetriever() - class Location(DataPoint): - country: str - city: str + # Create a second edge for extension rounds + mock_edge2 = MagicMock(spec=Edge) - class Home(DataPoint): - location: Location - rooms: int - sqm: int + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch.object( + retriever, + "get_context", + new_callable=AsyncMock, + side_effect=[[mock_edge], [mock_edge2]], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + side_effect=["Resolved context", "Extended context"], # Different contexts + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + side_effect=[ + "Extension query", + "Generated answer", + ], # Query for extension, then final answer + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config - class Person(DataPoint): - name: str - works_for: Company - owns: Optional[list[Union[Car, Home]]] = None + completion = await retriever.get_completion("test query", context_extension_rounds=1) - company1 = Company(name="Figma") - company2 = Company(name="Canva") + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" - person1 = Person(name="Mike Rodger", works_for=company1) - person1.owns = [Car(brand="Toyota", model="Camry", year=2020)] - person2 = Person(name="Ike Loma", works_for=company1) - person2.owns = [ - Car(brand="Tesla", model="Model S", year=2021), - Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4), - ] +@pytest.mark.asyncio +async def test_get_completion_context_extension_stops_early(mock_edge): + """Test get_completion stops early when no new triplets found.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) - person3 = Person(name="Jason Statham", works_for=company1) + retriever = GraphCompletionContextExtensionRetriever() - person4 = Person(name="Mike Broski", works_for=company2) - person4.owns = [Car(brand="Ford", model="Mustang", year=1978)] + with ( + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + side_effect=[ + "Extension query", + "Generated answer", + ], + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config - person5 = Person(name="Christina Mayer", works_for=company2) - person5.owns = [Car(brand="Honda", model="Civic", year=2023)] - - entities = [company1, company2, person1, person2, person3, person4, person5] - - await add_data_points(entities) - - retriever = GraphCompletionContextExtensionRetriever(top_k=20) - - context = await resolve_edges_to_text( - await retriever.get_context("Who works at Figma and drives Tesla?") + # When get_context returns same triplets, the loop should stop early + completion = await retriever.get_completion( + "test query", context=[mock_edge], context_extension_rounds=4 ) - print(context) + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" - assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger" - assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma" - assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham" - answer = await retriever.get_completion("Who works at Figma?") +@pytest.mark.asyncio +async def test_get_completion_with_session(mock_edge): + """Test get_completion with session caching enabled.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) - assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" - assert all(isinstance(item, str) and item.strip() for item in answer), ( - "Answer must contain only non-empty strings" + retriever = GraphCompletionContextExtensionRetriever() + + mock_user = MagicMock() + mock_user.id = "test-user-id" + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.get_conversation_history", + return_value="Previous conversation", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.summarize_text", + return_value="Context summary", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + side_effect=[ + "Extension query", + "Generated answer", + ], # Extension query, then final answer + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.save_conversation_history", + ) as mock_save, + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.session_user" + ) as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = mock_user + + completion = await retriever.get_completion( + "test query", session_id="test_session", context_extension_rounds=1 ) - @pytest.mark.asyncio - async def test_get_graph_completion_extension_context_on_empty_graph(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".cognee_system/test_get_graph_completion_extension_context_on_empty_graph", + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + mock_save.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_completion_with_save_interaction(mock_edge): + """Test get_completion with save_interaction enabled.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + mock_graph_engine.add_edges = AsyncMock() + + retriever = GraphCompletionContextExtensionRetriever(save_interaction=True) + + mock_node1 = MagicMock() + mock_node2 = MagicMock() + mock_edge.node1 = mock_node1 + mock_edge.node2 = mock_node2 + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + side_effect=[ + "Extension query", + "Generated answer", + ], # Extension query, then final answer + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node", + side_effect=[ + UUID("550e8400-e29b-41d4-a716-446655440000"), + UUID("550e8400-e29b-41d4-a716-446655440001"), + ], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.add_data_points", + ) as mock_add_data, + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + "test query", context=[mock_edge], context_extension_rounds=1 ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".data_storage/test_get_graph_completion_extension_context_on_empty_graph", + + assert isinstance(completion, list) + assert len(completion) == 1 + mock_add_data.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_completion_with_response_model(mock_edge): + """Test get_completion with custom response model.""" + from pydantic import BaseModel + + class TestModel(BaseModel): + answer: str + + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionContextExtensionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + side_effect=[ + "Extension query", + TestModel(answer="Test answer"), + ], # Extension query, then final answer + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + "test query", response_model=TestModel, context_extension_rounds=1 ) - cognee.config.data_root_directory(data_directory_path) - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) + assert isinstance(completion, list) + assert len(completion) == 1 + assert isinstance(completion[0], TestModel) - retriever = GraphCompletionContextExtensionRetriever() - await setup() +@pytest.mark.asyncio +async def test_get_completion_with_session_no_user_id(mock_edge): + """Test get_completion with session config but no user ID.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) - context = await retriever.get_context("Who works at Figma?") - assert context == [], "Context should be empty on an empty graph" + retriever = GraphCompletionContextExtensionRetriever() - answer = await retriever.get_completion("Who works at Figma?") + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + side_effect=[ + "Extension query", + "Generated answer", + ], # Extension query, then final answer + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.session_user" + ) as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = None # No user - assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" - assert all(isinstance(item, str) and item.strip() for item in answer), ( - "Answer must contain only non-empty strings" - ) + completion = await retriever.get_completion("test query", context_extension_rounds=1) + + assert isinstance(completion, list) + assert len(completion) == 1 + + +@pytest.mark.asyncio +async def test_get_completion_zero_extension_rounds(mock_edge): + """Test get_completion with zero context extension rounds.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionContextExtensionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", context_extension_rounds=0) + + assert isinstance(completion, list) + assert len(completion) == 1 diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py index 206cfaf84..9f3147512 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py @@ -1,170 +1,688 @@ -import os import pytest -import pathlib -from typing import Optional, Union +from unittest.mock import AsyncMock, patch, MagicMock +from uuid import UUID -import cognee -from cognee.low_level import setup, DataPoint -from cognee.modules.graph.utils import resolve_edges_to_text -from cognee.tasks.storage import add_data_points -from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever +from cognee.modules.retrieval.graph_completion_cot_retriever import ( + GraphCompletionCotRetriever, + _as_answer_text, +) +from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge +from cognee.infrastructure.llm.LLMGateway import LLMGateway -class TestGraphCompletionCoTRetriever: - @pytest.mark.asyncio - async def test_graph_completion_cot_context_simple(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_cot_context_simple" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_cot_context_simple" - ) - cognee.config.data_root_directory(data_directory_path) +@pytest.fixture +def mock_edge(): + """Create a mock edge.""" + edge = MagicMock(spec=Edge) + return edge - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - class Company(DataPoint): - name: str +@pytest.mark.asyncio +async def test_get_triplets_inherited(mock_edge): + """Test that get_triplets is inherited from parent class.""" + retriever = GraphCompletionCotRetriever() - class Person(DataPoint): - name: str - works_for: Company + with patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ): + triplets = await retriever.get_triplets("test query") - company1 = Company(name="Figma") - company2 = Company(name="Canva") - person1 = Person(name="Steve Rodger", works_for=company1) - person2 = Person(name="Ike Loma", works_for=company1) - person3 = Person(name="Jason Statham", works_for=company1) - person4 = Person(name="Mike Broski", works_for=company2) - person5 = Person(name="Christina Mayer", works_for=company2) + assert len(triplets) == 1 + assert triplets[0] == mock_edge - entities = [company1, company2, person1, person2, person3, person4, person5] - await add_data_points(entities) +@pytest.mark.asyncio +async def test_init_custom_params(): + """Test GraphCompletionCotRetriever initialization with custom parameters.""" + retriever = GraphCompletionCotRetriever( + top_k=10, + user_prompt_path="custom_user.txt", + system_prompt_path="custom_system.txt", + validation_user_prompt_path="custom_validation_user.txt", + validation_system_prompt_path="custom_validation_system.txt", + followup_system_prompt_path="custom_followup_system.txt", + followup_user_prompt_path="custom_followup_user.txt", + ) - retriever = GraphCompletionCotRetriever() + assert retriever.top_k == 10 + assert retriever.user_prompt_path == "custom_user.txt" + assert retriever.system_prompt_path == "custom_system.txt" + assert retriever.validation_user_prompt_path == "custom_validation_user.txt" + assert retriever.validation_system_prompt_path == "custom_validation_system.txt" + assert retriever.followup_system_prompt_path == "custom_followup_system.txt" + assert retriever.followup_user_prompt_path == "custom_followup_user.txt" - context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?")) - assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski" - assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer" +@pytest.mark.asyncio +async def test_init_defaults(): + """Test GraphCompletionCotRetriever initialization with defaults.""" + retriever = GraphCompletionCotRetriever() - answer = await retriever.get_completion("Who works at Canva?") + assert retriever.validation_user_prompt_path == "cot_validation_user_prompt.txt" + assert retriever.validation_system_prompt_path == "cot_validation_system_prompt.txt" + assert retriever.followup_system_prompt_path == "cot_followup_system_prompt.txt" + assert retriever.followup_user_prompt_path == "cot_followup_user_prompt.txt" - assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" - assert all(isinstance(item, str) and item.strip() for item in answer), ( - "Answer must contain only non-empty strings" + +@pytest.mark.asyncio +async def test_run_cot_completion_round_zero_with_context(mock_edge): + """Test _run_cot_completion round 0 with provided context.""" + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt", + return_value="Rendered prompt", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt", + return_value="System prompt", + ), + patch.object( + LLMGateway, + "acreate_structured_output", + new_callable=AsyncMock, + side_effect=["validation_result", "followup_question"], + ), + ): + completion, context_text, triplets = await retriever._run_cot_completion( + query="test query", + context=[mock_edge], + max_iter=1, ) - @pytest.mark.asyncio - async def test_graph_completion_cot_context_complex(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".cognee_system/test_graph_completion_cot_context_complex", - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_cot_context_complex" - ) - cognee.config.data_root_directory(data_directory_path) + assert completion == "Generated answer" + assert context_text == "Resolved context" + assert len(triplets) >= 1 - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - class Company(DataPoint): - name: str - metadata: dict = {"index_fields": ["name"]} +@pytest.mark.asyncio +async def test_run_cot_completion_round_zero_without_context(mock_edge): + """Test _run_cot_completion round 0 without provided context.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) - class Car(DataPoint): - brand: str - model: str - year: int + retriever = GraphCompletionCotRetriever() - class Location(DataPoint): - country: str - city: str - - class Home(DataPoint): - location: Location - rooms: int - sqm: int - - class Person(DataPoint): - name: str - works_for: Company - owns: Optional[list[Union[Car, Home]]] = None - - company1 = Company(name="Figma") - company2 = Company(name="Canva") - - person1 = Person(name="Mike Rodger", works_for=company1) - person1.owns = [Car(brand="Toyota", model="Camry", year=2020)] - - person2 = Person(name="Ike Loma", works_for=company1) - person2.owns = [ - Car(brand="Tesla", model="Model S", year=2021), - Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4), - ] - - person3 = Person(name="Jason Statham", works_for=company1) - - person4 = Person(name="Mike Broski", works_for=company2) - person4.owns = [Car(brand="Ford", model="Mustang", year=1978)] - - person5 = Person(name="Christina Mayer", works_for=company2) - person5.owns = [Car(brand="Honda", model="Civic", year=2023)] - - entities = [company1, company2, person1, person2, person3, person4, person5] - - await add_data_points(entities) - - retriever = GraphCompletionCotRetriever(top_k=20) - - context = await resolve_edges_to_text(await retriever.get_context("Who works at Figma?")) - - print(context) - - assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger" - assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma" - assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham" - - answer = await retriever.get_completion("Who works at Figma?") - - assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" - assert all(isinstance(item, str) and item.strip() for item in answer), ( - "Answer must contain only non-empty strings" + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + ): + completion, context_text, triplets = await retriever._run_cot_completion( + query="test query", + context=None, + max_iter=1, ) - @pytest.mark.asyncio - async def test_get_graph_completion_cot_context_on_empty_graph(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".cognee_system/test_get_graph_completion_cot_context_on_empty_graph", + assert completion == "Generated answer" + assert context_text == "Resolved context" + assert len(triplets) >= 1 + + +@pytest.mark.asyncio +async def test_run_cot_completion_multiple_rounds(mock_edge): + """Test _run_cot_completion with multiple rounds.""" + retriever = GraphCompletionCotRetriever() + + mock_edge2 = MagicMock(spec=Edge) + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + patch.object( + retriever, + "get_context", + new_callable=AsyncMock, + side_effect=[[mock_edge], [mock_edge2]], + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt", + return_value="Rendered prompt", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt", + return_value="System prompt", + ), + patch.object( + LLMGateway, + "acreate_structured_output", + new_callable=AsyncMock, + side_effect=[ + "validation_result", + "followup_question", + "validation_result2", + "followup_question2", + ], + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", + return_value="Generated answer", + ), + ): + completion, context_text, triplets = await retriever._run_cot_completion( + query="test query", + context=[mock_edge], + max_iter=2, ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".data_storage/test_get_graph_completion_cot_context_on_empty_graph", + + assert completion == "Generated answer" + assert context_text == "Resolved context" + assert len(triplets) >= 1 + + +@pytest.mark.asyncio +async def test_run_cot_completion_with_conversation_history(mock_edge): + """Test _run_cot_completion with conversation history.""" + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ) as mock_generate, + ): + completion, context_text, triplets = await retriever._run_cot_completion( + query="test query", + context=[mock_edge], + conversation_history="Previous conversation", + max_iter=1, ) - cognee.config.data_root_directory(data_directory_path) - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) + assert completion == "Generated answer" + call_kwargs = mock_generate.call_args[1] + assert call_kwargs.get("conversation_history") == "Previous conversation" - retriever = GraphCompletionCotRetriever() - await setup() +@pytest.mark.asyncio +async def test_run_cot_completion_with_response_model(mock_edge): + """Test _run_cot_completion with custom response model.""" + from pydantic import BaseModel - context = await retriever.get_context("Who works at Figma?") - assert context == [], "Context should be empty on an empty graph" + class TestModel(BaseModel): + answer: str - answer = await retriever.get_completion("Who works at Figma?") + retriever = GraphCompletionCotRetriever() - assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" - assert all(isinstance(item, str) and item.strip() for item in answer), ( - "Answer must contain only non-empty strings" + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value=TestModel(answer="Test answer"), + ), + ): + completion, context_text, triplets = await retriever._run_cot_completion( + query="test query", + context=[mock_edge], + response_model=TestModel, + max_iter=1, ) + + assert isinstance(completion, TestModel) + assert completion.answer == "Test answer" + + +@pytest.mark.asyncio +async def test_run_cot_completion_empty_conversation_history(mock_edge): + """Test _run_cot_completion with empty conversation history.""" + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ) as mock_generate, + ): + completion, context_text, triplets = await retriever._run_cot_completion( + query="test query", + context=[mock_edge], + conversation_history="", + max_iter=1, + ) + + assert completion == "Generated answer" + # Verify conversation_history was passed as None when empty + call_kwargs = mock_generate.call_args[1] + assert call_kwargs.get("conversation_history") is None + + +@pytest.mark.asyncio +async def test_get_completion_without_context(mock_edge): + """Test get_completion retrieves context when not provided.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt", + return_value="Rendered prompt", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt", + return_value="System prompt", + ), + patch.object( + LLMGateway, + "acreate_structured_output", + new_callable=AsyncMock, + side_effect=["validation_result", "followup_question"], + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", max_iter=1) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_provided_context(mock_edge): + """Test get_completion uses provided context.""" + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", context=[mock_edge], max_iter=1) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_session(mock_edge): + """Test get_completion with session caching enabled.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionCotRetriever() + + mock_user = MagicMock() + mock_user.id = "test-user-id" + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.get_conversation_history", + return_value="Previous conversation", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.summarize_text", + return_value="Context summary", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.save_conversation_history", + ) as mock_save, + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" + ) as mock_cache_config, + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.session_user" + ) as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = mock_user + + completion = await retriever.get_completion( + "test query", session_id="test_session", max_iter=1 + ) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + mock_save.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_completion_with_save_interaction(mock_edge): + """Test get_completion with save_interaction enabled.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + mock_graph_engine.add_edges = AsyncMock() + + retriever = GraphCompletionCotRetriever(save_interaction=True) + + mock_node1 = MagicMock() + mock_node2 = MagicMock() + mock_edge.node1 = mock_node1 + mock_edge.node2 = mock_node2 + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt", + return_value="Rendered prompt", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt", + return_value="System prompt", + ), + patch.object( + LLMGateway, + "acreate_structured_output", + new_callable=AsyncMock, + side_effect=["validation_result", "followup_question"], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node", + side_effect=[ + UUID("550e8400-e29b-41d4-a716-446655440000"), + UUID("550e8400-e29b-41d4-a716-446655440001"), + ], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.add_data_points", + ) as mock_add_data, + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + # Pass context so save_interaction condition is met + completion = await retriever.get_completion("test query", context=[mock_edge], max_iter=1) + + assert isinstance(completion, list) + assert len(completion) == 1 + mock_add_data.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_completion_with_response_model(mock_edge): + """Test get_completion with custom response model.""" + from pydantic import BaseModel + + class TestModel(BaseModel): + answer: str + + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value=TestModel(answer="Test answer"), + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + "test query", response_model=TestModel, max_iter=1 + ) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert isinstance(completion[0], TestModel) + + +@pytest.mark.asyncio +async def test_get_completion_with_session_no_user_id(mock_edge): + """Test get_completion with session config but no user ID.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" + ) as mock_cache_config, + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.session_user" + ) as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = None # No user + + completion = await retriever.get_completion("test query", max_iter=1) + + assert isinstance(completion, list) + assert len(completion) == 1 + + +@pytest.mark.asyncio +async def test_get_completion_with_save_interaction_no_context(mock_edge): + """Test get_completion with save_interaction but no context provided.""" + retriever = GraphCompletionCotRetriever(save_interaction=True) + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt", + return_value="Rendered prompt", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt", + return_value="System prompt", + ), + patch.object( + LLMGateway, + "acreate_structured_output", + new_callable=AsyncMock, + side_effect=["validation_result", "followup_question"], + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", context=None, max_iter=1) + + assert isinstance(completion, list) + assert len(completion) == 1 + + +@pytest.mark.asyncio +async def test_as_answer_text_with_typeerror(): + """Test _as_answer_text handles TypeError when json.dumps fails.""" + non_serializable = {1, 2, 3} + + result = _as_answer_text(non_serializable) + + assert isinstance(result, str) + assert result == str(non_serializable) + + +@pytest.mark.asyncio +async def test_as_answer_text_with_string(): + """Test _as_answer_text with string input.""" + result = _as_answer_text("test string") + assert result == "test string" + + +@pytest.mark.asyncio +async def test_as_answer_text_with_dict(): + """Test _as_answer_text with dictionary input.""" + test_dict = {"key": "value", "number": 42} + result = _as_answer_text(test_dict) + assert isinstance(result, str) + assert "key" in result + assert "value" in result + + +@pytest.mark.asyncio +async def test_as_answer_text_with_basemodel(): + """Test _as_answer_text with Pydantic BaseModel input.""" + from pydantic import BaseModel + + class TestModel(BaseModel): + answer: str + + test_model = TestModel(answer="test answer") + result = _as_answer_text(test_model) + + assert isinstance(result, str) + assert "[Structured Response]" in result + assert "test answer" in result diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py index f462baced..c22f30fd0 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py @@ -1,223 +1,648 @@ -import os import pytest -import pathlib -from typing import Optional, Union +from unittest.mock import AsyncMock, patch, MagicMock +from uuid import UUID -import cognee -from cognee.low_level import setup, DataPoint -from cognee.modules.graph.utils import resolve_edges_to_text -from cognee.tasks.storage import add_data_points from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever +from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge -class TestGraphCompletionRetriever: - @pytest.mark.asyncio - async def test_graph_completion_context_simple(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context_simple" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context_simple" - ) - cognee.config.data_root_directory(data_directory_path) +@pytest.fixture +def mock_edge(): + """Create a mock edge.""" + edge = MagicMock(spec=Edge) + return edge - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - class Company(DataPoint): - name: str - description: str +@pytest.mark.asyncio +async def test_get_triplets_success(mock_edge): + """Test successful retrieval of triplets.""" + retriever = GraphCompletionRetriever(top_k=5) - class Person(DataPoint): - name: str - description: str - works_for: Company + with patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ) as mock_search: + triplets = await retriever.get_triplets("test query") - company1 = Company(name="Figma", description="Figma is a company") - company2 = Company(name="Canva", description="Canvas is a company") - person1 = Person( - name="Steve Rodger", - description="This is description about Steve Rodger", - works_for=company1, - ) - person2 = Person( - name="Ike Loma", description="This is description about Ike Loma", works_for=company1 - ) - person3 = Person( - name="Jason Statham", - description="This is description about Jason Statham", - works_for=company1, - ) - person4 = Person( - name="Mike Broski", - description="This is description about Mike Broski", - works_for=company2, - ) - person5 = Person( - name="Christina Mayer", - description="This is description about Christina Mayer", - works_for=company2, + assert len(triplets) == 1 + assert triplets[0] == mock_edge + mock_search.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_triplets_empty_results(): + """Test that empty list is returned when no triplets are found.""" + retriever = GraphCompletionRetriever() + + with patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[], + ): + triplets = await retriever.get_triplets("test query") + + assert triplets == [] + + +@pytest.mark.asyncio +async def test_get_triplets_top_k_parameter(): + """Test that top_k parameter is passed to brute_force_triplet_search.""" + retriever = GraphCompletionRetriever(top_k=10) + + with patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[], + ) as mock_search: + await retriever.get_triplets("test query") + + call_kwargs = mock_search.call_args[1] + assert call_kwargs["top_k"] == 10 + + +@pytest.mark.asyncio +async def test_get_context_success(mock_edge): + """Test successful retrieval of context.""" + retriever = GraphCompletionRetriever() + + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + ): + context = await retriever.get_context("test query") + + assert isinstance(context, list) + assert len(context) == 1 + assert context[0] == mock_edge + + +@pytest.mark.asyncio +async def test_get_context_empty_results(): + """Test that empty list is returned when no context is found.""" + retriever = GraphCompletionRetriever() + + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[], + ), + ): + context = await retriever.get_context("test query") + + assert context == [] + + +@pytest.mark.asyncio +async def test_get_context_empty_graph(): + """Test that empty list is returned when graph is empty.""" + retriever = GraphCompletionRetriever() + + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=True) + + with patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ): + context = await retriever.get_context("test query") + + assert context == [] + + +@pytest.mark.asyncio +async def test_resolve_edges_to_text(mock_edge): + """Test resolve_edges_to_text method.""" + retriever = GraphCompletionRetriever() + + with patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved text", + ) as mock_resolve: + result = await retriever.resolve_edges_to_text([mock_edge]) + + assert result == "Resolved text" + mock_resolve.assert_awaited_once_with([mock_edge]) + + +@pytest.mark.asyncio +async def test_init_defaults(): + """Test GraphCompletionRetriever initialization with defaults.""" + retriever = GraphCompletionRetriever() + + assert retriever.top_k == 5 + assert retriever.user_prompt_path == "graph_context_for_question.txt" + assert retriever.system_prompt_path == "answer_simple_question.txt" + assert retriever.node_type is None + assert retriever.node_name is None + + +@pytest.mark.asyncio +async def test_init_custom_params(): + """Test GraphCompletionRetriever initialization with custom parameters.""" + retriever = GraphCompletionRetriever( + top_k=10, + user_prompt_path="custom_user.txt", + system_prompt_path="custom_system.txt", + system_prompt="Custom prompt", + node_type=str, + node_name=["node1"], + save_interaction=True, + wide_search_top_k=200, + triplet_distance_penalty=5.0, + ) + + assert retriever.top_k == 10 + assert retriever.user_prompt_path == "custom_user.txt" + assert retriever.system_prompt_path == "custom_system.txt" + assert retriever.system_prompt == "Custom prompt" + assert retriever.node_type is str + assert retriever.node_name == ["node1"] + assert retriever.save_interaction is True + assert retriever.wide_search_top_k == 200 + assert retriever.triplet_distance_penalty == 5.0 + + +@pytest.mark.asyncio +async def test_init_none_top_k(): + """Test GraphCompletionRetriever initialization with None top_k.""" + retriever = GraphCompletionRetriever(top_k=None) + + assert retriever.top_k == 5 # None defaults to 5 + + +@pytest.mark.asyncio +async def test_convert_retrieved_objects_to_context(mock_edge): + """Test convert_retrieved_objects_to_context method.""" + retriever = GraphCompletionRetriever() + + with patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved text", + ) as mock_resolve: + result = await retriever.convert_retrieved_objects_to_context([mock_edge]) + + assert result == "Resolved text" + mock_resolve.assert_awaited_once_with([mock_edge]) + + +@pytest.mark.asyncio +async def test_get_completion_without_context(mock_edge): + """Test get_completion retrieves context when not provided.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_provided_context(mock_edge): + """Test get_completion uses provided context.""" + retriever = GraphCompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", context=[mock_edge]) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_session(mock_edge): + """Test get_completion with session caching enabled.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionRetriever() + + mock_user = MagicMock() + mock_user.id = "test-user-id" + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_conversation_history", + return_value="Previous conversation", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.summarize_text", + return_value="Context summary", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.save_conversation_history", + ) as mock_save, + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + patch( + "cognee.modules.retrieval.graph_completion_retriever.session_user" + ) as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = mock_user + + completion = await retriever.get_completion("test query", session_id="test_session") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + mock_save.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_completion_with_response_model(mock_edge): + """Test get_completion with custom response model.""" + from pydantic import BaseModel + + class TestModel(BaseModel): + answer: str + + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value=TestModel(answer="Test answer"), + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", response_model=TestModel) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert isinstance(completion[0], TestModel) + + +@pytest.mark.asyncio +async def test_get_completion_empty_context(mock_edge): + """Test get_completion with empty context.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query") + + assert isinstance(completion, list) + assert len(completion) == 1 + + +@pytest.mark.asyncio +async def test_save_qa(mock_edge): + """Test save_qa method.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.add_edges = AsyncMock() + + retriever = GraphCompletionRetriever() + + mock_node1 = MagicMock() + mock_node2 = MagicMock() + mock_edge.node1 = mock_node1 + mock_edge.node2 = mock_node2 + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node", + side_effect=["uuid1", "uuid2"], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.add_data_points", + ) as mock_add_data, + ): + await retriever.save_qa( + question="Test question", + answer="Test answer", + context="Test context", + triplets=[mock_edge], ) - entities = [company1, company2, person1, person2, person3, person4, person5] + mock_add_data.assert_awaited_once() + mock_graph_engine.add_edges.assert_awaited_once() - await add_data_points(entities) - retriever = GraphCompletionRetriever() +@pytest.mark.asyncio +async def test_save_qa_no_triplet_ids(mock_edge): + """Test save_qa when triplets have no extractable IDs.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.add_edges = AsyncMock() - context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?")) + retriever = GraphCompletionRetriever() - # Ensure the top-level sections are present - assert "Nodes:" in context, "Missing 'Nodes:' section in context" - assert "Connections:" in context, "Missing 'Connections:' section in context" + mock_node1 = MagicMock() + mock_node2 = MagicMock() + mock_edge.node1 = mock_node1 + mock_edge.node2 = mock_node2 - # --- Nodes headers --- - assert "Node: Steve Rodger" in context, "Missing node header for Steve Rodger" - assert "Node: Figma" in context, "Missing node header for Figma" - assert "Node: Ike Loma" in context, "Missing node header for Ike Loma" - assert "Node: Jason Statham" in context, "Missing node header for Jason Statham" - assert "Node: Mike Broski" in context, "Missing node header for Mike Broski" - assert "Node: Canva" in context, "Missing node header for Canva" - assert "Node: Christina Mayer" in context, "Missing node header for Christina Mayer" - - # --- Node contents --- - assert ( - "__node_content_start__\nThis is description about Steve Rodger\n__node_content_end__" - in context - ), "Description block for Steve Rodger altered" - assert "__node_content_start__\nFigma is a company\n__node_content_end__" in context, ( - "Description block for Figma altered" - ) - assert ( - "__node_content_start__\nThis is description about Ike Loma\n__node_content_end__" - in context - ), "Description block for Ike Loma altered" - assert ( - "__node_content_start__\nThis is description about Jason Statham\n__node_content_end__" - in context - ), "Description block for Jason Statham altered" - assert ( - "__node_content_start__\nThis is description about Mike Broski\n__node_content_end__" - in context - ), "Description block for Mike Broski altered" - assert "__node_content_start__\nCanvas is a company\n__node_content_end__" in context, ( - "Description block for Canva altered" - ) - assert ( - "__node_content_start__\nThis is description about Christina Mayer\n__node_content_end__" - in context - ), "Description block for Christina Mayer altered" - - # --- Connections --- - assert "Steve Rodger --[works_for]--> Figma" in context, ( - "Connection Steve Rodger→Figma missing or changed" - ) - assert "Ike Loma --[works_for]--> Figma" in context, ( - "Connection Ike Loma→Figma missing or changed" - ) - assert "Jason Statham --[works_for]--> Figma" in context, ( - "Connection Jason Statham→Figma missing or changed" - ) - assert "Mike Broski --[works_for]--> Canva" in context, ( - "Connection Mike Broski→Canva missing or changed" - ) - assert "Christina Mayer --[works_for]--> Canva" in context, ( - "Connection Christina Mayer→Canva missing or changed" + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node", + return_value=None, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.add_data_points", + ) as mock_add_data, + ): + await retriever.save_qa( + question="Test question", + answer="Test answer", + context="Test context", + triplets=[mock_edge], ) - @pytest.mark.asyncio - async def test_graph_completion_context_complex(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context_complex" + mock_add_data.assert_awaited_once() + mock_graph_engine.add_edges.assert_not_called() + + +@pytest.mark.asyncio +async def test_save_qa_empty_triplets(): + """Test save_qa with empty triplets list.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.add_edges = AsyncMock() + + retriever = GraphCompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.add_data_points", + ) as mock_add_data, + ): + await retriever.save_qa( + question="Test question", + answer="Test answer", + context="Test context", + triplets=[], ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context_complex" - ) - cognee.config.data_root_directory(data_directory_path) - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() + mock_add_data.assert_awaited_once() + mock_graph_engine.add_edges.assert_not_called() - class Company(DataPoint): - name: str - metadata: dict = {"index_fields": ["name"]} - class Car(DataPoint): - brand: str - model: str - year: int +@pytest.mark.asyncio +async def test_get_completion_with_save_interaction_no_completion(mock_edge): + """Test get_completion with save_interaction but no completion.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) - class Location(DataPoint): - country: str - city: str + retriever = GraphCompletionRetriever(save_interaction=True) - class Home(DataPoint): - location: Location - rooms: int - sqm: int + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value=None, # No completion + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config - class Person(DataPoint): - name: str - works_for: Company - owns: Optional[list[Union[Car, Home]]] = None + completion = await retriever.get_completion("test query") - company1 = Company(name="Figma") - company2 = Company(name="Canva") + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] is None - person1 = Person(name="Mike Rodger", works_for=company1) - person1.owns = [Car(brand="Toyota", model="Camry", year=2020)] - person2 = Person(name="Ike Loma", works_for=company1) - person2.owns = [ - Car(brand="Tesla", model="Model S", year=2021), - Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4), - ] +@pytest.mark.asyncio +async def test_get_completion_with_save_interaction_no_context(mock_edge): + """Test get_completion with save_interaction but no context provided.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) - person3 = Person(name="Jason Statham", works_for=company1) + retriever = GraphCompletionRetriever(save_interaction=True) - person4 = Person(name="Mike Broski", works_for=company2) - person4.owns = [Car(brand="Ford", model="Mustang", year=1978)] + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config - person5 = Person(name="Christina Mayer", works_for=company2) - person5.owns = [Car(brand="Honda", model="Civic", year=2023)] + completion = await retriever.get_completion("test query", context=None) - entities = [company1, company2, person1, person2, person3, person4, person5] + assert isinstance(completion, list) + assert len(completion) == 1 - await add_data_points(entities) - retriever = GraphCompletionRetriever(top_k=20) +@pytest.mark.asyncio +async def test_get_completion_with_save_interaction_all_conditions_met(mock_edge): + """Test get_completion with save_interaction when all conditions are met (line 216).""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) - context = await resolve_edges_to_text(await retriever.get_context("Who works at Figma?")) + retriever = GraphCompletionRetriever(save_interaction=True) - print(context) + mock_node1 = MagicMock() + mock_node2 = MagicMock() + mock_edge.node1 = mock_node1 + mock_edge.node2 = mock_node2 - assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger" - assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma" - assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham" + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node", + side_effect=[ + UUID("550e8400-e29b-41d4-a716-446655440000"), + UUID("550e8400-e29b-41d4-a716-446655440001"), + ], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.add_data_points", + ) as mock_add_data, + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config - @pytest.mark.asyncio - async def test_get_graph_completion_context_on_empty_graph(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".cognee_system/test_get_graph_completion_context_on_empty_graph", - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".data_storage/test_get_graph_completion_context_on_empty_graph", - ) - cognee.config.data_root_directory(data_directory_path) + completion = await retriever.get_completion("test query", context=[mock_edge]) - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - - retriever = GraphCompletionRetriever() - - await setup() - - context = await retriever.get_context("Who works at Figma?") - assert context == [], "Context should be empty on an empty graph" + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + mock_add_data.assert_awaited_once() diff --git a/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py index 9bfed68f3..e998d419d 100644 --- a/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py @@ -1,205 +1,321 @@ -import os -from typing import List import pytest -import pathlib -import cognee +from unittest.mock import AsyncMock, patch, MagicMock -from cognee.low_level import setup -from cognee.tasks.storage import add_data_points -from cognee.infrastructure.databases.vector import get_vector_engine -from cognee.modules.chunking.models import DocumentChunk -from cognee.modules.data.processing.document_types import TextDocument -from cognee.modules.retrieval.exceptions.exceptions import NoDataError from cognee.modules.retrieval.completion_retriever import CompletionRetriever -from cognee.infrastructure.engine import DataPoint -from cognee.modules.data.processing.document_types import Document -from cognee.modules.engine.models import Entity +from cognee.modules.retrieval.exceptions.exceptions import NoDataError +from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError -class DocumentChunkWithEntities(DataPoint): - text: str - chunk_size: int - chunk_index: int - cut_type: str - is_part_of: Document - contains: List[Entity] = None - - metadata: dict = {"index_fields": ["text"]} +@pytest.fixture +def mock_vector_engine(): + """Create a mock vector engine.""" + engine = AsyncMock() + engine.search = AsyncMock() + return engine -class TestRAGCompletionRetriever: - @pytest.mark.asyncio - async def test_rag_completion_context_simple(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_rag_completion_context_simple" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_rag_completion_context_simple" - ) - cognee.config.data_root_directory(data_directory_path) +@pytest.mark.asyncio +async def test_get_context_success(mock_vector_engine): + """Test successful retrieval of context.""" + mock_result1 = MagicMock() + mock_result1.payload = {"text": "Steve Rodger"} + mock_result2 = MagicMock() + mock_result2.payload = {"text": "Mike Broski"} - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() + mock_vector_engine.search.return_value = [mock_result1, mock_result2] - document = TextDocument( - name="Steve Rodger's career", - raw_data_location="somewhere", - external_metadata="", - mime_type="text/plain", - ) + retriever = CompletionRetriever(top_k=2) - chunk1 = DocumentChunk( - text="Steve Rodger", - chunk_size=2, - chunk_index=0, - cut_type="sentence_end", - is_part_of=document, - contains=[], - ) - chunk2 = DocumentChunk( - text="Mike Broski", - chunk_size=2, - chunk_index=1, - cut_type="sentence_end", - is_part_of=document, - contains=[], - ) - chunk3 = DocumentChunk( - text="Christina Mayer", - chunk_size=2, - chunk_index=2, - cut_type="sentence_end", - is_part_of=document, - contains=[], - ) + with patch( + "cognee.modules.retrieval.completion_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") - entities = [chunk1, chunk2, chunk3] + assert context == "Steve Rodger\nMike Broski" + mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=2) - await add_data_points(entities) - retriever = CompletionRetriever() +@pytest.mark.asyncio +async def test_get_context_collection_not_found_error(mock_vector_engine): + """Test that CollectionNotFoundError is converted to NoDataError.""" + mock_vector_engine.search.side_effect = CollectionNotFoundError("Collection not found") - context = await retriever.get_context("Mike") + retriever = CompletionRetriever() - assert context == "Mike Broski", "Failed to get Mike Broski" + with patch( + "cognee.modules.retrieval.completion_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + with pytest.raises(NoDataError, match="No data found"): + await retriever.get_context("test query") - @pytest.mark.asyncio - async def test_rag_completion_context_complex(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_rag_completion_context_complex" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_rag_completion_context_complex" - ) - cognee.config.data_root_directory(data_directory_path) - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() +@pytest.mark.asyncio +async def test_get_context_empty_results(mock_vector_engine): + """Test that empty string is returned when no chunks are found.""" + mock_vector_engine.search.return_value = [] - document1 = TextDocument( - name="Employee List", - raw_data_location="somewhere", - external_metadata="", - mime_type="text/plain", - ) + retriever = CompletionRetriever() - document2 = TextDocument( - name="Car List", - raw_data_location="somewhere", - external_metadata="", - mime_type="text/plain", - ) + with patch( + "cognee.modules.retrieval.completion_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") - chunk1 = DocumentChunk( - text="Steve Rodger", - chunk_size=2, - chunk_index=0, - cut_type="sentence_end", - is_part_of=document1, - contains=[], - ) - chunk2 = DocumentChunk( - text="Mike Broski", - chunk_size=2, - chunk_index=1, - cut_type="sentence_end", - is_part_of=document1, - contains=[], - ) - chunk3 = DocumentChunk( - text="Christina Mayer", - chunk_size=2, - chunk_index=2, - cut_type="sentence_end", - is_part_of=document1, - contains=[], - ) + assert context == "" - chunk4 = DocumentChunk( - text="Range Rover", - chunk_size=2, - chunk_index=0, - cut_type="sentence_end", - is_part_of=document2, - contains=[], - ) - chunk5 = DocumentChunk( - text="Hyundai", - chunk_size=2, - chunk_index=1, - cut_type="sentence_end", - is_part_of=document2, - contains=[], - ) - chunk6 = DocumentChunk( - text="Chrysler", - chunk_size=2, - chunk_index=2, - cut_type="sentence_end", - is_part_of=document2, - contains=[], - ) - entities = [chunk1, chunk2, chunk3, chunk4, chunk5, chunk6] +@pytest.mark.asyncio +async def test_get_context_top_k_limit(mock_vector_engine): + """Test that top_k parameter limits the number of results.""" + mock_results = [MagicMock() for _ in range(2)] + for i, result in enumerate(mock_results): + result.payload = {"text": f"Chunk {i}"} - await add_data_points(entities) + mock_vector_engine.search.return_value = mock_results - # TODO: top_k doesn't affect the output, it should be fixed. - retriever = CompletionRetriever(top_k=20) + retriever = CompletionRetriever(top_k=2) - context = await retriever.get_context("Christina") + with patch( + "cognee.modules.retrieval.completion_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") - assert context[0:15] == "Christina Mayer", "Failed to get Christina Mayer" + assert context == "Chunk 0\nChunk 1" + mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=2) - @pytest.mark.asyncio - async def test_get_rag_completion_context_on_empty_graph(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".cognee_system/test_get_rag_completion_context_on_empty_graph", - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".data_storage/test_get_rag_completion_context_on_empty_graph", - ) - cognee.config.data_root_directory(data_directory_path) - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) +@pytest.mark.asyncio +async def test_get_context_single_chunk(mock_vector_engine): + """Test get_context with single chunk result.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Single chunk text"} + mock_vector_engine.search.return_value = [mock_result] - retriever = CompletionRetriever() + retriever = CompletionRetriever() - with pytest.raises(NoDataError): - await retriever.get_context("Christina Mayer") + with patch( + "cognee.modules.retrieval.completion_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") - vector_engine = get_vector_engine() - await vector_engine.create_collection( - "DocumentChunk_text", payload_schema=DocumentChunkWithEntities - ) + assert context == "Single chunk text" - context = await retriever.get_context("Christina Mayer") - assert context == "", "Returned context should be empty on an empty graph" + +@pytest.mark.asyncio +async def test_get_completion_without_session(mock_vector_engine): + """Test get_completion without session caching.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Chunk text"} + mock_vector_engine.search.return_value = [mock_result] + + retriever = CompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.completion_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch("cognee.modules.retrieval.completion_retriever.CacheConfig") as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_provided_context(mock_vector_engine): + """Test get_completion with provided context.""" + retriever = CompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch("cognee.modules.retrieval.completion_retriever.CacheConfig") as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", context="Provided context") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_session(mock_vector_engine): + """Test get_completion with session caching enabled.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Chunk text"} + mock_vector_engine.search.return_value = [mock_result] + + retriever = CompletionRetriever() + + mock_user = MagicMock() + mock_user.id = "test-user-id" + + with ( + patch( + "cognee.modules.retrieval.completion_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.completion_retriever.get_conversation_history", + return_value="Previous conversation", + ), + patch( + "cognee.modules.retrieval.completion_retriever.summarize_text", + return_value="Context summary", + ), + patch( + "cognee.modules.retrieval.completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.completion_retriever.save_conversation_history", + ) as mock_save, + patch("cognee.modules.retrieval.completion_retriever.CacheConfig") as mock_cache_config, + patch("cognee.modules.retrieval.completion_retriever.session_user") as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = mock_user + + completion = await retriever.get_completion("test query", session_id="test_session") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + mock_save.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_completion_with_session_no_user_id(mock_vector_engine): + """Test get_completion with session config but no user ID.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Chunk text"} + mock_vector_engine.search.return_value = [mock_result] + + retriever = CompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.completion_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch("cognee.modules.retrieval.completion_retriever.CacheConfig") as mock_cache_config, + patch("cognee.modules.retrieval.completion_retriever.session_user") as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = None # No user + + completion = await retriever.get_completion("test query") + + assert isinstance(completion, list) + assert len(completion) == 1 + + +@pytest.mark.asyncio +async def test_get_completion_with_response_model(mock_vector_engine): + """Test get_completion with custom response model.""" + from pydantic import BaseModel + + class TestModel(BaseModel): + answer: str + + mock_result = MagicMock() + mock_result.payload = {"text": "Chunk text"} + mock_vector_engine.search.return_value = [mock_result] + + retriever = CompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.completion_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.completion_retriever.generate_completion", + return_value=TestModel(answer="Test answer"), + ), + patch("cognee.modules.retrieval.completion_retriever.CacheConfig") as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", response_model=TestModel) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert isinstance(completion[0], TestModel) + + +@pytest.mark.asyncio +async def test_init_defaults(): + """Test CompletionRetriever initialization with defaults.""" + retriever = CompletionRetriever() + + assert retriever.user_prompt_path == "context_for_question.txt" + assert retriever.system_prompt_path == "answer_simple_question.txt" + assert retriever.top_k == 1 + assert retriever.system_prompt is None + + +@pytest.mark.asyncio +async def test_init_custom_params(): + """Test CompletionRetriever initialization with custom parameters.""" + retriever = CompletionRetriever( + user_prompt_path="custom_user.txt", + system_prompt_path="custom_system.txt", + system_prompt="Custom prompt", + top_k=10, + ) + + assert retriever.user_prompt_path == "custom_user.txt" + assert retriever.system_prompt_path == "custom_system.txt" + assert retriever.system_prompt == "Custom prompt" + assert retriever.top_k == 10 + + +@pytest.mark.asyncio +async def test_get_context_missing_text_key(mock_vector_engine): + """Test get_context handles missing text key in payload.""" + mock_result = MagicMock() + mock_result.payload = {"other_key": "value"} + + mock_vector_engine.search.return_value = [mock_result] + + retriever = CompletionRetriever() + + with patch( + "cognee.modules.retrieval.completion_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + with pytest.raises(KeyError): + await retriever.get_context("test query") diff --git a/cognee/tests/unit/modules/retrieval/structured_output_test.py b/cognee/tests/unit/modules/retrieval/structured_output_test.py deleted file mode 100644 index 4ad3019ff..000000000 --- a/cognee/tests/unit/modules/retrieval/structured_output_test.py +++ /dev/null @@ -1,204 +0,0 @@ -import asyncio - -import pytest -import cognee -import pathlib -import os - -from pydantic import BaseModel -from cognee.low_level import setup, DataPoint -from cognee.tasks.storage import add_data_points -from cognee.modules.chunking.models import DocumentChunk -from cognee.modules.data.processing.document_types import TextDocument -from cognee.modules.engine.models import Entity, EntityType -from cognee.modules.retrieval.entity_extractors.DummyEntityExtractor import DummyEntityExtractor -from cognee.modules.retrieval.context_providers.DummyContextProvider import DummyContextProvider -from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever -from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever -from cognee.modules.retrieval.graph_completion_context_extension_retriever import ( - GraphCompletionContextExtensionRetriever, -) -from cognee.modules.retrieval.EntityCompletionRetriever import EntityCompletionRetriever -from cognee.modules.retrieval.temporal_retriever import TemporalRetriever -from cognee.modules.retrieval.completion_retriever import CompletionRetriever - - -class TestAnswer(BaseModel): - answer: str - explanation: str - - -def _assert_string_answer(answer: list[str]): - assert isinstance(answer, list), f"Expected str, got {type(answer).__name__}" - assert all(isinstance(item, str) and item.strip() for item in answer), "Items should be strings" - assert all(item.strip() for item in answer), "Items should not be empty" - - -def _assert_structured_answer(answer: list[TestAnswer]): - assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" - assert all(isinstance(x, TestAnswer) for x in answer), "Items should be TestAnswer" - assert all(x.answer.strip() for x in answer), "Answer text should not be empty" - assert all(x.explanation.strip() for x in answer), "Explanation should not be empty" - - -async def _test_get_structured_graph_completion_cot(): - retriever = GraphCompletionCotRetriever() - - # Test with string response model (default) - string_answer = await retriever.get_completion("Who works at Figma?") - _assert_string_answer(string_answer) - - # Test with structured response model - structured_answer = await retriever.get_completion( - "Who works at Figma?", response_model=TestAnswer - ) - _assert_structured_answer(structured_answer) - - -async def _test_get_structured_graph_completion(): - retriever = GraphCompletionRetriever() - - # Test with string response model (default) - string_answer = await retriever.get_completion("Who works at Figma?") - _assert_string_answer(string_answer) - - # Test with structured response model - structured_answer = await retriever.get_completion( - "Who works at Figma?", response_model=TestAnswer - ) - _assert_structured_answer(structured_answer) - - -async def _test_get_structured_graph_completion_temporal(): - retriever = TemporalRetriever() - - # Test with string response model (default) - string_answer = await retriever.get_completion("When did Steve start working at Figma?") - _assert_string_answer(string_answer) - - # Test with structured response model - structured_answer = await retriever.get_completion( - "When did Steve start working at Figma??", response_model=TestAnswer - ) - _assert_structured_answer(structured_answer) - - -async def _test_get_structured_graph_completion_rag(): - retriever = CompletionRetriever() - - # Test with string response model (default) - string_answer = await retriever.get_completion("Where does Steve work?") - _assert_string_answer(string_answer) - - # Test with structured response model - structured_answer = await retriever.get_completion( - "Where does Steve work?", response_model=TestAnswer - ) - _assert_structured_answer(structured_answer) - - -async def _test_get_structured_graph_completion_context_extension(): - retriever = GraphCompletionContextExtensionRetriever() - - # Test with string response model (default) - string_answer = await retriever.get_completion("Who works at Figma?") - _assert_string_answer(string_answer) - - # Test with structured response model - structured_answer = await retriever.get_completion( - "Who works at Figma?", response_model=TestAnswer - ) - _assert_structured_answer(structured_answer) - - -async def _test_get_structured_entity_completion(): - retriever = EntityCompletionRetriever(DummyEntityExtractor(), DummyContextProvider()) - - # Test with string response model (default) - string_answer = await retriever.get_completion("Who is Albert Einstein?") - _assert_string_answer(string_answer) - - # Test with structured response model - structured_answer = await retriever.get_completion( - "Who is Albert Einstein?", response_model=TestAnswer - ) - _assert_structured_answer(structured_answer) - - -class TestStructuredOutputCompletion: - @pytest.mark.asyncio - async def test_get_structured_completion(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_get_structured_completion" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_get_structured_completion" - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - - class Company(DataPoint): - name: str - - class Person(DataPoint): - name: str - works_for: Company - works_since: int - - company1 = Company(name="Figma") - person1 = Person(name="Steve Rodger", works_for=company1, works_since=2015) - - entities = [company1, person1] - await add_data_points(entities) - - document = TextDocument( - name="Steve Rodger's career", - raw_data_location="somewhere", - external_metadata="", - mime_type="text/plain", - ) - - chunk1 = DocumentChunk( - text="Steve Rodger", - chunk_size=2, - chunk_index=0, - cut_type="sentence_end", - is_part_of=document, - contains=[], - ) - chunk2 = DocumentChunk( - text="Mike Broski", - chunk_size=2, - chunk_index=1, - cut_type="sentence_end", - is_part_of=document, - contains=[], - ) - chunk3 = DocumentChunk( - text="Christina Mayer", - chunk_size=2, - chunk_index=2, - cut_type="sentence_end", - is_part_of=document, - contains=[], - ) - - entities = [chunk1, chunk2, chunk3] - await add_data_points(entities) - - entity_type = EntityType(name="Person", description="A human individual") - entity = Entity(name="Albert Einstein", is_a=entity_type, description="A famous physicist") - - entities = [entity] - await add_data_points(entities) - - await _test_get_structured_graph_completion_cot() - await _test_get_structured_graph_completion() - await _test_get_structured_graph_completion_temporal() - await _test_get_structured_graph_completion_rag() - await _test_get_structured_graph_completion_context_extension() - await _test_get_structured_entity_completion() diff --git a/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py b/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py index 5f4b93425..e552ac74a 100644 --- a/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py @@ -1,159 +1,193 @@ -import os import pytest -import pathlib +from unittest.mock import AsyncMock, patch, MagicMock -import cognee -from cognee.low_level import setup -from cognee.tasks.storage import add_data_points -from cognee.infrastructure.databases.vector import get_vector_engine -from cognee.modules.chunking.models import DocumentChunk -from cognee.tasks.summarization.models import TextSummary -from cognee.modules.data.processing.document_types import TextDocument -from cognee.modules.retrieval.exceptions.exceptions import NoDataError from cognee.modules.retrieval.summaries_retriever import SummariesRetriever +from cognee.modules.retrieval.exceptions.exceptions import NoDataError +from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError -class TestSummariesRetriever: - @pytest.mark.asyncio - async def test_chunk_context(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_chunk_context" - ) - cognee.config.data_root_directory(data_directory_path) +@pytest.fixture +def mock_vector_engine(): + """Create a mock vector engine.""" + engine = AsyncMock() + engine.search = AsyncMock() + return engine - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - document1 = TextDocument( - name="Employee List", - raw_data_location="somewhere", - external_metadata="", - mime_type="text/plain", - ) +@pytest.mark.asyncio +async def test_get_context_success(mock_vector_engine): + """Test successful retrieval of summary context.""" + mock_result1 = MagicMock() + mock_result1.payload = {"text": "S.R.", "made_from": "chunk1"} + mock_result2 = MagicMock() + mock_result2.payload = {"text": "M.B.", "made_from": "chunk2"} - document2 = TextDocument( - name="Car List", - raw_data_location="somewhere", - external_metadata="", - mime_type="text/plain", - ) + mock_vector_engine.search.return_value = [mock_result1, mock_result2] - chunk1 = DocumentChunk( - text="Steve Rodger", - chunk_size=2, - chunk_index=0, - cut_type="sentence_end", - is_part_of=document1, - contains=[], - ) - chunk1_summary = TextSummary( - text="S.R.", - made_from=chunk1, - ) - chunk2 = DocumentChunk( - text="Mike Broski", - chunk_size=2, - chunk_index=1, - cut_type="sentence_end", - is_part_of=document1, - contains=[], - ) - chunk2_summary = TextSummary( - text="M.B.", - made_from=chunk2, - ) - chunk3 = DocumentChunk( - text="Christina Mayer", - chunk_size=2, - chunk_index=2, - cut_type="sentence_end", - is_part_of=document1, - contains=[], - ) - chunk3_summary = TextSummary( - text="C.M.", - made_from=chunk3, - ) - chunk4 = DocumentChunk( - text="Range Rover", - chunk_size=2, - chunk_index=0, - cut_type="sentence_end", - is_part_of=document2, - contains=[], - ) - chunk4_summary = TextSummary( - text="R.R.", - made_from=chunk4, - ) - chunk5 = DocumentChunk( - text="Hyundai", - chunk_size=2, - chunk_index=1, - cut_type="sentence_end", - is_part_of=document2, - contains=[], - ) - chunk5_summary = TextSummary( - text="H.Y.", - made_from=chunk5, - ) - chunk6 = DocumentChunk( - text="Chrysler", - chunk_size=2, - chunk_index=2, - cut_type="sentence_end", - is_part_of=document2, - contains=[], - ) - chunk6_summary = TextSummary( - text="C.H.", - made_from=chunk6, - ) + retriever = SummariesRetriever(top_k=5) - entities = [ - chunk1_summary, - chunk2_summary, - chunk3_summary, - chunk4_summary, - chunk5_summary, - chunk6_summary, - ] + with patch( + "cognee.modules.retrieval.summaries_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") - await add_data_points(entities) + assert len(context) == 2 + assert context[0]["text"] == "S.R." + assert context[1]["text"] == "M.B." + mock_vector_engine.search.assert_awaited_once_with("TextSummary_text", "test query", limit=5) - retriever = SummariesRetriever(top_k=20) - context = await retriever.get_context("Christina") +@pytest.mark.asyncio +async def test_get_context_collection_not_found_error(mock_vector_engine): + """Test that CollectionNotFoundError is converted to NoDataError.""" + mock_vector_engine.search.side_effect = CollectionNotFoundError("Collection not found") - assert context[0]["text"] == "C.M.", "Failed to get Christina Mayer" + retriever = SummariesRetriever() - @pytest.mark.asyncio - async def test_chunk_context_on_empty_graph(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_on_empty_graph" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_on_empty_graph" - ) - cognee.config.data_root_directory(data_directory_path) + with patch( + "cognee.modules.retrieval.summaries_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + with pytest.raises(NoDataError, match="No data found"): + await retriever.get_context("test query") - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - retriever = SummariesRetriever() +@pytest.mark.asyncio +async def test_get_context_empty_results(mock_vector_engine): + """Test that empty list is returned when no summaries are found.""" + mock_vector_engine.search.return_value = [] - with pytest.raises(NoDataError): - await retriever.get_context("Christina Mayer") + retriever = SummariesRetriever() - vector_engine = get_vector_engine() - await vector_engine.create_collection("TextSummary_text", payload_schema=TextSummary) + with patch( + "cognee.modules.retrieval.summaries_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") - context = await retriever.get_context("Christina Mayer") - assert context == [], "Returned context should be empty on an empty graph" + assert context == [] + + +@pytest.mark.asyncio +async def test_get_context_top_k_limit(mock_vector_engine): + """Test that top_k parameter limits the number of results.""" + mock_results = [MagicMock() for _ in range(3)] + for i, result in enumerate(mock_results): + result.payload = {"text": f"Summary {i}"} + + mock_vector_engine.search.return_value = mock_results + + retriever = SummariesRetriever(top_k=3) + + with patch( + "cognee.modules.retrieval.summaries_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") + + assert len(context) == 3 + mock_vector_engine.search.assert_awaited_once_with("TextSummary_text", "test query", limit=3) + + +@pytest.mark.asyncio +async def test_get_completion_with_context(mock_vector_engine): + """Test get_completion returns provided context.""" + retriever = SummariesRetriever() + + provided_context = [{"text": "S.R."}, {"text": "M.B."}] + completion = await retriever.get_completion("test query", context=provided_context) + + assert completion == provided_context + + +@pytest.mark.asyncio +async def test_get_completion_without_context(mock_vector_engine): + """Test get_completion retrieves context when not provided.""" + mock_result = MagicMock() + mock_result.payload = {"text": "S.R."} + mock_vector_engine.search.return_value = [mock_result] + + retriever = SummariesRetriever() + + with patch( + "cognee.modules.retrieval.summaries_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + completion = await retriever.get_completion("test query") + + assert len(completion) == 1 + assert completion[0]["text"] == "S.R." + + +@pytest.mark.asyncio +async def test_init_defaults(): + """Test SummariesRetriever initialization with defaults.""" + retriever = SummariesRetriever() + + assert retriever.top_k == 5 + + +@pytest.mark.asyncio +async def test_init_custom_top_k(): + """Test SummariesRetriever initialization with custom top_k.""" + retriever = SummariesRetriever(top_k=10) + + assert retriever.top_k == 10 + + +@pytest.mark.asyncio +async def test_get_context_empty_payload(mock_vector_engine): + """Test get_context handles empty payload.""" + mock_result = MagicMock() + mock_result.payload = {} + + mock_vector_engine.search.return_value = [mock_result] + + retriever = SummariesRetriever() + + with patch( + "cognee.modules.retrieval.summaries_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") + + assert len(context) == 1 + assert context[0] == {} + + +@pytest.mark.asyncio +async def test_get_completion_with_session_id(mock_vector_engine): + """Test get_completion with session_id parameter.""" + mock_result = MagicMock() + mock_result.payload = {"text": "S.R."} + mock_vector_engine.search.return_value = [mock_result] + + retriever = SummariesRetriever() + + with patch( + "cognee.modules.retrieval.summaries_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + completion = await retriever.get_completion("test query", session_id="test_session") + + assert len(completion) == 1 + assert completion[0]["text"] == "S.R." + + +@pytest.mark.asyncio +async def test_get_completion_with_kwargs(mock_vector_engine): + """Test get_completion accepts additional kwargs.""" + mock_result = MagicMock() + mock_result.payload = {"text": "S.R."} + mock_vector_engine.search.return_value = [mock_result] + + retriever = SummariesRetriever() + + with patch( + "cognee.modules.retrieval.summaries_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + completion = await retriever.get_completion("test query", extra_param="value") + + assert len(completion) == 1 diff --git a/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py b/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py index c3c6a47f6..1d2f4c84d 100644 --- a/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py @@ -1,7 +1,12 @@ from types import SimpleNamespace import pytest +import os +from unittest.mock import AsyncMock, patch, MagicMock +from datetime import datetime from cognee.modules.retrieval.temporal_retriever import TemporalRetriever +from cognee.tasks.temporal_graph.models import QueryInterval, Timestamp +from cognee.infrastructure.llm import LLMGateway # Test TemporalRetriever initialization defaults and overrides @@ -140,85 +145,561 @@ async def test_filter_top_k_events_error_handling(): await tr.filter_top_k_events([{}], []) -class _FakeRetriever(TemporalRetriever): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._calls = [] +@pytest.fixture +def mock_graph_engine(): + """Create a mock graph engine.""" + engine = AsyncMock() + engine.collect_time_ids = AsyncMock() + engine.collect_events = AsyncMock() + return engine - async def extract_time_from_query(self, query: str): - if "both" in query: + +@pytest.fixture +def mock_vector_engine(): + """Create a mock vector engine.""" + engine = AsyncMock() + engine.embedding_engine = AsyncMock() + engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + engine.search = AsyncMock() + return engine + + +@pytest.mark.asyncio +async def test_get_context_with_time_range(mock_graph_engine, mock_vector_engine): + """Test get_context when time range is extracted from query.""" + retriever = TemporalRetriever(top_k=5) + + mock_graph_engine.collect_time_ids.return_value = ["e1", "e2"] + mock_graph_engine.collect_events.return_value = [ + { + "events": [ + {"id": "e1", "description": "Event 1"}, + {"id": "e2", "description": "Event 2"}, + ] + } + ] + + mock_result1 = SimpleNamespace(payload={"id": "e2"}, score=0.05) + mock_result2 = SimpleNamespace(payload={"id": "e1"}, score=0.10) + mock_vector_engine.search.return_value = [mock_result1, mock_result2] + + with ( + patch.object( + retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31") + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + ): + context = await retriever.get_context("What happened in 2024?") + + assert isinstance(context, str) + assert len(context) > 0 + assert "Event" in context + + +@pytest.mark.asyncio +async def test_get_context_fallback_to_triplets_no_time(mock_graph_engine): + """Test get_context falls back to triplets when no time is extracted.""" + retriever = TemporalRetriever() + + with ( + patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch.object( + retriever, "get_triplets", return_value=[{"s": "a", "p": "b", "o": "c"}] + ) as mock_get_triplets, + patch.object( + retriever, "resolve_edges_to_text", return_value="triplet text" + ) as mock_resolve, + ): + + async def mock_extract_time(query): + return None, None + + retriever.extract_time_from_query = mock_extract_time + + context = await retriever.get_context("test query") + + assert context == "triplet text" + mock_get_triplets.assert_awaited_once_with("test query") + mock_resolve.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_context_no_events_found(mock_graph_engine): + """Test get_context falls back to triplets when no events are found.""" + retriever = TemporalRetriever() + + mock_graph_engine.collect_time_ids.return_value = [] + + with ( + patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch.object( + retriever, "get_triplets", return_value=[{"s": "a", "p": "b", "o": "c"}] + ) as mock_get_triplets, + patch.object( + retriever, "resolve_edges_to_text", return_value="triplet text" + ) as mock_resolve, + ): + + async def mock_extract_time(query): return "2024-01-01", "2024-12-31" - if "from_only" in query: - return "2024-01-01", None - if "to_only" in query: - return None, "2024-12-31" - return None, None - async def get_triplets(self, query: str): - self._calls.append(("get_triplets", query)) - return [{"s": "a", "p": "b", "o": "c"}] + retriever.extract_time_from_query = mock_extract_time - async def resolve_edges_to_text(self, triplets): - self._calls.append(("resolve_edges_to_text", len(triplets))) - return "edges->text" + context = await retriever.get_context("test query") - async def _fake_graph_collect_ids(self, **kwargs): - return ["e1", "e2"] + assert context == "triplet text" + mock_get_triplets.assert_awaited_once_with("test query") + mock_resolve.assert_awaited_once() - async def _fake_graph_collect_events(self, ids): - return [ + +@pytest.mark.asyncio +async def test_get_context_time_from_only(mock_graph_engine, mock_vector_engine): + """Test get_context with only time_from.""" + retriever = TemporalRetriever(top_k=5) + + mock_graph_engine.collect_time_ids.return_value = ["e1"] + mock_graph_engine.collect_events.return_value = [ + { + "events": [ + {"id": "e1", "description": "Event 1"}, + ] + } + ] + + mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05) + mock_vector_engine.search.return_value = [mock_result] + + with ( + patch.object(retriever, "extract_time_from_query", return_value=("2024-01-01", None)), + patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + ): + context = await retriever.get_context("What happened after 2024?") + + assert isinstance(context, str) + assert "Event 1" in context + + +@pytest.mark.asyncio +async def test_get_context_time_to_only(mock_graph_engine, mock_vector_engine): + """Test get_context with only time_to.""" + retriever = TemporalRetriever(top_k=5) + + mock_graph_engine.collect_time_ids.return_value = ["e1"] + mock_graph_engine.collect_events.return_value = [ + { + "events": [ + {"id": "e1", "description": "Event 1"}, + ] + } + ] + + mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05) + mock_vector_engine.search.return_value = [mock_result] + + with ( + patch.object(retriever, "extract_time_from_query", return_value=(None, "2024-12-31")), + patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + ): + context = await retriever.get_context("What happened before 2024?") + + assert isinstance(context, str) + assert "Event 1" in context + + +@pytest.mark.asyncio +async def test_get_completion_without_context(mock_graph_engine, mock_vector_engine): + """Test get_completion retrieves context when not provided.""" + retriever = TemporalRetriever() + + mock_graph_engine.collect_time_ids.return_value = ["e1"] + mock_graph_engine.collect_events.return_value = [ + { + "events": [ + {"id": "e1", "description": "Event 1"}, + ] + } + ] + + mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05) + mock_vector_engine.search.return_value = [mock_result] + + with ( + patch.object( + retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31") + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.generate_completion", + return_value="Generated answer", + ), + patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("What happened in 2024?") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_provided_context(): + """Test get_completion uses provided context.""" + retriever = TemporalRetriever() + + with ( + patch( + "cognee.modules.retrieval.temporal_retriever.generate_completion", + return_value="Generated answer", + ), + patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", context="Provided context") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_session(mock_graph_engine, mock_vector_engine): + """Test get_completion with session caching enabled.""" + retriever = TemporalRetriever() + + mock_graph_engine.collect_time_ids.return_value = ["e1"] + mock_graph_engine.collect_events.return_value = [ + { + "events": [ + {"id": "e1", "description": "Event 1"}, + ] + } + ] + + mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05) + mock_vector_engine.search.return_value = [mock_result] + + mock_user = MagicMock() + mock_user.id = "test-user-id" + + with ( + patch.object( + retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31") + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_conversation_history", + return_value="Previous conversation", + ), + patch( + "cognee.modules.retrieval.temporal_retriever.summarize_text", + return_value="Context summary", + ), + patch( + "cognee.modules.retrieval.temporal_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.temporal_retriever.save_conversation_history", + ) as mock_save, + patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config, + patch("cognee.modules.retrieval.temporal_retriever.session_user") as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = mock_user + + completion = await retriever.get_completion( + "What happened in 2024?", session_id="test_session" + ) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + mock_save.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_completion_with_session_no_user_id(mock_graph_engine, mock_vector_engine): + """Test get_completion with session config but no user ID.""" + retriever = TemporalRetriever() + + mock_graph_engine.collect_time_ids.return_value = ["e1"] + mock_graph_engine.collect_events.return_value = [ + { + "events": [ + {"id": "e1", "description": "Event 1"}, + ] + } + ] + + mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05) + mock_vector_engine.search.return_value = [mock_result] + + with ( + patch.object( + retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31") + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.generate_completion", + return_value="Generated answer", + ), + patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config, + patch("cognee.modules.retrieval.temporal_retriever.session_user") as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = None # No user + + completion = await retriever.get_completion("What happened in 2024?") + + assert isinstance(completion, list) + assert len(completion) == 1 + + +@pytest.mark.asyncio +async def test_get_completion_context_retrieved_but_empty(mock_graph_engine): + """Test get_completion when get_context returns empty string.""" + retriever = TemporalRetriever() + + with ( + patch.object( + retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31") + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_vector_engine", + ) as mock_get_vector, + patch.object(retriever, "filter_top_k_events", return_value=[]), + ): + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[]) + mock_get_vector.return_value = mock_vector_engine + + mock_graph_engine.collect_time_ids.return_value = ["e1"] + mock_graph_engine.collect_events.return_value = [ { "events": [ - {"id": "e1", "description": "E1"}, - {"id": "e2", "description": "E2"}, - {"id": "e3", "description": "E3"}, + {"id": "e1", "description": ""}, ] } ] - async def _fake_vector_embed(self, texts): - assert isinstance(texts, list) and texts - return [[0.0, 1.0, 2.0]] + with pytest.raises((UnboundLocalError, NameError)): + await retriever.get_completion("test query") - async def _fake_vector_search(self, **kwargs): - return [ - SimpleNamespace(payload={"id": "e2"}, score=0.05), - SimpleNamespace(payload={"id": "e1"}, score=0.10), - ] - async def get_context(self, query: str): - time_from, time_to = await self.extract_time_from_query(query) +@pytest.mark.asyncio +async def test_get_completion_with_response_model(mock_graph_engine, mock_vector_engine): + """Test get_completion with custom response model.""" + from pydantic import BaseModel - if not (time_from or time_to): - triplets = await self.get_triplets(query) - return await self.resolve_edges_to_text(triplets) + class TestModel(BaseModel): + answer: str - ids = await self._fake_graph_collect_ids(time_from=time_from, time_to=time_to) - relevant_events = await self._fake_graph_collect_events(ids) + retriever = TemporalRetriever() - _ = await self._fake_vector_embed([query]) - vector_search_results = await self._fake_vector_search( - collection_name="Event_name", query_vector=[0.0], limit=0 + mock_graph_engine.collect_time_ids.return_value = ["e1"] + mock_graph_engine.collect_events.return_value = [ + { + "events": [ + {"id": "e1", "description": "Event 1"}, + ] + } + ] + + mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05) + mock_vector_engine.search.return_value = [mock_result] + + with ( + patch.object( + retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31") + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.generate_completion", + return_value=TestModel(answer="Test answer"), + ), + patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + "What happened in 2024?", response_model=TestModel ) - top_k_events = await self.filter_top_k_events(relevant_events, vector_search_results) - return self.descriptions_to_string(top_k_events) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert isinstance(completion[0], TestModel) -# Test get_context fallback to triplets when no time is extracted @pytest.mark.asyncio -async def test_fake_get_context_falls_back_to_triplets_when_no_time(): - tr = _FakeRetriever(top_k=2) - ctx = await tr.get_context("no_time") - assert ctx == "edges->text" - assert tr._calls[0][0] == "get_triplets" - assert tr._calls[1][0] == "resolve_edges_to_text" +async def test_extract_time_from_query_relative_path(): + """Test extract_time_from_query with relative prompt path.""" + retriever = TemporalRetriever(time_extraction_prompt_path="extract_query_time.txt") + + mock_timestamp_from = Timestamp(year=2024, month=1, day=1) + mock_timestamp_to = Timestamp(year=2024, month=12, day=31) + mock_interval = QueryInterval(starts_at=mock_timestamp_from, ends_at=mock_timestamp_to) + + with ( + patch("cognee.modules.retrieval.temporal_retriever.os.path.isabs", return_value=False), + patch("cognee.modules.retrieval.temporal_retriever.datetime") as mock_datetime, + patch( + "cognee.modules.retrieval.temporal_retriever.render_prompt", + return_value="System prompt", + ), + patch.object( + LLMGateway, + "acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_interval, + ), + ): + mock_datetime.now.return_value.strftime.return_value = "11-12-2024" + + time_from, time_to = await retriever.extract_time_from_query("What happened in 2024?") + + assert time_from == mock_timestamp_from + assert time_to == mock_timestamp_to -# Test get_context when time is extracted and vector ranking is applied @pytest.mark.asyncio -async def test_fake_get_context_with_time_filters_and_vector_ranking(): - tr = _FakeRetriever(top_k=2) - ctx = await tr.get_context("both time") - assert ctx.startswith("E2") - assert "#####################" in ctx - assert "E1" in ctx and "E3" not in ctx +async def test_extract_time_from_query_absolute_path(): + """Test extract_time_from_query with absolute prompt path.""" + retriever = TemporalRetriever( + time_extraction_prompt_path="/absolute/path/to/extract_query_time.txt" + ) + + mock_timestamp_from = Timestamp(year=2024, month=1, day=1) + mock_timestamp_to = Timestamp(year=2024, month=12, day=31) + mock_interval = QueryInterval(starts_at=mock_timestamp_from, ends_at=mock_timestamp_to) + + with ( + patch("cognee.modules.retrieval.temporal_retriever.os.path.isabs", return_value=True), + patch( + "cognee.modules.retrieval.temporal_retriever.os.path.dirname", + return_value="/absolute/path/to", + ), + patch( + "cognee.modules.retrieval.temporal_retriever.os.path.basename", + return_value="extract_query_time.txt", + ), + patch("cognee.modules.retrieval.temporal_retriever.datetime") as mock_datetime, + patch( + "cognee.modules.retrieval.temporal_retriever.render_prompt", + return_value="System prompt", + ), + patch.object( + LLMGateway, + "acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_interval, + ), + ): + mock_datetime.now.return_value.strftime.return_value = "11-12-2024" + + time_from, time_to = await retriever.extract_time_from_query("What happened in 2024?") + + assert time_from == mock_timestamp_from + assert time_to == mock_timestamp_to + + +@pytest.mark.asyncio +async def test_extract_time_from_query_with_none_values(): + """Test extract_time_from_query when interval has None values.""" + retriever = TemporalRetriever(time_extraction_prompt_path="extract_query_time.txt") + + mock_interval = QueryInterval(starts_at=None, ends_at=None) + + with ( + patch("cognee.modules.retrieval.temporal_retriever.os.path.isabs", return_value=False), + patch("cognee.modules.retrieval.temporal_retriever.datetime") as mock_datetime, + patch( + "cognee.modules.retrieval.temporal_retriever.render_prompt", + return_value="System prompt", + ), + patch.object( + LLMGateway, + "acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_interval, + ), + ): + mock_datetime.now.return_value.strftime.return_value = "11-12-2024" + + time_from, time_to = await retriever.extract_time_from_query("What happened?") + + assert time_from is None + assert time_to is None diff --git a/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py b/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py index 3dc9f38d9..b7cbe08d7 100644 --- a/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +++ b/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py @@ -1,12 +1,14 @@ import pytest -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, patch, MagicMock from cognee.modules.retrieval.utils.brute_force_triplet_search import ( brute_force_triplet_search, get_memory_fragment, + format_triplets, ) from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError +from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError class MockScoredResult: @@ -354,20 +356,30 @@ async def test_brute_force_triplet_search_passes_top_k_to_importance_calculation @pytest.mark.asyncio async def test_get_memory_fragment_returns_empty_graph_on_entity_not_found(): - """Test that get_memory_fragment returns empty graph when entity not found.""" + """Test that get_memory_fragment returns empty graph when entity not found (line 85).""" mock_graph_engine = AsyncMock() - mock_graph_engine.project_graph_from_db = AsyncMock( + + # Create a mock fragment that will raise EntityNotFoundError when project_graph_from_db is called + mock_fragment = MagicMock(spec=CogneeGraph) + mock_fragment.project_graph_from_db = AsyncMock( side_effect=EntityNotFoundError("Entity not found") ) - with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine", - return_value=mock_graph_engine, + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.CogneeGraph", + return_value=mock_fragment, + ), ): - fragment = await get_memory_fragment() + result = await get_memory_fragment() - assert isinstance(fragment, CogneeGraph) - assert len(fragment.nodes) == 0 + # Fragment should be returned even though EntityNotFoundError was raised (pass statement on line 85) + assert result == mock_fragment + mock_fragment.project_graph_from_db.assert_awaited_once() @pytest.mark.asyncio @@ -606,3 +618,200 @@ async def test_brute_force_triplet_search_mixed_empty_collections(): call_kwargs = mock_get_fragment_fn.call_args[1] assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"} + + +def test_format_triplets(): + """Test format_triplets function.""" + mock_edge = MagicMock() + mock_node1 = MagicMock() + mock_node2 = MagicMock() + + mock_node1.attributes = {"name": "Node1", "type": "Entity", "id": "n1"} + mock_node2.attributes = {"name": "Node2", "type": "Entity", "id": "n2"} + mock_edge.attributes = {"relationship_name": "relates_to", "edge_text": "connects"} + + mock_edge.node1 = mock_node1 + mock_edge.node2 = mock_node2 + + result = format_triplets([mock_edge]) + + assert isinstance(result, str) + assert "Node1" in result + assert "Node2" in result + assert "relates_to" in result + assert "connects" in result + + +def test_format_triplets_with_none_values(): + """Test format_triplets filters out None values.""" + mock_edge = MagicMock() + mock_node1 = MagicMock() + mock_node2 = MagicMock() + + mock_node1.attributes = {"name": "Node1", "type": None, "id": "n1"} + mock_node2.attributes = {"name": "Node2", "type": "Entity", "id": None} + mock_edge.attributes = {"relationship_name": "relates_to", "edge_text": None} + + mock_edge.node1 = mock_node1 + mock_edge.node2 = mock_node2 + + result = format_triplets([mock_edge]) + + assert "Node1" in result + assert "Node2" in result + assert "relates_to" in result + assert "None" not in result or result.count("None") == 0 + + +def test_format_triplets_with_nested_dict(): + """Test format_triplets handles nested dict attributes (lines 23-35).""" + mock_edge = MagicMock() + mock_node1 = MagicMock() + mock_node2 = MagicMock() + + mock_node1.attributes = {"name": "Node1", "metadata": {"type": "Entity", "id": "n1"}} + mock_node2.attributes = {"name": "Node2", "metadata": {"type": "Entity", "id": "n2"}} + mock_edge.attributes = {"relationship_name": "relates_to"} + + mock_edge.node1 = mock_node1 + mock_edge.node2 = mock_node2 + + result = format_triplets([mock_edge]) + + assert isinstance(result, str) + assert "Node1" in result + assert "Node2" in result + assert "relates_to" in result + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_vector_engine_init_error(): + """Test brute_force_triplet_search handles vector engine initialization error (lines 145-147).""" + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine" + ) as mock_get_vector_engine, + ): + mock_get_vector_engine.side_effect = Exception("Initialization error") + + with pytest.raises(RuntimeError, match="Initialization error"): + await brute_force_triplet_search(query="test query") + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_collection_not_found_error(): + """Test brute_force_triplet_search handles CollectionNotFoundError in search (lines 156-157).""" + mock_vector_engine = AsyncMock() + mock_embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine = mock_embedding_engine + mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + + mock_vector_engine.search = AsyncMock( + side_effect=[ + CollectionNotFoundError("Collection not found"), + [], + [], + ] + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=CogneeGraph(), + ), + ): + result = await brute_force_triplet_search( + query="test query", collections=["missing_collection", "existing_collection"] + ) + + assert result == [] + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_generic_exception(): + """Test brute_force_triplet_search handles generic exceptions (lines 209-217).""" + mock_vector_engine = AsyncMock() + mock_embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine = mock_embedding_engine + mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + + mock_vector_engine.search = AsyncMock(side_effect=Exception("Generic error")) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + ): + with pytest.raises(Exception, match="Generic error"): + await brute_force_triplet_search(query="test query") + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_with_node_name_sets_relevant_ids_to_none(): + """Test brute_force_triplet_search sets relevant_ids_to_filter to None when node_name is provided (line 191).""" + mock_vector_engine = AsyncMock() + mock_embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine = mock_embedding_engine + mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + + mock_result = MockScoredResult(id="node1", score=0.8, payload={"id": "node1"}) + mock_vector_engine.search = AsyncMock(return_value=[mock_result]) + + mock_fragment = AsyncMock() + mock_fragment.map_vector_distances_to_graph_nodes = AsyncMock() + mock_fragment.map_vector_distances_to_graph_edges = AsyncMock() + mock_fragment.calculate_top_triplet_importances = AsyncMock(return_value=[]) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment, + ): + await brute_force_triplet_search(query="test query", node_name=["Node1"]) + + assert mock_get_fragment.called + call_kwargs = mock_get_fragment.call_args.kwargs if mock_get_fragment.call_args else {} + assert call_kwargs.get("relevant_ids_to_filter") is None + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_collection_not_found_at_top_level(): + """Test brute_force_triplet_search handles CollectionNotFoundError at top level (line 210).""" + mock_vector_engine = AsyncMock() + mock_embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine = mock_embedding_engine + mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + + mock_result = MockScoredResult(id="node1", score=0.8, payload={"id": "node1"}) + mock_vector_engine.search = AsyncMock(return_value=[mock_result]) + + mock_fragment = AsyncMock() + mock_fragment.map_vector_distances_to_graph_nodes = AsyncMock() + mock_fragment.map_vector_distances_to_graph_edges = AsyncMock() + mock_fragment.calculate_top_triplet_importances = AsyncMock( + side_effect=CollectionNotFoundError("Collection not found") + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ), + ): + result = await brute_force_triplet_search(query="test query") + + assert result == [] diff --git a/cognee/tests/unit/modules/retrieval/test_completion.py b/cognee/tests/unit/modules/retrieval/test_completion.py new file mode 100644 index 000000000..9a836c2cc --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/test_completion.py @@ -0,0 +1,343 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock +from typing import Type + + +class TestGenerateCompletion: + @pytest.mark.asyncio + async def test_generate_completion_with_system_prompt(self): + """Test generate_completion with provided system_prompt.""" + mock_llm_response = "Generated answer" + + with ( + patch( + "cognee.modules.retrieval.utils.completion.render_prompt", + return_value="User prompt text", + ), + patch( + "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_llm_response, + ) as mock_llm, + ): + from cognee.modules.retrieval.utils.completion import generate_completion + + result = await generate_completion( + query="What is AI?", + context="AI is artificial intelligence", + user_prompt_path="user_prompt.txt", + system_prompt_path="system_prompt.txt", + system_prompt="Custom system prompt", + ) + + assert result == mock_llm_response + mock_llm.assert_awaited_once_with( + text_input="User prompt text", + system_prompt="Custom system prompt", + response_model=str, + ) + + @pytest.mark.asyncio + async def test_generate_completion_without_system_prompt(self): + """Test generate_completion reads system_prompt from file when not provided.""" + mock_llm_response = "Generated answer" + + with ( + patch( + "cognee.modules.retrieval.utils.completion.render_prompt", + return_value="User prompt text", + ), + patch( + "cognee.modules.retrieval.utils.completion.read_query_prompt", + return_value="System prompt from file", + ), + patch( + "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_llm_response, + ) as mock_llm, + ): + from cognee.modules.retrieval.utils.completion import generate_completion + + result = await generate_completion( + query="What is AI?", + context="AI is artificial intelligence", + user_prompt_path="user_prompt.txt", + system_prompt_path="system_prompt.txt", + ) + + assert result == mock_llm_response + mock_llm.assert_awaited_once_with( + text_input="User prompt text", + system_prompt="System prompt from file", + response_model=str, + ) + + @pytest.mark.asyncio + async def test_generate_completion_with_conversation_history(self): + """Test generate_completion includes conversation_history in system_prompt.""" + mock_llm_response = "Generated answer" + + with ( + patch( + "cognee.modules.retrieval.utils.completion.render_prompt", + return_value="User prompt text", + ), + patch( + "cognee.modules.retrieval.utils.completion.read_query_prompt", + return_value="System prompt from file", + ), + patch( + "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_llm_response, + ) as mock_llm, + ): + from cognee.modules.retrieval.utils.completion import generate_completion + + result = await generate_completion( + query="What is AI?", + context="AI is artificial intelligence", + user_prompt_path="user_prompt.txt", + system_prompt_path="system_prompt.txt", + conversation_history="Previous conversation:\nQ: What is ML?\nA: ML is machine learning", + ) + + assert result == mock_llm_response + expected_system_prompt = ( + "Previous conversation:\nQ: What is ML?\nA: ML is machine learning" + + "\nTASK:" + + "System prompt from file" + ) + mock_llm.assert_awaited_once_with( + text_input="User prompt text", + system_prompt=expected_system_prompt, + response_model=str, + ) + + @pytest.mark.asyncio + async def test_generate_completion_with_conversation_history_and_custom_system_prompt(self): + """Test generate_completion includes conversation_history with custom system_prompt.""" + mock_llm_response = "Generated answer" + + with ( + patch( + "cognee.modules.retrieval.utils.completion.render_prompt", + return_value="User prompt text", + ), + patch( + "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_llm_response, + ) as mock_llm, + ): + from cognee.modules.retrieval.utils.completion import generate_completion + + result = await generate_completion( + query="What is AI?", + context="AI is artificial intelligence", + user_prompt_path="user_prompt.txt", + system_prompt_path="system_prompt.txt", + system_prompt="Custom system prompt", + conversation_history="Previous conversation:\nQ: What is ML?\nA: ML is machine learning", + ) + + assert result == mock_llm_response + expected_system_prompt = ( + "Previous conversation:\nQ: What is ML?\nA: ML is machine learning" + + "\nTASK:" + + "Custom system prompt" + ) + mock_llm.assert_awaited_once_with( + text_input="User prompt text", + system_prompt=expected_system_prompt, + response_model=str, + ) + + @pytest.mark.asyncio + async def test_generate_completion_with_response_model(self): + """Test generate_completion with custom response_model.""" + mock_response_model = MagicMock() + mock_llm_response = {"answer": "Generated answer"} + + with ( + patch( + "cognee.modules.retrieval.utils.completion.render_prompt", + return_value="User prompt text", + ), + patch( + "cognee.modules.retrieval.utils.completion.read_query_prompt", + return_value="System prompt from file", + ), + patch( + "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_llm_response, + ) as mock_llm, + ): + from cognee.modules.retrieval.utils.completion import generate_completion + + result = await generate_completion( + query="What is AI?", + context="AI is artificial intelligence", + user_prompt_path="user_prompt.txt", + system_prompt_path="system_prompt.txt", + response_model=mock_response_model, + ) + + assert result == mock_llm_response + mock_llm.assert_awaited_once_with( + text_input="User prompt text", + system_prompt="System prompt from file", + response_model=mock_response_model, + ) + + @pytest.mark.asyncio + async def test_generate_completion_render_prompt_args(self): + """Test generate_completion passes correct args to render_prompt.""" + mock_llm_response = "Generated answer" + + with ( + patch( + "cognee.modules.retrieval.utils.completion.render_prompt", + return_value="User prompt text", + ) as mock_render, + patch( + "cognee.modules.retrieval.utils.completion.read_query_prompt", + return_value="System prompt from file", + ), + patch( + "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_llm_response, + ), + ): + from cognee.modules.retrieval.utils.completion import generate_completion + + await generate_completion( + query="What is AI?", + context="AI is artificial intelligence", + user_prompt_path="user_prompt.txt", + system_prompt_path="system_prompt.txt", + ) + + mock_render.assert_called_once_with( + "user_prompt.txt", + {"question": "What is AI?", "context": "AI is artificial intelligence"}, + ) + + +class TestSummarizeText: + @pytest.mark.asyncio + async def test_summarize_text_with_system_prompt(self): + """Test summarize_text with provided system_prompt.""" + mock_llm_response = "Summary text" + + with patch( + "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_llm_response, + ) as mock_llm: + from cognee.modules.retrieval.utils.completion import summarize_text + + result = await summarize_text( + text="Long text to summarize", + system_prompt_path="summarize_search_results.txt", + system_prompt="Custom summary prompt", + ) + + assert result == mock_llm_response + mock_llm.assert_awaited_once_with( + text_input="Long text to summarize", + system_prompt="Custom summary prompt", + response_model=str, + ) + + @pytest.mark.asyncio + async def test_summarize_text_without_system_prompt(self): + """Test summarize_text reads system_prompt from file when not provided.""" + mock_llm_response = "Summary text" + + with ( + patch( + "cognee.modules.retrieval.utils.completion.read_query_prompt", + return_value="System prompt from file", + ), + patch( + "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_llm_response, + ) as mock_llm, + ): + from cognee.modules.retrieval.utils.completion import summarize_text + + result = await summarize_text( + text="Long text to summarize", + system_prompt_path="summarize_search_results.txt", + ) + + assert result == mock_llm_response + mock_llm.assert_awaited_once_with( + text_input="Long text to summarize", + system_prompt="System prompt from file", + response_model=str, + ) + + @pytest.mark.asyncio + async def test_summarize_text_default_prompt_path(self): + """Test summarize_text uses default prompt path when not provided.""" + mock_llm_response = "Summary text" + + with ( + patch( + "cognee.modules.retrieval.utils.completion.read_query_prompt", + return_value="Default system prompt", + ) as mock_read, + patch( + "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_llm_response, + ) as mock_llm, + ): + from cognee.modules.retrieval.utils.completion import summarize_text + + result = await summarize_text(text="Long text to summarize") + + assert result == mock_llm_response + mock_read.assert_called_once_with("summarize_search_results.txt") + mock_llm.assert_awaited_once_with( + text_input="Long text to summarize", + system_prompt="Default system prompt", + response_model=str, + ) + + @pytest.mark.asyncio + async def test_summarize_text_custom_prompt_path(self): + """Test summarize_text uses custom prompt path when provided.""" + mock_llm_response = "Summary text" + + with ( + patch( + "cognee.modules.retrieval.utils.completion.read_query_prompt", + return_value="Custom system prompt", + ) as mock_read, + patch( + "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_llm_response, + ) as mock_llm, + ): + from cognee.modules.retrieval.utils.completion import summarize_text + + result = await summarize_text( + text="Long text to summarize", + system_prompt_path="custom_prompt.txt", + ) + + assert result == mock_llm_response + mock_read.assert_called_once_with("custom_prompt.txt") + mock_llm.assert_awaited_once_with( + text_input="Long text to summarize", + system_prompt="Custom system prompt", + response_model=str, + ) diff --git a/cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py b/cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py new file mode 100644 index 000000000..2af10da5e --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py @@ -0,0 +1,157 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock + +from cognee.modules.retrieval.graph_summary_completion_retriever import ( + GraphSummaryCompletionRetriever, +) +from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge + + +@pytest.fixture +def mock_edge(): + """Create a mock edge.""" + edge = MagicMock(spec=Edge) + return edge + + +class TestGraphSummaryCompletionRetriever: + @pytest.mark.asyncio + async def test_init_defaults(self): + """Test GraphSummaryCompletionRetriever initialization with defaults.""" + retriever = GraphSummaryCompletionRetriever() + + assert retriever.summarize_prompt_path == "summarize_search_results.txt" + assert retriever.user_prompt_path == "graph_context_for_question.txt" + assert retriever.system_prompt_path == "answer_simple_question.txt" + assert retriever.top_k == 5 + assert retriever.save_interaction is False + + @pytest.mark.asyncio + async def test_init_custom_params(self): + """Test GraphSummaryCompletionRetriever initialization with custom parameters.""" + retriever = GraphSummaryCompletionRetriever( + user_prompt_path="custom_user.txt", + system_prompt_path="custom_system.txt", + summarize_prompt_path="custom_summarize.txt", + system_prompt="Custom system prompt", + top_k=10, + save_interaction=True, + wide_search_top_k=200, + triplet_distance_penalty=2.5, + ) + + assert retriever.summarize_prompt_path == "custom_summarize.txt" + assert retriever.user_prompt_path == "custom_user.txt" + assert retriever.system_prompt_path == "custom_system.txt" + assert retriever.top_k == 10 + assert retriever.save_interaction is True + + @pytest.mark.asyncio + async def test_resolve_edges_to_text_calls_super_and_summarizes(self, mock_edge): + """Test resolve_edges_to_text calls super method and then summarizes.""" + retriever = GraphSummaryCompletionRetriever( + summarize_prompt_path="custom_summarize.txt", + system_prompt="Custom system prompt", + ) + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text", + new_callable=AsyncMock, + return_value="Resolved edges text", + ) as mock_super_resolve, + patch( + "cognee.modules.retrieval.graph_summary_completion_retriever.summarize_text", + new_callable=AsyncMock, + return_value="Summarized text", + ) as mock_summarize, + ): + result = await retriever.resolve_edges_to_text([mock_edge]) + + assert result == "Summarized text" + mock_super_resolve.assert_awaited_once_with([mock_edge]) + mock_summarize.assert_awaited_once_with( + "Resolved edges text", + "custom_summarize.txt", + "Custom system prompt", + ) + + @pytest.mark.asyncio + async def test_resolve_edges_to_text_with_default_system_prompt(self, mock_edge): + """Test resolve_edges_to_text uses None for system_prompt when not provided.""" + retriever = GraphSummaryCompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text", + new_callable=AsyncMock, + return_value="Resolved edges text", + ), + patch( + "cognee.modules.retrieval.graph_summary_completion_retriever.summarize_text", + new_callable=AsyncMock, + return_value="Summarized text", + ) as mock_summarize, + ): + await retriever.resolve_edges_to_text([mock_edge]) + + mock_summarize.assert_awaited_once_with( + "Resolved edges text", + "summarize_search_results.txt", + None, + ) + + @pytest.mark.asyncio + async def test_resolve_edges_to_text_with_empty_edges(self): + """Test resolve_edges_to_text handles empty edges list.""" + retriever = GraphSummaryCompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text", + new_callable=AsyncMock, + return_value="", + ), + patch( + "cognee.modules.retrieval.graph_summary_completion_retriever.summarize_text", + new_callable=AsyncMock, + return_value="Empty summary", + ) as mock_summarize, + ): + result = await retriever.resolve_edges_to_text([]) + + assert result == "Empty summary" + mock_summarize.assert_awaited_once_with( + "", + "summarize_search_results.txt", + None, + ) + + @pytest.mark.asyncio + async def test_resolve_edges_to_text_with_multiple_edges(self, mock_edge): + """Test resolve_edges_to_text handles multiple edges.""" + retriever = GraphSummaryCompletionRetriever() + + mock_edge2 = MagicMock(spec=Edge) + mock_edge3 = MagicMock(spec=Edge) + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text", + new_callable=AsyncMock, + return_value="Multiple edges resolved text", + ), + patch( + "cognee.modules.retrieval.graph_summary_completion_retriever.summarize_text", + new_callable=AsyncMock, + return_value="Multiple edges summarized", + ) as mock_summarize, + ): + result = await retriever.resolve_edges_to_text([mock_edge, mock_edge2, mock_edge3]) + + assert result == "Multiple edges summarized" + mock_summarize.assert_awaited_once_with( + "Multiple edges resolved text", + "summarize_search_results.txt", + None, + ) diff --git a/cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py b/cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py new file mode 100644 index 000000000..a1e746bb9 --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py @@ -0,0 +1,312 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock +from uuid import UUID, NAMESPACE_OID, uuid5 + +from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback +from cognee.modules.retrieval.utils.models import UserFeedbackEvaluation, UserFeedbackSentiment +from cognee.modules.engine.models import NodeSet + + +@pytest.fixture +def mock_feedback_evaluation(): + """Create a mock feedback evaluation.""" + evaluation = MagicMock(spec=UserFeedbackEvaluation) + evaluation.evaluation = MagicMock() + evaluation.evaluation.value = "positive" + evaluation.score = 4.5 + return evaluation + + +@pytest.fixture +def mock_graph_engine(): + """Create a mock graph engine.""" + engine = AsyncMock() + engine.get_last_user_interaction_ids = AsyncMock(return_value=[]) + engine.add_edges = AsyncMock() + engine.apply_feedback_weight = AsyncMock() + return engine + + +class TestUserQAFeedback: + @pytest.mark.asyncio + async def test_init_default(self): + """Test UserQAFeedback initialization with default last_k.""" + retriever = UserQAFeedback() + assert retriever.last_k == 1 + + @pytest.mark.asyncio + async def test_init_custom_last_k(self): + """Test UserQAFeedback initialization with custom last_k.""" + retriever = UserQAFeedback(last_k=5) + assert retriever.last_k == 5 + + @pytest.mark.asyncio + async def test_add_feedback_success_with_relationships( + self, mock_feedback_evaluation, mock_graph_engine + ): + """Test add_feedback successfully creates feedback with relationships.""" + interaction_id_1 = str(UUID("550e8400-e29b-41d4-a716-446655440000")) + interaction_id_2 = str(UUID("550e8400-e29b-41d4-a716-446655440001")) + mock_graph_engine.get_last_user_interaction_ids = AsyncMock( + return_value=[interaction_id_1, interaction_id_2] + ) + + feedback_text = "This answer was helpful" + + with ( + patch( + "cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_feedback_evaluation, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.add_data_points", + new_callable=AsyncMock, + ) as mock_add_data, + patch( + "cognee.modules.retrieval.user_qa_feedback.index_graph_edges", + new_callable=AsyncMock, + ) as mock_index_edges, + ): + retriever = UserQAFeedback(last_k=2) + result = await retriever.add_feedback(feedback_text) + + assert result == [feedback_text] + mock_add_data.assert_awaited_once() + mock_graph_engine.add_edges.assert_awaited_once() + mock_index_edges.assert_awaited_once() + mock_graph_engine.apply_feedback_weight.assert_awaited_once() + + # Verify add_edges was called with correct relationships + call_args = mock_graph_engine.add_edges.call_args[0][0] + assert len(call_args) == 2 + assert call_args[0][0] == uuid5(NAMESPACE_OID, name=feedback_text) + assert call_args[0][1] == UUID(interaction_id_1) + assert call_args[0][2] == "gives_feedback_to" + assert call_args[0][3]["relationship_name"] == "gives_feedback_to" + assert call_args[0][3]["ontology_valid"] is False + + # Verify apply_feedback_weight was called with correct node IDs + weight_call_args = mock_graph_engine.apply_feedback_weight.call_args[1]["node_ids"] + assert len(weight_call_args) == 2 + assert interaction_id_1 in weight_call_args + assert interaction_id_2 in weight_call_args + + @pytest.mark.asyncio + async def test_add_feedback_success_no_relationships( + self, mock_feedback_evaluation, mock_graph_engine + ): + """Test add_feedback successfully creates feedback without relationships.""" + mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[]) + + feedback_text = "This answer was helpful" + + with ( + patch( + "cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_feedback_evaluation, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.add_data_points", + new_callable=AsyncMock, + ) as mock_add_data, + patch( + "cognee.modules.retrieval.user_qa_feedback.index_graph_edges", + new_callable=AsyncMock, + ) as mock_index_edges, + ): + retriever = UserQAFeedback(last_k=1) + result = await retriever.add_feedback(feedback_text) + + assert result == [feedback_text] + mock_add_data.assert_awaited_once() + # Should not call add_edges or index_graph_edges when no relationships + mock_graph_engine.add_edges.assert_not_awaited() + mock_index_edges.assert_not_awaited() + mock_graph_engine.apply_feedback_weight.assert_not_awaited() + + @pytest.mark.asyncio + async def test_add_feedback_creates_correct_feedback_node( + self, mock_feedback_evaluation, mock_graph_engine + ): + """Test add_feedback creates CogneeUserFeedback with correct attributes.""" + mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[]) + + feedback_text = "This was a negative experience" + mock_feedback_evaluation.evaluation.value = "negative" + mock_feedback_evaluation.score = -3.0 + + with ( + patch( + "cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_feedback_evaluation, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.add_data_points", + new_callable=AsyncMock, + ) as mock_add_data, + ): + retriever = UserQAFeedback() + await retriever.add_feedback(feedback_text) + + # Verify add_data_points was called with correct CogneeUserFeedback + call_args = mock_add_data.call_args[1]["data_points"] + assert len(call_args) == 1 + feedback_node = call_args[0] + assert feedback_node.id == uuid5(NAMESPACE_OID, name=feedback_text) + assert feedback_node.feedback == feedback_text + assert feedback_node.sentiment == "negative" + assert feedback_node.score == -3.0 + assert isinstance(feedback_node.belongs_to_set, NodeSet) + assert feedback_node.belongs_to_set.name == "UserQAFeedbacks" + + @pytest.mark.asyncio + async def test_add_feedback_calls_llm_with_correct_prompt( + self, mock_feedback_evaluation, mock_graph_engine + ): + """Test add_feedback calls LLM with correct sentiment analysis prompt.""" + mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[]) + + feedback_text = "Great answer!" + + with ( + patch( + "cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_feedback_evaluation, + ) as mock_llm, + patch( + "cognee.modules.retrieval.user_qa_feedback.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.add_data_points", + new_callable=AsyncMock, + ), + ): + retriever = UserQAFeedback() + await retriever.add_feedback(feedback_text) + + mock_llm.assert_awaited_once() + call_kwargs = mock_llm.call_args[1] + assert call_kwargs["text_input"] == feedback_text + assert "sentiment analysis assistant" in call_kwargs["system_prompt"] + assert call_kwargs["response_model"] == UserFeedbackEvaluation + + @pytest.mark.asyncio + async def test_add_feedback_uses_last_k_parameter( + self, mock_feedback_evaluation, mock_graph_engine + ): + """Test add_feedback uses last_k parameter when getting interaction IDs.""" + mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[]) + + feedback_text = "Test feedback" + + with ( + patch( + "cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_feedback_evaluation, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.add_data_points", + new_callable=AsyncMock, + ), + ): + retriever = UserQAFeedback(last_k=5) + await retriever.add_feedback(feedback_text) + + mock_graph_engine.get_last_user_interaction_ids.assert_awaited_once_with(limit=5) + + @pytest.mark.asyncio + async def test_add_feedback_with_single_interaction( + self, mock_feedback_evaluation, mock_graph_engine + ): + """Test add_feedback with single interaction ID.""" + interaction_id = str(UUID("550e8400-e29b-41d4-a716-446655440000")) + mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[interaction_id]) + + feedback_text = "Test feedback" + + with ( + patch( + "cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_feedback_evaluation, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.add_data_points", + new_callable=AsyncMock, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.index_graph_edges", + new_callable=AsyncMock, + ), + ): + retriever = UserQAFeedback() + result = await retriever.add_feedback(feedback_text) + + assert result == [feedback_text] + # Should create relationship for the interaction + call_args = mock_graph_engine.add_edges.call_args[0][0] + assert len(call_args) == 1 + assert call_args[0][1] == UUID(interaction_id) + + @pytest.mark.asyncio + async def test_add_feedback_applies_weight_correctly( + self, mock_feedback_evaluation, mock_graph_engine + ): + """Test add_feedback applies feedback weight with correct score.""" + interaction_id = str(UUID("550e8400-e29b-41d4-a716-446655440000")) + mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[interaction_id]) + mock_feedback_evaluation.score = 4.5 + + feedback_text = "Positive feedback" + + with ( + patch( + "cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_feedback_evaluation, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.add_data_points", + new_callable=AsyncMock, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.index_graph_edges", + new_callable=AsyncMock, + ), + ): + retriever = UserQAFeedback() + await retriever.add_feedback(feedback_text) + + mock_graph_engine.apply_feedback_weight.assert_awaited_once_with( + node_ids=[interaction_id], weight=4.5 + ) diff --git a/cognee/tests/unit/modules/retrieval/triplet_retriever_test.py b/cognee/tests/unit/modules/retrieval/triplet_retriever_test.py index d79aca428..83612c7aa 100644 --- a/cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/triplet_retriever_test.py @@ -81,3 +81,249 @@ async def test_get_context_collection_not_found_error(mock_vector_engine): ): with pytest.raises(NoDataError, match="No data found"): await retriever.get_context("test query") + + +@pytest.mark.asyncio +async def test_get_context_empty_payload_text(mock_vector_engine): + """Test get_context handles missing text in payload.""" + mock_result = MagicMock() + mock_result.payload = {} + + mock_vector_engine.search.return_value = [mock_result] + + retriever = TripletRetriever() + + with patch( + "cognee.modules.retrieval.triplet_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + with pytest.raises(KeyError): + await retriever.get_context("test query") + + +@pytest.mark.asyncio +async def test_get_context_single_triplet(mock_vector_engine): + """Test get_context with single triplet result.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Single triplet"} + + mock_vector_engine.search.return_value = [mock_result] + + retriever = TripletRetriever() + + with patch( + "cognee.modules.retrieval.triplet_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") + + assert context == "Single triplet" + + +@pytest.mark.asyncio +async def test_init_defaults(): + """Test TripletRetriever initialization with defaults.""" + retriever = TripletRetriever() + + assert retriever.user_prompt_path == "context_for_question.txt" + assert retriever.system_prompt_path == "answer_simple_question.txt" + assert retriever.top_k == 5 # Default is 5 + assert retriever.system_prompt is None + + +@pytest.mark.asyncio +async def test_init_custom_params(): + """Test TripletRetriever initialization with custom parameters.""" + retriever = TripletRetriever( + user_prompt_path="custom_user.txt", + system_prompt_path="custom_system.txt", + system_prompt="Custom prompt", + top_k=10, + ) + + assert retriever.user_prompt_path == "custom_user.txt" + assert retriever.system_prompt_path == "custom_system.txt" + assert retriever.system_prompt == "Custom prompt" + assert retriever.top_k == 10 + + +@pytest.mark.asyncio +async def test_get_completion_without_context(mock_vector_engine): + """Test get_completion retrieves context when not provided.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Test triplet"} + mock_vector_engine.has_collection.return_value = True + mock_vector_engine.search.return_value = [mock_result] + + retriever = TripletRetriever() + + with ( + patch( + "cognee.modules.retrieval.triplet_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.triplet_retriever.generate_completion", + return_value="Generated answer", + ), + patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_provided_context(mock_vector_engine): + """Test get_completion uses provided context.""" + retriever = TripletRetriever() + + with ( + patch( + "cognee.modules.retrieval.triplet_retriever.generate_completion", + return_value="Generated answer", + ), + patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", context="Provided context") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_session(mock_vector_engine): + """Test get_completion with session caching enabled.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Test triplet"} + mock_vector_engine.has_collection.return_value = True + mock_vector_engine.search.return_value = [mock_result] + + retriever = TripletRetriever() + + mock_user = MagicMock() + mock_user.id = "test-user-id" + + with ( + patch( + "cognee.modules.retrieval.triplet_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.triplet_retriever.get_conversation_history", + return_value="Previous conversation", + ), + patch( + "cognee.modules.retrieval.triplet_retriever.summarize_text", + return_value="Context summary", + ), + patch( + "cognee.modules.retrieval.triplet_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.triplet_retriever.save_conversation_history", + ) as mock_save, + patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config, + patch("cognee.modules.retrieval.triplet_retriever.session_user") as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = mock_user + + completion = await retriever.get_completion("test query", session_id="test_session") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + mock_save.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_completion_with_session_no_user_id(mock_vector_engine): + """Test get_completion with session config but no user ID.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Test triplet"} + mock_vector_engine.has_collection.return_value = True + mock_vector_engine.search.return_value = [mock_result] + + retriever = TripletRetriever() + + with ( + patch( + "cognee.modules.retrieval.triplet_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.triplet_retriever.generate_completion", + return_value="Generated answer", + ), + patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config, + patch("cognee.modules.retrieval.triplet_retriever.session_user") as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = None # No user + + completion = await retriever.get_completion("test query") + + assert isinstance(completion, list) + assert len(completion) == 1 + + +@pytest.mark.asyncio +async def test_get_completion_with_response_model(mock_vector_engine): + """Test get_completion with custom response model.""" + from pydantic import BaseModel + + class TestModel(BaseModel): + answer: str + + mock_result = MagicMock() + mock_result.payload = {"text": "Test triplet"} + mock_vector_engine.has_collection.return_value = True + mock_vector_engine.search.return_value = [mock_result] + + retriever = TripletRetriever() + + with ( + patch( + "cognee.modules.retrieval.triplet_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.triplet_retriever.generate_completion", + return_value=TestModel(answer="Test answer"), + ), + patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", response_model=TestModel) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert isinstance(completion[0], TestModel) + + +@pytest.mark.asyncio +async def test_init_none_top_k(): + """Test TripletRetriever initialization with None top_k.""" + retriever = TripletRetriever(top_k=None) + + assert retriever.top_k == 5 From 7cf6f082835cd45962aacdeff1bfffcbc6cdebb6 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Fri, 12 Dec 2025 15:29:21 +0100 Subject: [PATCH 055/176] chore: update test credentials --- .github/workflows/examples_tests.yml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/examples_tests.yml b/.github/workflows/examples_tests.yml index f8f3e5aa3..95a14c9ac 100644 --- a/.github/workflows/examples_tests.yml +++ b/.github/workflows/examples_tests.yml @@ -315,10 +315,12 @@ jobs: EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} STORAGE_BACKEND: 's3' AWS_REGION: eu-west-1 + AWS_ENDPOINT_URL: https://s3-eu-west-1.amazonaws.com + AWS_ACCESS_KEY_ID: ${{ secrets.AWS_S3_DEV_USER_KEY_ID }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_S3_DEV_USER_SECRET_KEY }} + STORAGE_BUCKET_NAME: github-runner-cognee-tests DATA_ROOT_DIRECTORY: "s3://github-runner-cognee-tests/cognee/data" SYSTEM_ROOT_DIRECTORY: "s3://github-runner-cognee-tests/cognee/system" - AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} - AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} DB_PROVIDER: 'postgres' DB_NAME: 'cognee_db' DB_HOST: '127.0.0.1' From fa035f42f40715b616e6c926b00515dbb35c80da Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Fri, 12 Dec 2025 16:47:58 +0100 Subject: [PATCH 056/176] chore: adds back accidentally deleted structured output test --- .../retrieval/structured_output_test.py | 204 ++++++++++++++++++ 1 file changed, 204 insertions(+) create mode 100644 cognee/tests/unit/modules/retrieval/structured_output_test.py diff --git a/cognee/tests/unit/modules/retrieval/structured_output_test.py b/cognee/tests/unit/modules/retrieval/structured_output_test.py new file mode 100644 index 000000000..4ad3019ff --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/structured_output_test.py @@ -0,0 +1,204 @@ +import asyncio + +import pytest +import cognee +import pathlib +import os + +from pydantic import BaseModel +from cognee.low_level import setup, DataPoint +from cognee.tasks.storage import add_data_points +from cognee.modules.chunking.models import DocumentChunk +from cognee.modules.data.processing.document_types import TextDocument +from cognee.modules.engine.models import Entity, EntityType +from cognee.modules.retrieval.entity_extractors.DummyEntityExtractor import DummyEntityExtractor +from cognee.modules.retrieval.context_providers.DummyContextProvider import DummyContextProvider +from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever +from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever +from cognee.modules.retrieval.graph_completion_context_extension_retriever import ( + GraphCompletionContextExtensionRetriever, +) +from cognee.modules.retrieval.EntityCompletionRetriever import EntityCompletionRetriever +from cognee.modules.retrieval.temporal_retriever import TemporalRetriever +from cognee.modules.retrieval.completion_retriever import CompletionRetriever + + +class TestAnswer(BaseModel): + answer: str + explanation: str + + +def _assert_string_answer(answer: list[str]): + assert isinstance(answer, list), f"Expected str, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), "Items should be strings" + assert all(item.strip() for item in answer), "Items should not be empty" + + +def _assert_structured_answer(answer: list[TestAnswer]): + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(x, TestAnswer) for x in answer), "Items should be TestAnswer" + assert all(x.answer.strip() for x in answer), "Answer text should not be empty" + assert all(x.explanation.strip() for x in answer), "Explanation should not be empty" + + +async def _test_get_structured_graph_completion_cot(): + retriever = GraphCompletionCotRetriever() + + # Test with string response model (default) + string_answer = await retriever.get_completion("Who works at Figma?") + _assert_string_answer(string_answer) + + # Test with structured response model + structured_answer = await retriever.get_completion( + "Who works at Figma?", response_model=TestAnswer + ) + _assert_structured_answer(structured_answer) + + +async def _test_get_structured_graph_completion(): + retriever = GraphCompletionRetriever() + + # Test with string response model (default) + string_answer = await retriever.get_completion("Who works at Figma?") + _assert_string_answer(string_answer) + + # Test with structured response model + structured_answer = await retriever.get_completion( + "Who works at Figma?", response_model=TestAnswer + ) + _assert_structured_answer(structured_answer) + + +async def _test_get_structured_graph_completion_temporal(): + retriever = TemporalRetriever() + + # Test with string response model (default) + string_answer = await retriever.get_completion("When did Steve start working at Figma?") + _assert_string_answer(string_answer) + + # Test with structured response model + structured_answer = await retriever.get_completion( + "When did Steve start working at Figma??", response_model=TestAnswer + ) + _assert_structured_answer(structured_answer) + + +async def _test_get_structured_graph_completion_rag(): + retriever = CompletionRetriever() + + # Test with string response model (default) + string_answer = await retriever.get_completion("Where does Steve work?") + _assert_string_answer(string_answer) + + # Test with structured response model + structured_answer = await retriever.get_completion( + "Where does Steve work?", response_model=TestAnswer + ) + _assert_structured_answer(structured_answer) + + +async def _test_get_structured_graph_completion_context_extension(): + retriever = GraphCompletionContextExtensionRetriever() + + # Test with string response model (default) + string_answer = await retriever.get_completion("Who works at Figma?") + _assert_string_answer(string_answer) + + # Test with structured response model + structured_answer = await retriever.get_completion( + "Who works at Figma?", response_model=TestAnswer + ) + _assert_structured_answer(structured_answer) + + +async def _test_get_structured_entity_completion(): + retriever = EntityCompletionRetriever(DummyEntityExtractor(), DummyContextProvider()) + + # Test with string response model (default) + string_answer = await retriever.get_completion("Who is Albert Einstein?") + _assert_string_answer(string_answer) + + # Test with structured response model + structured_answer = await retriever.get_completion( + "Who is Albert Einstein?", response_model=TestAnswer + ) + _assert_structured_answer(structured_answer) + + +class TestStructuredOutputCompletion: + @pytest.mark.asyncio + async def test_get_structured_completion(self): + system_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".cognee_system/test_get_structured_completion" + ) + cognee.config.system_root_directory(system_directory_path) + data_directory_path = os.path.join( + pathlib.Path(__file__).parent, ".data_storage/test_get_structured_completion" + ) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + class Company(DataPoint): + name: str + + class Person(DataPoint): + name: str + works_for: Company + works_since: int + + company1 = Company(name="Figma") + person1 = Person(name="Steve Rodger", works_for=company1, works_since=2015) + + entities = [company1, person1] + await add_data_points(entities) + + document = TextDocument( + name="Steve Rodger's career", + raw_data_location="somewhere", + external_metadata="", + mime_type="text/plain", + ) + + chunk1 = DocumentChunk( + text="Steve Rodger", + chunk_size=2, + chunk_index=0, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + chunk2 = DocumentChunk( + text="Mike Broski", + chunk_size=2, + chunk_index=1, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + chunk3 = DocumentChunk( + text="Christina Mayer", + chunk_size=2, + chunk_index=2, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + + entities = [chunk1, chunk2, chunk3] + await add_data_points(entities) + + entity_type = EntityType(name="Person", description="A human individual") + entity = Entity(name="Albert Einstein", is_a=entity_type, description="A famous physicist") + + entities = [entity] + await add_data_points(entities) + + await _test_get_structured_graph_completion_cot() + await _test_get_structured_graph_completion() + await _test_get_structured_graph_completion_temporal() + await _test_get_structured_graph_completion_rag() + await _test_get_structured_graph_completion_context_extension() + await _test_get_structured_entity_completion() From 4e8845c117ecf892c3f5554c94de4f9f1171b9ff Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Tue, 16 Dec 2025 11:11:29 +0100 Subject: [PATCH 057/176] chore: retriever test reorganization + adding new tests (integration) (STEP 1) (#1881) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Description This PR restructures/adds integration and unit tests for the retrieval module. -Old integration tests were updated and moved under unit tests + fixtures added -Added missing unit tests for all core retrieval business logic -Covered 100% of the core retrievers with tests -Minor changes (dead code deletion, typo fixed) ## Type of Change - [ ] Bug fix (non-breaking change that fixes an issue) - [x] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [x] Code refactoring - [ ] Performance improvement - [ ] Other (please specify): ## Screenshots/Videos (if applicable) ## Pre-submission Checklist - [x] **I have tested my changes thoroughly before submitting this PR** - [x] **This PR contains minimal changes necessary to address the issue/feature** - [x] My code follows the project's coding standards and style guidelines - [x] I have added tests that prove my fix is effective or that my feature works - [x] I have added necessary documentation (if applicable) - [x] All new and existing tests pass - [x] I have searched existing PRs to ensure this change hasn't been submitted already - [x] I have linked any relevant issues in the description - [x] My commits have clear and descriptive messages ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. ## Summary by CodeRabbit * **Changes** * TripletRetriever now returns up to 5 results by default (was 1), providing richer context. * **Tests** * Reorganized test coverage: many unit tests removed and replaced with comprehensive integration tests across retrieval components (graph, chunks, RAG, summaries, temporal, triplets, structured output). * **Chores** * Simplified triplet formatting logic and removed debug output. ✏️ Tip: You can customize this high-level summary in your review settings. --- cognee/modules/retrieval/triplet_retriever.py | 2 +- .../utils/brute_force_triplet_search.py | 18 - .../retrieval/test_chunks_retriever.py | 252 ++++++++ .../test_graph_completion_retriever.py | 268 ++++++++ ..._completion_retriever_context_extension.py | 226 +++++++ .../test_graph_completion_retriever_cot.py | 218 +++++++ .../test_rag_completion_retriever.py | 254 ++++++++ .../retrieval/test_structured_output.py} | 162 ++--- .../retrieval/test_summaries_retriever.py | 184 ++++++ .../retrieval/test_temporal_retriever.py | 306 +++++++++ .../retrieval/test_triplet_retriever.py | 35 + .../eval_framework/benchmark_adapters_test.py | 25 + .../eval_framework/corpus_builder_test.py | 37 +- .../retrieval/chunks_retriever_test.py | 201 ------ .../retrieval/conversation_history_test.py | 154 ----- ...letion_retriever_context_extension_test.py | 177 ----- .../graph_completion_retriever_cot_test.py | 170 ----- .../graph_completion_retriever_test.py | 223 ------- .../rag_completion_retriever_test.py | 205 ------ .../retrieval/summaries_retriever_test.py | 159 ----- .../retrieval/temporal_retriever_test.py | 224 ------- .../test_brute_force_triplet_search.py | 608 ------------------ .../retrieval/triplet_retriever_test.py | 83 --- 23 files changed, 1888 insertions(+), 2303 deletions(-) create mode 100644 cognee/tests/integration/retrieval/test_chunks_retriever.py create mode 100644 cognee/tests/integration/retrieval/test_graph_completion_retriever.py create mode 100644 cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py create mode 100644 cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py create mode 100644 cognee/tests/integration/retrieval/test_rag_completion_retriever.py rename cognee/tests/{unit/modules/retrieval/structured_output_test.py => integration/retrieval/test_structured_output.py} (65%) create mode 100644 cognee/tests/integration/retrieval/test_summaries_retriever.py create mode 100644 cognee/tests/integration/retrieval/test_temporal_retriever.py delete mode 100644 cognee/tests/unit/modules/retrieval/chunks_retriever_test.py delete mode 100644 cognee/tests/unit/modules/retrieval/conversation_history_test.py delete mode 100644 cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py delete mode 100644 cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py delete mode 100644 cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py delete mode 100644 cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py delete mode 100644 cognee/tests/unit/modules/retrieval/summaries_retriever_test.py delete mode 100644 cognee/tests/unit/modules/retrieval/temporal_retriever_test.py delete mode 100644 cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py delete mode 100644 cognee/tests/unit/modules/retrieval/triplet_retriever_test.py diff --git a/cognee/modules/retrieval/triplet_retriever.py b/cognee/modules/retrieval/triplet_retriever.py index d251d113a..b9d006312 100644 --- a/cognee/modules/retrieval/triplet_retriever.py +++ b/cognee/modules/retrieval/triplet_retriever.py @@ -36,7 +36,7 @@ class TripletRetriever(BaseRetriever): """Initialize retriever with optional custom prompt paths.""" self.user_prompt_path = user_prompt_path self.system_prompt_path = system_prompt_path - self.top_k = top_k if top_k is not None else 1 + self.top_k = top_k if top_k is not None else 5 self.system_prompt = system_prompt async def get_context(self, query: str) -> str: diff --git a/cognee/modules/retrieval/utils/brute_force_triplet_search.py b/cognee/modules/retrieval/utils/brute_force_triplet_search.py index bd412e0ca..a70fa661b 100644 --- a/cognee/modules/retrieval/utils/brute_force_triplet_search.py +++ b/cognee/modules/retrieval/utils/brute_force_triplet_search.py @@ -16,24 +16,6 @@ logger = get_logger(level=ERROR) def format_triplets(edges): - print("\n\n\n") - - def filter_attributes(obj, attributes): - """Helper function to filter out non-None properties, including nested dicts.""" - result = {} - for attr in attributes: - value = getattr(obj, attr, None) - if value is not None: - # If the value is a dict, extract relevant keys from it - if isinstance(value, dict): - nested_values = { - k: v for k, v in value.items() if k in attributes and v is not None - } - result[attr] = nested_values - else: - result[attr] = value - return result - triplets = [] for edge in edges: node1 = edge.node1 diff --git a/cognee/tests/integration/retrieval/test_chunks_retriever.py b/cognee/tests/integration/retrieval/test_chunks_retriever.py new file mode 100644 index 000000000..d2e5e6149 --- /dev/null +++ b/cognee/tests/integration/retrieval/test_chunks_retriever.py @@ -0,0 +1,252 @@ +import os +import pytest +import pathlib +import pytest_asyncio +from typing import List +import cognee + +from cognee.low_level import setup +from cognee.tasks.storage import add_data_points +from cognee.infrastructure.databases.vector import get_vector_engine +from cognee.modules.chunking.models import DocumentChunk +from cognee.modules.data.processing.document_types import TextDocument +from cognee.modules.retrieval.exceptions.exceptions import NoDataError +from cognee.modules.retrieval.chunks_retriever import ChunksRetriever +from cognee.infrastructure.engine import DataPoint +from cognee.modules.data.processing.document_types import Document +from cognee.modules.engine.models import Entity + + +class DocumentChunkWithEntities(DataPoint): + text: str + chunk_size: int + chunk_index: int + cut_type: str + is_part_of: Document + contains: List[Entity] = None + + metadata: dict = {"index_fields": ["text"]} + + +@pytest_asyncio.fixture +async def setup_test_environment_with_chunks_simple(): + """Set up a clean test environment with simple chunks.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str(base_dir / ".cognee_system/test_chunks_retriever_context_simple") + data_directory_path = str(base_dir / ".data_storage/test_chunks_retriever_context_simple") + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + document = TextDocument( + name="Steve Rodger's career", + raw_data_location="somewhere", + external_metadata="", + mime_type="text/plain", + ) + + chunk1 = DocumentChunk( + text="Steve Rodger", + chunk_size=2, + chunk_index=0, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + chunk2 = DocumentChunk( + text="Mike Broski", + chunk_size=2, + chunk_index=1, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + chunk3 = DocumentChunk( + text="Christina Mayer", + chunk_size=2, + chunk_index=2, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + + entities = [chunk1, chunk2, chunk3] + + await add_data_points(entities) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest_asyncio.fixture +async def setup_test_environment_with_chunks_complex(): + """Set up a clean test environment with complex chunks.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str(base_dir / ".cognee_system/test_chunks_retriever_context_complex") + data_directory_path = str(base_dir / ".data_storage/test_chunks_retriever_context_complex") + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + document1 = TextDocument( + name="Employee List", + raw_data_location="somewhere", + external_metadata="", + mime_type="text/plain", + ) + + document2 = TextDocument( + name="Car List", + raw_data_location="somewhere", + external_metadata="", + mime_type="text/plain", + ) + + chunk1 = DocumentChunk( + text="Steve Rodger", + chunk_size=2, + chunk_index=0, + cut_type="sentence_end", + is_part_of=document1, + contains=[], + ) + chunk2 = DocumentChunk( + text="Mike Broski", + chunk_size=2, + chunk_index=1, + cut_type="sentence_end", + is_part_of=document1, + contains=[], + ) + chunk3 = DocumentChunk( + text="Christina Mayer", + chunk_size=2, + chunk_index=2, + cut_type="sentence_end", + is_part_of=document1, + contains=[], + ) + + chunk4 = DocumentChunk( + text="Range Rover", + chunk_size=2, + chunk_index=0, + cut_type="sentence_end", + is_part_of=document2, + contains=[], + ) + chunk5 = DocumentChunk( + text="Hyundai", + chunk_size=2, + chunk_index=1, + cut_type="sentence_end", + is_part_of=document2, + contains=[], + ) + chunk6 = DocumentChunk( + text="Chrysler", + chunk_size=2, + chunk_index=2, + cut_type="sentence_end", + is_part_of=document2, + contains=[], + ) + + entities = [chunk1, chunk2, chunk3, chunk4, chunk5, chunk6] + + await add_data_points(entities) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest_asyncio.fixture +async def setup_test_environment_empty(): + """Set up a clean test environment without chunks.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str(base_dir / ".cognee_system/test_chunks_retriever_context_empty") + data_directory_path = str(base_dir / ".data_storage/test_chunks_retriever_context_empty") + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest.mark.asyncio +async def test_chunks_retriever_context_multiple_chunks(setup_test_environment_with_chunks_simple): + """Integration test: verify ChunksRetriever can retrieve multiple chunks.""" + retriever = ChunksRetriever() + + context = await retriever.get_context("Steve") + + assert isinstance(context, list), "Context should be a list" + assert len(context) > 0, "Context should not be empty" + assert any(chunk["text"] == "Steve Rodger" for chunk in context), ( + "Failed to get Steve Rodger chunk" + ) + + +@pytest.mark.asyncio +async def test_chunks_retriever_top_k_limit(setup_test_environment_with_chunks_complex): + """Integration test: verify ChunksRetriever respects top_k parameter.""" + retriever = ChunksRetriever(top_k=2) + + context = await retriever.get_context("Employee") + + assert isinstance(context, list), "Context should be a list" + assert len(context) <= 2, "Should respect top_k limit" + + +@pytest.mark.asyncio +async def test_chunks_retriever_context_complex(setup_test_environment_with_chunks_complex): + """Integration test: verify ChunksRetriever can retrieve chunk context (complex).""" + retriever = ChunksRetriever(top_k=20) + + context = await retriever.get_context("Christina") + + assert context[0]["text"] == "Christina Mayer", "Failed to get Christina Mayer" + + +@pytest.mark.asyncio +async def test_chunks_retriever_context_on_empty_graph(setup_test_environment_empty): + """Integration test: verify ChunksRetriever handles empty graph correctly.""" + retriever = ChunksRetriever() + + with pytest.raises(NoDataError): + await retriever.get_context("Christina Mayer") + + vector_engine = get_vector_engine() + await vector_engine.create_collection( + "DocumentChunk_text", payload_schema=DocumentChunkWithEntities + ) + + context = await retriever.get_context("Christina Mayer") + assert len(context) == 0, "Found chunks when none should exist" diff --git a/cognee/tests/integration/retrieval/test_graph_completion_retriever.py b/cognee/tests/integration/retrieval/test_graph_completion_retriever.py new file mode 100644 index 000000000..7367b353b --- /dev/null +++ b/cognee/tests/integration/retrieval/test_graph_completion_retriever.py @@ -0,0 +1,268 @@ +import os +import pytest +import pathlib +import pytest_asyncio +from typing import Optional, Union +import cognee + +from cognee.low_level import setup, DataPoint +from cognee.modules.graph.utils import resolve_edges_to_text +from cognee.tasks.storage import add_data_points +from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever + + +@pytest_asyncio.fixture +async def setup_test_environment_simple(): + """Set up a clean test environment with simple graph data.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str(base_dir / ".cognee_system/test_graph_completion_context_simple") + data_directory_path = str(base_dir / ".data_storage/test_graph_completion_context_simple") + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + class Company(DataPoint): + name: str + description: str + + class Person(DataPoint): + name: str + description: str + works_for: Company + + company1 = Company(name="Figma", description="Figma is a company") + company2 = Company(name="Canva", description="Canvas is a company") + person1 = Person( + name="Steve Rodger", + description="This is description about Steve Rodger", + works_for=company1, + ) + person2 = Person( + name="Ike Loma", description="This is description about Ike Loma", works_for=company1 + ) + person3 = Person( + name="Jason Statham", + description="This is description about Jason Statham", + works_for=company1, + ) + person4 = Person( + name="Mike Broski", + description="This is description about Mike Broski", + works_for=company2, + ) + person5 = Person( + name="Christina Mayer", + description="This is description about Christina Mayer", + works_for=company2, + ) + + entities = [company1, company2, person1, person2, person3, person4, person5] + + await add_data_points(entities) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest_asyncio.fixture +async def setup_test_environment_complex(): + """Set up a clean test environment with complex graph data.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str(base_dir / ".cognee_system/test_graph_completion_context_complex") + data_directory_path = str(base_dir / ".data_storage/test_graph_completion_context_complex") + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + class Company(DataPoint): + name: str + metadata: dict = {"index_fields": ["name"]} + + class Car(DataPoint): + brand: str + model: str + year: int + + class Location(DataPoint): + country: str + city: str + + class Home(DataPoint): + location: Location + rooms: int + sqm: int + + class Person(DataPoint): + name: str + works_for: Company + owns: Optional[list[Union[Car, Home]]] = None + + company1 = Company(name="Figma") + company2 = Company(name="Canva") + + person1 = Person(name="Mike Rodger", works_for=company1) + person1.owns = [Car(brand="Toyota", model="Camry", year=2020)] + + person2 = Person(name="Ike Loma", works_for=company1) + person2.owns = [ + Car(brand="Tesla", model="Model S", year=2021), + Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4), + ] + + person3 = Person(name="Jason Statham", works_for=company1) + + person4 = Person(name="Mike Broski", works_for=company2) + person4.owns = [Car(brand="Ford", model="Mustang", year=1978)] + + person5 = Person(name="Christina Mayer", works_for=company2) + person5.owns = [Car(brand="Honda", model="Civic", year=2023)] + + entities = [company1, company2, person1, person2, person3, person4, person5] + + await add_data_points(entities) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest_asyncio.fixture +async def setup_test_environment_empty(): + """Set up a clean test environment without graph data.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str( + base_dir / ".cognee_system/test_get_graph_completion_context_on_empty_graph" + ) + data_directory_path = str( + base_dir / ".data_storage/test_get_graph_completion_context_on_empty_graph" + ) + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest.mark.asyncio +async def test_graph_completion_context_simple(setup_test_environment_simple): + """Integration test: verify GraphCompletionRetriever can retrieve context (simple).""" + retriever = GraphCompletionRetriever() + + context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?")) + + # Ensure the top-level sections are present + assert "Nodes:" in context, "Missing 'Nodes:' section in context" + assert "Connections:" in context, "Missing 'Connections:' section in context" + + # --- Nodes headers --- + assert "Node: Steve Rodger" in context, "Missing node header for Steve Rodger" + assert "Node: Figma" in context, "Missing node header for Figma" + assert "Node: Ike Loma" in context, "Missing node header for Ike Loma" + assert "Node: Jason Statham" in context, "Missing node header for Jason Statham" + assert "Node: Mike Broski" in context, "Missing node header for Mike Broski" + assert "Node: Canva" in context, "Missing node header for Canva" + assert "Node: Christina Mayer" in context, "Missing node header for Christina Mayer" + + # --- Node contents --- + assert ( + "__node_content_start__\nThis is description about Steve Rodger\n__node_content_end__" + in context + ), "Description block for Steve Rodger altered" + assert "__node_content_start__\nFigma is a company\n__node_content_end__" in context, ( + "Description block for Figma altered" + ) + assert ( + "__node_content_start__\nThis is description about Ike Loma\n__node_content_end__" + in context + ), "Description block for Ike Loma altered" + assert ( + "__node_content_start__\nThis is description about Jason Statham\n__node_content_end__" + in context + ), "Description block for Jason Statham altered" + assert ( + "__node_content_start__\nThis is description about Mike Broski\n__node_content_end__" + in context + ), "Description block for Mike Broski altered" + assert "__node_content_start__\nCanvas is a company\n__node_content_end__" in context, ( + "Description block for Canva altered" + ) + assert ( + "__node_content_start__\nThis is description about Christina Mayer\n__node_content_end__" + in context + ), "Description block for Christina Mayer altered" + + # --- Connections --- + assert "Steve Rodger --[works_for]--> Figma" in context, ( + "Connection Steve Rodger→Figma missing or changed" + ) + assert "Ike Loma --[works_for]--> Figma" in context, ( + "Connection Ike Loma→Figma missing or changed" + ) + assert "Jason Statham --[works_for]--> Figma" in context, ( + "Connection Jason Statham→Figma missing or changed" + ) + assert "Mike Broski --[works_for]--> Canva" in context, ( + "Connection Mike Broski→Canva missing or changed" + ) + assert "Christina Mayer --[works_for]--> Canva" in context, ( + "Connection Christina Mayer→Canva missing or changed" + ) + + +@pytest.mark.asyncio +async def test_graph_completion_context_complex(setup_test_environment_complex): + """Integration test: verify GraphCompletionRetriever can retrieve context (complex).""" + retriever = GraphCompletionRetriever(top_k=20) + + context = await resolve_edges_to_text(await retriever.get_context("Who works at Figma?")) + + assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger" + assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma" + assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham" + + +@pytest.mark.asyncio +async def test_get_graph_completion_context_on_empty_graph(setup_test_environment_empty): + """Integration test: verify GraphCompletionRetriever handles empty graph correctly.""" + retriever = GraphCompletionRetriever() + + context = await retriever.get_context("Who works at Figma?") + assert context == [], "Context should be empty on an empty graph" + + +@pytest.mark.asyncio +async def test_graph_completion_get_triplets_empty(setup_test_environment_empty): + """Integration test: verify GraphCompletionRetriever get_triplets handles empty graph.""" + retriever = GraphCompletionRetriever() + + triplets = await retriever.get_triplets("Who works at Figma?") + + assert isinstance(triplets, list), "Triplets should be a list" + assert len(triplets) == 0, "Should return empty list on empty graph" diff --git a/cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py b/cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py new file mode 100644 index 000000000..c87de16ef --- /dev/null +++ b/cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py @@ -0,0 +1,226 @@ +import os +import pytest +import pathlib +import pytest_asyncio +from typing import Optional, Union +import cognee + +from cognee.low_level import setup, DataPoint +from cognee.tasks.storage import add_data_points +from cognee.modules.graph.utils import resolve_edges_to_text +from cognee.modules.retrieval.graph_completion_context_extension_retriever import ( + GraphCompletionContextExtensionRetriever, +) + + +@pytest_asyncio.fixture +async def setup_test_environment_simple(): + """Set up a clean test environment with simple graph data.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str( + base_dir / ".cognee_system/test_graph_completion_extension_context_simple" + ) + data_directory_path = str( + base_dir / ".data_storage/test_graph_completion_extension_context_simple" + ) + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + class Company(DataPoint): + name: str + + class Person(DataPoint): + name: str + works_for: Company + + company1 = Company(name="Figma") + company2 = Company(name="Canva") + person1 = Person(name="Steve Rodger", works_for=company1) + person2 = Person(name="Ike Loma", works_for=company1) + person3 = Person(name="Jason Statham", works_for=company1) + person4 = Person(name="Mike Broski", works_for=company2) + person5 = Person(name="Christina Mayer", works_for=company2) + + entities = [company1, company2, person1, person2, person3, person4, person5] + + await add_data_points(entities) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest_asyncio.fixture +async def setup_test_environment_complex(): + """Set up a clean test environment with complex graph data.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str( + base_dir / ".cognee_system/test_graph_completion_extension_context_complex" + ) + data_directory_path = str( + base_dir / ".data_storage/test_graph_completion_extension_context_complex" + ) + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + class Company(DataPoint): + name: str + metadata: dict = {"index_fields": ["name"]} + + class Car(DataPoint): + brand: str + model: str + year: int + + class Location(DataPoint): + country: str + city: str + + class Home(DataPoint): + location: Location + rooms: int + sqm: int + + class Person(DataPoint): + name: str + works_for: Company + owns: Optional[list[Union[Car, Home]]] = None + + company1 = Company(name="Figma") + company2 = Company(name="Canva") + + person1 = Person(name="Mike Rodger", works_for=company1) + person1.owns = [Car(brand="Toyota", model="Camry", year=2020)] + + person2 = Person(name="Ike Loma", works_for=company1) + person2.owns = [ + Car(brand="Tesla", model="Model S", year=2021), + Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4), + ] + + person3 = Person(name="Jason Statham", works_for=company1) + + person4 = Person(name="Mike Broski", works_for=company2) + person4.owns = [Car(brand="Ford", model="Mustang", year=1978)] + + person5 = Person(name="Christina Mayer", works_for=company2) + person5.owns = [Car(brand="Honda", model="Civic", year=2023)] + + entities = [company1, company2, person1, person2, person3, person4, person5] + + await add_data_points(entities) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest_asyncio.fixture +async def setup_test_environment_empty(): + """Set up a clean test environment without graph data.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str( + base_dir / ".cognee_system/test_get_graph_completion_extension_context_on_empty_graph" + ) + data_directory_path = str( + base_dir / ".data_storage/test_get_graph_completion_extension_context_on_empty_graph" + ) + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest.mark.asyncio +async def test_graph_completion_extension_context_simple(setup_test_environment_simple): + """Integration test: verify GraphCompletionContextExtensionRetriever can retrieve context (simple).""" + retriever = GraphCompletionContextExtensionRetriever() + + context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?")) + + assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski" + assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer" + + answer = await retriever.get_completion("Who works at Canva?") + + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), ( + "Answer must contain only non-empty strings" + ) + + +@pytest.mark.asyncio +async def test_graph_completion_extension_context_complex(setup_test_environment_complex): + """Integration test: verify GraphCompletionContextExtensionRetriever can retrieve context (complex).""" + retriever = GraphCompletionContextExtensionRetriever(top_k=20) + + context = await resolve_edges_to_text( + await retriever.get_context("Who works at Figma and drives Tesla?") + ) + + assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger" + assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma" + assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham" + + answer = await retriever.get_completion("Who works at Figma?") + + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), ( + "Answer must contain only non-empty strings" + ) + + +@pytest.mark.asyncio +async def test_get_graph_completion_extension_context_on_empty_graph(setup_test_environment_empty): + """Integration test: verify GraphCompletionContextExtensionRetriever handles empty graph correctly.""" + retriever = GraphCompletionContextExtensionRetriever() + + context = await retriever.get_context("Who works at Figma?") + assert context == [], "Context should be empty on an empty graph" + + answer = await retriever.get_completion("Who works at Figma?") + + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), ( + "Answer must contain only non-empty strings" + ) + + +@pytest.mark.asyncio +async def test_graph_completion_extension_get_triplets_empty(setup_test_environment_empty): + """Integration test: verify GraphCompletionContextExtensionRetriever get_triplets handles empty graph.""" + retriever = GraphCompletionContextExtensionRetriever() + + triplets = await retriever.get_triplets("Who works at Figma?") + + assert isinstance(triplets, list), "Triplets should be a list" + assert len(triplets) == 0, "Should return empty list on empty graph" diff --git a/cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py b/cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py new file mode 100644 index 000000000..0db035e03 --- /dev/null +++ b/cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py @@ -0,0 +1,218 @@ +import os +import pytest +import pathlib +import pytest_asyncio +from typing import Optional, Union +import cognee + +from cognee.low_level import setup, DataPoint +from cognee.modules.graph.utils import resolve_edges_to_text +from cognee.tasks.storage import add_data_points +from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever + + +@pytest_asyncio.fixture +async def setup_test_environment_simple(): + """Set up a clean test environment with simple graph data.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str( + base_dir / ".cognee_system/test_graph_completion_cot_context_simple" + ) + data_directory_path = str(base_dir / ".data_storage/test_graph_completion_cot_context_simple") + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + class Company(DataPoint): + name: str + + class Person(DataPoint): + name: str + works_for: Company + + company1 = Company(name="Figma") + company2 = Company(name="Canva") + person1 = Person(name="Steve Rodger", works_for=company1) + person2 = Person(name="Ike Loma", works_for=company1) + person3 = Person(name="Jason Statham", works_for=company1) + person4 = Person(name="Mike Broski", works_for=company2) + person5 = Person(name="Christina Mayer", works_for=company2) + + entities = [company1, company2, person1, person2, person3, person4, person5] + + await add_data_points(entities) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest_asyncio.fixture +async def setup_test_environment_complex(): + """Set up a clean test environment with complex graph data.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str( + base_dir / ".cognee_system/test_graph_completion_cot_context_complex" + ) + data_directory_path = str(base_dir / ".data_storage/test_graph_completion_cot_context_complex") + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + class Company(DataPoint): + name: str + metadata: dict = {"index_fields": ["name"]} + + class Car(DataPoint): + brand: str + model: str + year: int + + class Location(DataPoint): + country: str + city: str + + class Home(DataPoint): + location: Location + rooms: int + sqm: int + + class Person(DataPoint): + name: str + works_for: Company + owns: Optional[list[Union[Car, Home]]] = None + + company1 = Company(name="Figma") + company2 = Company(name="Canva") + + person1 = Person(name="Mike Rodger", works_for=company1) + person1.owns = [Car(brand="Toyota", model="Camry", year=2020)] + + person2 = Person(name="Ike Loma", works_for=company1) + person2.owns = [ + Car(brand="Tesla", model="Model S", year=2021), + Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4), + ] + + person3 = Person(name="Jason Statham", works_for=company1) + + person4 = Person(name="Mike Broski", works_for=company2) + person4.owns = [Car(brand="Ford", model="Mustang", year=1978)] + + person5 = Person(name="Christina Mayer", works_for=company2) + person5.owns = [Car(brand="Honda", model="Civic", year=2023)] + + entities = [company1, company2, person1, person2, person3, person4, person5] + + await add_data_points(entities) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest_asyncio.fixture +async def setup_test_environment_empty(): + """Set up a clean test environment without graph data.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str( + base_dir / ".cognee_system/test_get_graph_completion_cot_context_on_empty_graph" + ) + data_directory_path = str( + base_dir / ".data_storage/test_get_graph_completion_cot_context_on_empty_graph" + ) + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest.mark.asyncio +async def test_graph_completion_cot_context_simple(setup_test_environment_simple): + """Integration test: verify GraphCompletionCotRetriever can retrieve context (simple).""" + retriever = GraphCompletionCotRetriever() + + context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?")) + + assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski" + assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer" + + answer = await retriever.get_completion("Who works at Canva?") + + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), ( + "Answer must contain only non-empty strings" + ) + + +@pytest.mark.asyncio +async def test_graph_completion_cot_context_complex(setup_test_environment_complex): + """Integration test: verify GraphCompletionCotRetriever can retrieve context (complex).""" + retriever = GraphCompletionCotRetriever(top_k=20) + + context = await resolve_edges_to_text(await retriever.get_context("Who works at Figma?")) + + assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger" + assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma" + assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham" + + answer = await retriever.get_completion("Who works at Figma?") + + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), ( + "Answer must contain only non-empty strings" + ) + + +@pytest.mark.asyncio +async def test_get_graph_completion_cot_context_on_empty_graph(setup_test_environment_empty): + """Integration test: verify GraphCompletionCotRetriever handles empty graph correctly.""" + retriever = GraphCompletionCotRetriever() + + context = await retriever.get_context("Who works at Figma?") + assert context == [], "Context should be empty on an empty graph" + + answer = await retriever.get_completion("Who works at Figma?") + + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), ( + "Answer must contain only non-empty strings" + ) + + +@pytest.mark.asyncio +async def test_graph_completion_cot_get_triplets_empty(setup_test_environment_empty): + """Integration test: verify GraphCompletionCotRetriever get_triplets handles empty graph.""" + retriever = GraphCompletionCotRetriever() + + triplets = await retriever.get_triplets("Who works at Figma?") + + assert isinstance(triplets, list), "Triplets should be a list" + assert len(triplets) == 0, "Should return empty list on empty graph" diff --git a/cognee/tests/integration/retrieval/test_rag_completion_retriever.py b/cognee/tests/integration/retrieval/test_rag_completion_retriever.py new file mode 100644 index 000000000..b01d58160 --- /dev/null +++ b/cognee/tests/integration/retrieval/test_rag_completion_retriever.py @@ -0,0 +1,254 @@ +import os +from typing import List +import pytest +import pathlib +import pytest_asyncio +import cognee + +from cognee.low_level import setup +from cognee.tasks.storage import add_data_points +from cognee.infrastructure.databases.vector import get_vector_engine +from cognee.modules.chunking.models import DocumentChunk +from cognee.modules.data.processing.document_types import TextDocument +from cognee.modules.retrieval.exceptions.exceptions import NoDataError +from cognee.modules.retrieval.completion_retriever import CompletionRetriever +from cognee.infrastructure.engine import DataPoint +from cognee.modules.data.processing.document_types import Document +from cognee.modules.engine.models import Entity + + +class DocumentChunkWithEntities(DataPoint): + text: str + chunk_size: int + chunk_index: int + cut_type: str + is_part_of: Document + contains: List[Entity] = None + + metadata: dict = {"index_fields": ["text"]} + + +@pytest_asyncio.fixture +async def setup_test_environment_with_chunks_simple(): + """Set up a clean test environment with simple chunks.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str(base_dir / ".cognee_system/test_rag_completion_context_simple") + data_directory_path = str(base_dir / ".data_storage/test_rag_completion_context_simple") + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + document = TextDocument( + name="Steve Rodger's career", + raw_data_location="somewhere", + external_metadata="", + mime_type="text/plain", + ) + + chunk1 = DocumentChunk( + text="Steve Rodger", + chunk_size=2, + chunk_index=0, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + chunk2 = DocumentChunk( + text="Mike Broski", + chunk_size=2, + chunk_index=1, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + chunk3 = DocumentChunk( + text="Christina Mayer", + chunk_size=2, + chunk_index=2, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + + entities = [chunk1, chunk2, chunk3] + + await add_data_points(entities) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest_asyncio.fixture +async def setup_test_environment_with_chunks_complex(): + """Set up a clean test environment with complex chunks.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str(base_dir / ".cognee_system/test_rag_completion_context_complex") + data_directory_path = str(base_dir / ".data_storage/test_rag_completion_context_complex") + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + document1 = TextDocument( + name="Employee List", + raw_data_location="somewhere", + external_metadata="", + mime_type="text/plain", + ) + + document2 = TextDocument( + name="Car List", + raw_data_location="somewhere", + external_metadata="", + mime_type="text/plain", + ) + + chunk1 = DocumentChunk( + text="Steve Rodger", + chunk_size=2, + chunk_index=0, + cut_type="sentence_end", + is_part_of=document1, + contains=[], + ) + chunk2 = DocumentChunk( + text="Mike Broski", + chunk_size=2, + chunk_index=1, + cut_type="sentence_end", + is_part_of=document1, + contains=[], + ) + chunk3 = DocumentChunk( + text="Christina Mayer", + chunk_size=2, + chunk_index=2, + cut_type="sentence_end", + is_part_of=document1, + contains=[], + ) + + chunk4 = DocumentChunk( + text="Range Rover", + chunk_size=2, + chunk_index=0, + cut_type="sentence_end", + is_part_of=document2, + contains=[], + ) + chunk5 = DocumentChunk( + text="Hyundai", + chunk_size=2, + chunk_index=1, + cut_type="sentence_end", + is_part_of=document2, + contains=[], + ) + chunk6 = DocumentChunk( + text="Chrysler", + chunk_size=2, + chunk_index=2, + cut_type="sentence_end", + is_part_of=document2, + contains=[], + ) + + entities = [chunk1, chunk2, chunk3, chunk4, chunk5, chunk6] + + await add_data_points(entities) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest_asyncio.fixture +async def setup_test_environment_empty(): + """Set up a clean test environment without chunks.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str( + base_dir / ".cognee_system/test_get_rag_completion_context_on_empty_graph" + ) + data_directory_path = str( + base_dir / ".data_storage/test_get_rag_completion_context_on_empty_graph" + ) + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest.mark.asyncio +async def test_rag_completion_context_simple(setup_test_environment_with_chunks_simple): + """Integration test: verify CompletionRetriever can retrieve context (simple).""" + retriever = CompletionRetriever() + + context = await retriever.get_context("Mike") + + assert isinstance(context, str), "Context should be a string" + assert "Mike Broski" in context, "Failed to get Mike Broski" + + +@pytest.mark.asyncio +async def test_rag_completion_context_multiple_chunks(setup_test_environment_with_chunks_simple): + """Integration test: verify CompletionRetriever can retrieve context from multiple chunks.""" + retriever = CompletionRetriever() + + context = await retriever.get_context("Steve") + + assert isinstance(context, str), "Context should be a string" + assert "Steve Rodger" in context, "Failed to get Steve Rodger" + + +@pytest.mark.asyncio +async def test_rag_completion_context_complex(setup_test_environment_with_chunks_complex): + """Integration test: verify CompletionRetriever can retrieve context (complex).""" + # TODO: top_k doesn't affect the output, it should be fixed. + retriever = CompletionRetriever(top_k=20) + + context = await retriever.get_context("Christina") + + assert context[0:15] == "Christina Mayer", "Failed to get Christina Mayer" + + +@pytest.mark.asyncio +async def test_get_rag_completion_context_on_empty_graph(setup_test_environment_empty): + """Integration test: verify CompletionRetriever handles empty graph correctly.""" + retriever = CompletionRetriever() + + with pytest.raises(NoDataError): + await retriever.get_context("Christina Mayer") + + vector_engine = get_vector_engine() + await vector_engine.create_collection( + "DocumentChunk_text", payload_schema=DocumentChunkWithEntities + ) + + context = await retriever.get_context("Christina Mayer") + assert context == "", "Returned context should be empty on an empty graph" diff --git a/cognee/tests/unit/modules/retrieval/structured_output_test.py b/cognee/tests/integration/retrieval/test_structured_output.py similarity index 65% rename from cognee/tests/unit/modules/retrieval/structured_output_test.py rename to cognee/tests/integration/retrieval/test_structured_output.py index 4ad3019ff..13ffd8eef 100644 --- a/cognee/tests/unit/modules/retrieval/structured_output_test.py +++ b/cognee/tests/integration/retrieval/test_structured_output.py @@ -1,9 +1,9 @@ import asyncio - -import pytest -import cognee -import pathlib import os +import pytest +import pathlib +import pytest_asyncio +import cognee from pydantic import BaseModel from cognee.low_level import setup, DataPoint @@ -125,80 +125,90 @@ async def _test_get_structured_entity_completion(): _assert_structured_answer(structured_answer) -class TestStructuredOutputCompletion: - @pytest.mark.asyncio - async def test_get_structured_completion(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_get_structured_completion" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_get_structured_completion" - ) - cognee.config.data_root_directory(data_directory_path) +@pytest_asyncio.fixture +async def setup_test_environment(): + """Set up a clean test environment with graph and document data.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str(base_dir / ".cognee_system/test_get_structured_completion") + data_directory_path = str(base_dir / ".data_storage/test_get_structured_completion") + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + class Company(DataPoint): + name: str + + class Person(DataPoint): + name: str + works_for: Company + works_since: int + + company1 = Company(name="Figma") + person1 = Person(name="Steve Rodger", works_for=company1, works_since=2015) + + entities = [company1, person1] + await add_data_points(entities) + + document = TextDocument( + name="Steve Rodger's career", + raw_data_location="somewhere", + external_metadata="", + mime_type="text/plain", + ) + + chunk1 = DocumentChunk( + text="Steve Rodger", + chunk_size=2, + chunk_index=0, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + chunk2 = DocumentChunk( + text="Mike Broski", + chunk_size=2, + chunk_index=1, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + chunk3 = DocumentChunk( + text="Christina Mayer", + chunk_size=2, + chunk_index=2, + cut_type="sentence_end", + is_part_of=document, + contains=[], + ) + + entities = [chunk1, chunk2, chunk3] + await add_data_points(entities) + + entity_type = EntityType(name="Person", description="A human individual") + entity = Entity(name="Albert Einstein", is_a=entity_type, description="A famous physicist") + + entities = [entity] + await add_data_points(entities) + + yield + + try: await cognee.prune.prune_data() await cognee.prune.prune_system(metadata=True) - await setup() + except Exception: + pass - class Company(DataPoint): - name: str - class Person(DataPoint): - name: str - works_for: Company - works_since: int - - company1 = Company(name="Figma") - person1 = Person(name="Steve Rodger", works_for=company1, works_since=2015) - - entities = [company1, person1] - await add_data_points(entities) - - document = TextDocument( - name="Steve Rodger's career", - raw_data_location="somewhere", - external_metadata="", - mime_type="text/plain", - ) - - chunk1 = DocumentChunk( - text="Steve Rodger", - chunk_size=2, - chunk_index=0, - cut_type="sentence_end", - is_part_of=document, - contains=[], - ) - chunk2 = DocumentChunk( - text="Mike Broski", - chunk_size=2, - chunk_index=1, - cut_type="sentence_end", - is_part_of=document, - contains=[], - ) - chunk3 = DocumentChunk( - text="Christina Mayer", - chunk_size=2, - chunk_index=2, - cut_type="sentence_end", - is_part_of=document, - contains=[], - ) - - entities = [chunk1, chunk2, chunk3] - await add_data_points(entities) - - entity_type = EntityType(name="Person", description="A human individual") - entity = Entity(name="Albert Einstein", is_a=entity_type, description="A famous physicist") - - entities = [entity] - await add_data_points(entities) - - await _test_get_structured_graph_completion_cot() - await _test_get_structured_graph_completion() - await _test_get_structured_graph_completion_temporal() - await _test_get_structured_graph_completion_rag() - await _test_get_structured_graph_completion_context_extension() - await _test_get_structured_entity_completion() +@pytest.mark.asyncio +async def test_get_structured_completion(setup_test_environment): + """Integration test: verify structured output completion for all retrievers.""" + await _test_get_structured_graph_completion_cot() + await _test_get_structured_graph_completion() + await _test_get_structured_graph_completion_temporal() + await _test_get_structured_graph_completion_rag() + await _test_get_structured_graph_completion_context_extension() + await _test_get_structured_entity_completion() diff --git a/cognee/tests/integration/retrieval/test_summaries_retriever.py b/cognee/tests/integration/retrieval/test_summaries_retriever.py new file mode 100644 index 000000000..a2f4e40b3 --- /dev/null +++ b/cognee/tests/integration/retrieval/test_summaries_retriever.py @@ -0,0 +1,184 @@ +import os +import pytest +import pathlib +import pytest_asyncio +import cognee + +from cognee.low_level import setup +from cognee.tasks.storage import add_data_points +from cognee.infrastructure.databases.vector import get_vector_engine +from cognee.modules.chunking.models import DocumentChunk +from cognee.tasks.summarization.models import TextSummary +from cognee.modules.data.processing.document_types import TextDocument +from cognee.modules.retrieval.exceptions.exceptions import NoDataError +from cognee.modules.retrieval.summaries_retriever import SummariesRetriever + + +@pytest_asyncio.fixture +async def setup_test_environment_with_summaries(): + """Set up a clean test environment with summaries.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str(base_dir / ".cognee_system/test_summaries_retriever_context") + data_directory_path = str(base_dir / ".data_storage/test_summaries_retriever_context") + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + document1 = TextDocument( + name="Employee List", + raw_data_location="somewhere", + external_metadata="", + mime_type="text/plain", + ) + + document2 = TextDocument( + name="Car List", + raw_data_location="somewhere", + external_metadata="", + mime_type="text/plain", + ) + + chunk1 = DocumentChunk( + text="Steve Rodger", + chunk_size=2, + chunk_index=0, + cut_type="sentence_end", + is_part_of=document1, + contains=[], + ) + chunk1_summary = TextSummary( + text="S.R.", + made_from=chunk1, + ) + chunk2 = DocumentChunk( + text="Mike Broski", + chunk_size=2, + chunk_index=1, + cut_type="sentence_end", + is_part_of=document1, + contains=[], + ) + chunk2_summary = TextSummary( + text="M.B.", + made_from=chunk2, + ) + chunk3 = DocumentChunk( + text="Christina Mayer", + chunk_size=2, + chunk_index=2, + cut_type="sentence_end", + is_part_of=document1, + contains=[], + ) + chunk3_summary = TextSummary( + text="C.M.", + made_from=chunk3, + ) + chunk4 = DocumentChunk( + text="Range Rover", + chunk_size=2, + chunk_index=0, + cut_type="sentence_end", + is_part_of=document2, + contains=[], + ) + chunk4_summary = TextSummary( + text="R.R.", + made_from=chunk4, + ) + chunk5 = DocumentChunk( + text="Hyundai", + chunk_size=2, + chunk_index=1, + cut_type="sentence_end", + is_part_of=document2, + contains=[], + ) + chunk5_summary = TextSummary( + text="H.Y.", + made_from=chunk5, + ) + chunk6 = DocumentChunk( + text="Chrysler", + chunk_size=2, + chunk_index=2, + cut_type="sentence_end", + is_part_of=document2, + contains=[], + ) + chunk6_summary = TextSummary( + text="C.H.", + made_from=chunk6, + ) + + entities = [ + chunk1_summary, + chunk2_summary, + chunk3_summary, + chunk4_summary, + chunk5_summary, + chunk6_summary, + ] + + await add_data_points(entities) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest_asyncio.fixture +async def setup_test_environment_empty(): + """Set up a clean test environment without summaries.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str(base_dir / ".cognee_system/test_summaries_retriever_context_empty") + data_directory_path = str(base_dir / ".data_storage/test_summaries_retriever_context_empty") + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest.mark.asyncio +async def test_summaries_retriever_context(setup_test_environment_with_summaries): + """Integration test: verify SummariesRetriever can retrieve summary context.""" + retriever = SummariesRetriever(top_k=20) + + context = await retriever.get_context("Christina") + + assert isinstance(context, list), "Context should be a list" + assert len(context) > 0, "Context should not be empty" + assert context[0]["text"] == "C.M.", "Failed to get Christina Mayer" + + +@pytest.mark.asyncio +async def test_summaries_retriever_context_on_empty_graph(setup_test_environment_empty): + """Integration test: verify SummariesRetriever handles empty graph correctly.""" + retriever = SummariesRetriever() + + with pytest.raises(NoDataError): + await retriever.get_context("Christina Mayer") + + vector_engine = get_vector_engine() + await vector_engine.create_collection("TextSummary_text", payload_schema=TextSummary) + + context = await retriever.get_context("Christina Mayer") + assert context == [], "Returned context should be empty on an empty graph" diff --git a/cognee/tests/integration/retrieval/test_temporal_retriever.py b/cognee/tests/integration/retrieval/test_temporal_retriever.py new file mode 100644 index 000000000..8ce3b32f4 --- /dev/null +++ b/cognee/tests/integration/retrieval/test_temporal_retriever.py @@ -0,0 +1,306 @@ +import os +import pytest +import pathlib +import pytest_asyncio +import cognee + +from cognee.low_level import setup, DataPoint +from cognee.tasks.storage import add_data_points +from cognee.modules.retrieval.temporal_retriever import TemporalRetriever +from cognee.modules.engine.models.Event import Event +from cognee.modules.engine.models.Timestamp import Timestamp +from cognee.modules.engine.models.Interval import Interval + + +@pytest_asyncio.fixture +async def setup_test_environment_with_events(): + """Set up a clean test environment with temporal events.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str(base_dir / ".cognee_system/test_temporal_retriever_with_events") + data_directory_path = str(base_dir / ".data_storage/test_temporal_retriever_with_events") + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + # Create timestamps for events + timestamp1 = Timestamp( + time_at=1609459200, # 2021-01-01 00:00:00 + year=2021, + month=1, + day=1, + hour=0, + minute=0, + second=0, + timestamp_str="2021-01-01T00:00:00", + ) + + timestamp2 = Timestamp( + time_at=1612137600, # 2021-02-01 00:00:00 + year=2021, + month=2, + day=1, + hour=0, + minute=0, + second=0, + timestamp_str="2021-02-01T00:00:00", + ) + + timestamp3 = Timestamp( + time_at=1614556800, # 2021-03-01 00:00:00 + year=2021, + month=3, + day=1, + hour=0, + minute=0, + second=0, + timestamp_str="2021-03-01T00:00:00", + ) + + timestamp4 = Timestamp( + time_at=1625097600, # 2021-07-01 00:00:00 + year=2021, + month=7, + day=1, + hour=0, + minute=0, + second=0, + timestamp_str="2021-07-01T00:00:00", + ) + + timestamp5 = Timestamp( + time_at=1633046400, # 2021-10-01 00:00:00 + year=2021, + month=10, + day=1, + hour=0, + minute=0, + second=0, + timestamp_str="2021-10-01T00:00:00", + ) + + # Create interval for event spanning multiple timestamps + interval1 = Interval(time_from=timestamp2, time_to=timestamp3) + + # Create events with timestamps + event1 = Event( + name="Project Alpha Launch", + description="Launched Project Alpha at the beginning of 2021", + at=timestamp1, + location="San Francisco", + ) + + event2 = Event( + name="Team Meeting", + description="Monthly team meeting discussing Q1 goals", + during=interval1, + location="New York", + ) + + event3 = Event( + name="Product Release", + description="Released new product features in July", + at=timestamp4, + location="Remote", + ) + + event4 = Event( + name="Company Retreat", + description="Annual company retreat in October", + at=timestamp5, + location="Lake Tahoe", + ) + + entities = [event1, event2, event3, event4] + + await add_data_points(entities) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest_asyncio.fixture +async def setup_test_environment_with_graph_data(): + """Set up a clean test environment with graph data (for fallback to triplets).""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str(base_dir / ".cognee_system/test_temporal_retriever_with_graph") + data_directory_path = str(base_dir / ".data_storage/test_temporal_retriever_with_graph") + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + class Company(DataPoint): + name: str + description: str + + class Person(DataPoint): + name: str + description: str + works_for: Company + + company1 = Company(name="Figma", description="Figma is a company") + person1 = Person( + name="Steve Rodger", + description="This is description about Steve Rodger", + works_for=company1, + ) + + entities = [company1, person1] + + await add_data_points(entities) + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest_asyncio.fixture +async def setup_test_environment_empty(): + """Set up a clean test environment without data.""" + base_dir = pathlib.Path(__file__).parent.parent.parent.parent + system_directory_path = str(base_dir / ".cognee_system/test_temporal_retriever_empty") + data_directory_path = str(base_dir / ".data_storage/test_temporal_retriever_empty") + + cognee.config.system_root_directory(system_directory_path) + cognee.config.data_root_directory(data_directory_path) + + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + await setup() + + yield + + try: + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + except Exception: + pass + + +@pytest.mark.asyncio +async def test_temporal_retriever_context_with_time_range(setup_test_environment_with_events): + """Integration test: verify TemporalRetriever can retrieve events within time range.""" + retriever = TemporalRetriever(top_k=5) + + context = await retriever.get_context("What happened in January 2021?") + + assert isinstance(context, str), "Context should be a string" + assert len(context) > 0, "Context should not be empty" + assert "Project Alpha" in context or "Launch" in context, ( + "Should retrieve Project Alpha Launch event from January 2021" + ) + + +@pytest.mark.asyncio +async def test_temporal_retriever_context_with_single_time(setup_test_environment_with_events): + """Integration test: verify TemporalRetriever can retrieve events at specific time.""" + retriever = TemporalRetriever(top_k=5) + + context = await retriever.get_context("What happened in July 2021?") + + assert isinstance(context, str), "Context should be a string" + assert len(context) > 0, "Context should not be empty" + assert "Product Release" in context or "July" in context, ( + "Should retrieve Product Release event from July 2021" + ) + + +@pytest.mark.asyncio +async def test_temporal_retriever_context_fallback_to_triplets( + setup_test_environment_with_graph_data, +): + """Integration test: verify TemporalRetriever falls back to triplets when no time extracted.""" + retriever = TemporalRetriever(top_k=5) + + context = await retriever.get_context("Who works at Figma?") + + assert isinstance(context, str), "Context should be a string" + assert len(context) > 0, "Context should not be empty" + assert "Steve" in context or "Figma" in context, ( + "Should retrieve graph data via triplet search fallback" + ) + + +@pytest.mark.asyncio +async def test_temporal_retriever_context_empty_graph(setup_test_environment_empty): + """Integration test: verify TemporalRetriever handles empty graph correctly.""" + retriever = TemporalRetriever() + + context = await retriever.get_context("What happened?") + + assert isinstance(context, str), "Context should be a string" + assert len(context) >= 0, "Context should be a string (possibly empty)" + + +@pytest.mark.asyncio +async def test_temporal_retriever_get_completion(setup_test_environment_with_events): + """Integration test: verify TemporalRetriever can generate completions.""" + retriever = TemporalRetriever() + + completion = await retriever.get_completion("What happened in January 2021?") + + assert isinstance(completion, list), "Completion should be a list" + assert len(completion) > 0, "Completion should not be empty" + assert all(isinstance(item, str) and item.strip() for item in completion), ( + "Completion items should be non-empty strings" + ) + + +@pytest.mark.asyncio +async def test_temporal_retriever_get_completion_fallback(setup_test_environment_with_graph_data): + """Integration test: verify TemporalRetriever get_completion works with triplet fallback.""" + retriever = TemporalRetriever() + + completion = await retriever.get_completion("Who works at Figma?") + + assert isinstance(completion, list), "Completion should be a list" + assert len(completion) > 0, "Completion should not be empty" + assert all(isinstance(item, str) and item.strip() for item in completion), ( + "Completion items should be non-empty strings" + ) + + +@pytest.mark.asyncio +async def test_temporal_retriever_top_k_limit(setup_test_environment_with_events): + """Integration test: verify TemporalRetriever respects top_k parameter.""" + retriever = TemporalRetriever(top_k=2) + + context = await retriever.get_context("What happened in 2021?") + + assert isinstance(context, str), "Context should be a string" + separator_count = context.count("#####################") + assert separator_count <= 1, "Should respect top_k limit of 2 events" + + +@pytest.mark.asyncio +async def test_temporal_retriever_multiple_events(setup_test_environment_with_events): + """Integration test: verify TemporalRetriever can retrieve multiple events.""" + retriever = TemporalRetriever(top_k=10) + + context = await retriever.get_context("What events occurred in 2021?") + + assert isinstance(context, str), "Context should be a string" + assert len(context) > 0, "Context should not be empty" + + assert ( + "Project Alpha" in context + or "Team Meeting" in context + or "Product Release" in context + or "Company Retreat" in context + ), "Should retrieve at least one event from 2021" diff --git a/cognee/tests/integration/retrieval/test_triplet_retriever.py b/cognee/tests/integration/retrieval/test_triplet_retriever.py index e547b6cbe..ebe853e08 100644 --- a/cognee/tests/integration/retrieval/test_triplet_retriever.py +++ b/cognee/tests/integration/retrieval/test_triplet_retriever.py @@ -82,3 +82,38 @@ async def test_triplet_retriever_context_simple(setup_test_environment_with_trip context = await retriever.get_context("Alice") assert "Alice knows Bob" in context, "Failed to get Alice triplet" + assert isinstance(context, str), "Context should be a string" + assert len(context) > 0, "Context should not be empty" + + +@pytest.mark.asyncio +async def test_triplet_retriever_context_multiple_triplets(setup_test_environment_with_triplets): + """Integration test: verify TripletRetriever can retrieve multiple triplets.""" + retriever = TripletRetriever(top_k=5) + + context = await retriever.get_context("Bob") + + assert "Alice knows Bob" in context or "Bob works at Tech Corp" in context, ( + "Failed to get Bob-related triplets" + ) + + +@pytest.mark.asyncio +async def test_triplet_retriever_top_k_limit(setup_test_environment_with_triplets): + """Integration test: verify TripletRetriever respects top_k parameter.""" + retriever = TripletRetriever(top_k=1) + + context = await retriever.get_context("Alice") + + assert isinstance(context, str), "Context should be a string" + + +@pytest.mark.asyncio +async def test_triplet_retriever_context_empty(setup_test_environment_empty): + """Integration test: verify TripletRetriever handles empty graph correctly.""" + await setup() + + retriever = TripletRetriever() + + with pytest.raises(NoDataError): + await retriever.get_context("Alice") diff --git a/cognee/tests/unit/eval_framework/benchmark_adapters_test.py b/cognee/tests/unit/eval_framework/benchmark_adapters_test.py index 70ec43cf8..b18012594 100644 --- a/cognee/tests/unit/eval_framework/benchmark_adapters_test.py +++ b/cognee/tests/unit/eval_framework/benchmark_adapters_test.py @@ -11,6 +11,22 @@ MOCK_JSONL_DATA = """\ {"id": "2", "question": "What is ML?", "answer": "Machine Learning", "paragraphs": [{"paragraph_text": "ML is a subset of AI."}]} """ +MOCK_HOTPOT_CORPUS = [ + { + "_id": "1", + "question": "Next to which country is Germany located?", + "answer": "Netherlands", + # HotpotQA uses "level"; TwoWikiMultiHop uses "type". + "level": "easy", + "type": "comparison", + "context": [ + ["Germany", ["Germany is in Europe."]], + ["Netherlands", ["The Netherlands borders Germany."]], + ], + "supporting_facts": [["Netherlands", 0]], + } +] + ADAPTER_CLASSES = [ HotpotQAAdapter, @@ -35,6 +51,11 @@ def test_adapter_can_instantiate_and_load(AdapterClass): adapter = AdapterClass() result = adapter.load_corpus() + elif AdapterClass in (HotpotQAAdapter, TwoWikiMultihopAdapter): + with patch.object(AdapterClass, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS): + adapter = AdapterClass() + result = adapter.load_corpus() + else: adapter = AdapterClass() result = adapter.load_corpus() @@ -64,6 +85,10 @@ def test_adapter_returns_some_content(AdapterClass): ): adapter = AdapterClass() corpus_list, qa_pairs = adapter.load_corpus(limit=limit) + elif AdapterClass in (HotpotQAAdapter, TwoWikiMultihopAdapter): + with patch.object(AdapterClass, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS): + adapter = AdapterClass() + corpus_list, qa_pairs = adapter.load_corpus(limit=limit) else: adapter = AdapterClass() corpus_list, qa_pairs = adapter.load_corpus(limit=limit) diff --git a/cognee/tests/unit/eval_framework/corpus_builder_test.py b/cognee/tests/unit/eval_framework/corpus_builder_test.py index 14136bea5..53f886b58 100644 --- a/cognee/tests/unit/eval_framework/corpus_builder_test.py +++ b/cognee/tests/unit/eval_framework/corpus_builder_test.py @@ -2,15 +2,38 @@ import pytest from cognee.eval_framework.corpus_builder.corpus_builder_executor import CorpusBuilderExecutor from cognee.infrastructure.databases.graph import get_graph_engine from unittest.mock import AsyncMock, patch +from cognee.eval_framework.benchmark_adapters.hotpot_qa_adapter import HotpotQAAdapter benchmark_options = ["HotPotQA", "Dummy", "TwoWikiMultiHop"] +MOCK_HOTPOT_CORPUS = [ + { + "_id": "1", + "question": "Next to which country is Germany located?", + "answer": "Netherlands", + # HotpotQA uses "level"; TwoWikiMultiHop uses "type". + "level": "easy", + "type": "comparison", + "context": [ + ["Germany", ["Germany is in Europe."]], + ["Netherlands", ["The Netherlands borders Germany."]], + ], + "supporting_facts": [["Netherlands", 0]], + } +] + @pytest.mark.parametrize("benchmark", benchmark_options) def test_corpus_builder_load_corpus(benchmark): limit = 2 - corpus_builder = CorpusBuilderExecutor(benchmark, "Default") - raw_corpus, questions = corpus_builder.load_corpus(limit=limit) + if benchmark in ("HotPotQA", "TwoWikiMultiHop"): + with patch.object(HotpotQAAdapter, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS): + corpus_builder = CorpusBuilderExecutor(benchmark, "Default") + raw_corpus, questions = corpus_builder.load_corpus(limit=limit) + else: + corpus_builder = CorpusBuilderExecutor(benchmark, "Default") + raw_corpus, questions = corpus_builder.load_corpus(limit=limit) + assert len(raw_corpus) > 0, f"Corpus builder loads empty corpus for {benchmark}" assert len(questions) <= 2, ( f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}" @@ -22,8 +45,14 @@ def test_corpus_builder_load_corpus(benchmark): @patch.object(CorpusBuilderExecutor, "run_cognee", new_callable=AsyncMock) async def test_corpus_builder_build_corpus(mock_run_cognee, benchmark): limit = 2 - corpus_builder = CorpusBuilderExecutor(benchmark, "Default") - questions = await corpus_builder.build_corpus(limit=limit) + if benchmark in ("HotPotQA", "TwoWikiMultiHop"): + with patch.object(HotpotQAAdapter, "_get_raw_corpus", return_value=MOCK_HOTPOT_CORPUS): + corpus_builder = CorpusBuilderExecutor(benchmark, "Default") + questions = await corpus_builder.build_corpus(limit=limit) + else: + corpus_builder = CorpusBuilderExecutor(benchmark, "Default") + questions = await corpus_builder.build_corpus(limit=limit) + assert len(questions) <= 2, ( f"Corpus builder loads {len(questions)} for {benchmark} when limit is {limit}" ) diff --git a/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py b/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py deleted file mode 100644 index 44786f79d..000000000 --- a/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +++ /dev/null @@ -1,201 +0,0 @@ -import os -import pytest -import pathlib -from typing import List -import cognee -from cognee.low_level import setup -from cognee.tasks.storage import add_data_points -from cognee.infrastructure.databases.vector import get_vector_engine -from cognee.modules.chunking.models import DocumentChunk -from cognee.modules.data.processing.document_types import TextDocument -from cognee.modules.retrieval.exceptions.exceptions import NoDataError -from cognee.modules.retrieval.chunks_retriever import ChunksRetriever -from cognee.infrastructure.engine import DataPoint -from cognee.modules.data.processing.document_types import Document -from cognee.modules.engine.models import Entity - - -class DocumentChunkWithEntities(DataPoint): - text: str - chunk_size: int - chunk_index: int - cut_type: str - is_part_of: Document - contains: List[Entity] = None - - metadata: dict = {"index_fields": ["text"]} - - -class TestChunksRetriever: - @pytest.mark.asyncio - async def test_chunk_context_simple(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_simple" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_simple" - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - - document = TextDocument( - name="Steve Rodger's career", - raw_data_location="somewhere", - external_metadata="", - mime_type="text/plain", - ) - - chunk1 = DocumentChunk( - text="Steve Rodger", - chunk_size=2, - chunk_index=0, - cut_type="sentence_end", - is_part_of=document, - contains=[], - ) - chunk2 = DocumentChunk( - text="Mike Broski", - chunk_size=2, - chunk_index=1, - cut_type="sentence_end", - is_part_of=document, - contains=[], - ) - chunk3 = DocumentChunk( - text="Christina Mayer", - chunk_size=2, - chunk_index=2, - cut_type="sentence_end", - is_part_of=document, - contains=[], - ) - - entities = [chunk1, chunk2, chunk3] - - await add_data_points(entities) - - retriever = ChunksRetriever() - - context = await retriever.get_context("Mike") - - assert context[0]["text"] == "Mike Broski", "Failed to get Mike Broski" - - @pytest.mark.asyncio - async def test_chunk_context_complex(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_complex" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_complex" - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - - document1 = TextDocument( - name="Employee List", - raw_data_location="somewhere", - external_metadata="", - mime_type="text/plain", - ) - - document2 = TextDocument( - name="Car List", - raw_data_location="somewhere", - external_metadata="", - mime_type="text/plain", - ) - - chunk1 = DocumentChunk( - text="Steve Rodger", - chunk_size=2, - chunk_index=0, - cut_type="sentence_end", - is_part_of=document1, - contains=[], - ) - chunk2 = DocumentChunk( - text="Mike Broski", - chunk_size=2, - chunk_index=1, - cut_type="sentence_end", - is_part_of=document1, - contains=[], - ) - chunk3 = DocumentChunk( - text="Christina Mayer", - chunk_size=2, - chunk_index=2, - cut_type="sentence_end", - is_part_of=document1, - contains=[], - ) - - chunk4 = DocumentChunk( - text="Range Rover", - chunk_size=2, - chunk_index=0, - cut_type="sentence_end", - is_part_of=document2, - contains=[], - ) - chunk5 = DocumentChunk( - text="Hyundai", - chunk_size=2, - chunk_index=1, - cut_type="sentence_end", - is_part_of=document2, - contains=[], - ) - chunk6 = DocumentChunk( - text="Chrysler", - chunk_size=2, - chunk_index=2, - cut_type="sentence_end", - is_part_of=document2, - contains=[], - ) - - entities = [chunk1, chunk2, chunk3, chunk4, chunk5, chunk6] - - await add_data_points(entities) - - retriever = ChunksRetriever(top_k=20) - - context = await retriever.get_context("Christina") - - assert context[0]["text"] == "Christina Mayer", "Failed to get Christina Mayer" - - @pytest.mark.asyncio - async def test_chunk_context_on_empty_graph(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_on_empty_graph" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_on_empty_graph" - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - - retriever = ChunksRetriever() - - with pytest.raises(NoDataError): - await retriever.get_context("Christina Mayer") - - vector_engine = get_vector_engine() - await vector_engine.create_collection( - "DocumentChunk_text", payload_schema=DocumentChunkWithEntities - ) - - context = await retriever.get_context("Christina Mayer") - assert len(context) == 0, "Found chunks when none should exist" diff --git a/cognee/tests/unit/modules/retrieval/conversation_history_test.py b/cognee/tests/unit/modules/retrieval/conversation_history_test.py deleted file mode 100644 index d464a99d8..000000000 --- a/cognee/tests/unit/modules/retrieval/conversation_history_test.py +++ /dev/null @@ -1,154 +0,0 @@ -import pytest -from unittest.mock import AsyncMock, patch, MagicMock -from cognee.context_global_variables import session_user -import importlib - - -def create_mock_cache_engine(qa_history=None): - mock_cache = AsyncMock() - if qa_history is None: - qa_history = [] - mock_cache.get_latest_qa = AsyncMock(return_value=qa_history) - mock_cache.add_qa = AsyncMock(return_value=None) - return mock_cache - - -def create_mock_user(): - mock_user = MagicMock() - mock_user.id = "test-user-id-123" - return mock_user - - -class TestConversationHistoryUtils: - @pytest.mark.asyncio - async def test_get_conversation_history_returns_empty_when_no_history(self): - user = create_mock_user() - session_user.set(user) - mock_cache = create_mock_cache_engine([]) - - cache_module = importlib.import_module( - "cognee.infrastructure.databases.cache.get_cache_engine" - ) - - with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): - from cognee.modules.retrieval.utils.session_cache import get_conversation_history - - result = await get_conversation_history(session_id="test_session") - - assert result == "" - - @pytest.mark.asyncio - async def test_get_conversation_history_formats_history_correctly(self): - """Test get_conversation_history formats Q&A history with correct structure.""" - user = create_mock_user() - session_user.set(user) - - mock_history = [ - { - "time": "2024-01-15 10:30:45", - "question": "What is AI?", - "context": "AI is artificial intelligence", - "answer": "AI stands for Artificial Intelligence", - } - ] - mock_cache = create_mock_cache_engine(mock_history) - - # Import the real module to patch safely - cache_module = importlib.import_module( - "cognee.infrastructure.databases.cache.get_cache_engine" - ) - - with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): - with patch( - "cognee.modules.retrieval.utils.session_cache.CacheConfig" - ) as MockCacheConfig: - mock_config = MagicMock() - mock_config.caching = True - MockCacheConfig.return_value = mock_config - - from cognee.modules.retrieval.utils.session_cache import ( - get_conversation_history, - ) - - result = await get_conversation_history(session_id="test_session") - - assert "Previous conversation:" in result - assert "[2024-01-15 10:30:45]" in result - assert "QUESTION: What is AI?" in result - assert "CONTEXT: AI is artificial intelligence" in result - assert "ANSWER: AI stands for Artificial Intelligence" in result - - @pytest.mark.asyncio - async def test_save_to_session_cache_saves_correctly(self): - """Test save_conversation_history calls add_qa with correct parameters.""" - user = create_mock_user() - session_user.set(user) - - mock_cache = create_mock_cache_engine([]) - - cache_module = importlib.import_module( - "cognee.infrastructure.databases.cache.get_cache_engine" - ) - - with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): - with patch( - "cognee.modules.retrieval.utils.session_cache.CacheConfig" - ) as MockCacheConfig: - mock_config = MagicMock() - mock_config.caching = True - MockCacheConfig.return_value = mock_config - - from cognee.modules.retrieval.utils.session_cache import ( - save_conversation_history, - ) - - result = await save_conversation_history( - query="What is Python?", - context_summary="Python is a programming language", - answer="Python is a high-level programming language", - session_id="my_session", - ) - - assert result is True - mock_cache.add_qa.assert_called_once() - - call_kwargs = mock_cache.add_qa.call_args.kwargs - assert call_kwargs["question"] == "What is Python?" - assert call_kwargs["context"] == "Python is a programming language" - assert call_kwargs["answer"] == "Python is a high-level programming language" - assert call_kwargs["session_id"] == "my_session" - - @pytest.mark.asyncio - async def test_save_to_session_cache_uses_default_session_when_none(self): - """Test save_conversation_history uses 'default_session' when session_id is None.""" - user = create_mock_user() - session_user.set(user) - - mock_cache = create_mock_cache_engine([]) - - cache_module = importlib.import_module( - "cognee.infrastructure.databases.cache.get_cache_engine" - ) - - with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): - with patch( - "cognee.modules.retrieval.utils.session_cache.CacheConfig" - ) as MockCacheConfig: - mock_config = MagicMock() - mock_config.caching = True - MockCacheConfig.return_value = mock_config - - from cognee.modules.retrieval.utils.session_cache import ( - save_conversation_history, - ) - - result = await save_conversation_history( - query="Test question", - context_summary="Test context", - answer="Test answer", - session_id=None, - ) - - assert result is True - call_kwargs = mock_cache.add_qa.call_args.kwargs - assert call_kwargs["session_id"] == "default_session" diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py deleted file mode 100644 index 0e21fe351..000000000 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +++ /dev/null @@ -1,177 +0,0 @@ -import os -import pytest -import pathlib -from typing import Optional, Union - -import cognee -from cognee.low_level import setup, DataPoint -from cognee.tasks.storage import add_data_points -from cognee.modules.graph.utils import resolve_edges_to_text -from cognee.modules.retrieval.graph_completion_context_extension_retriever import ( - GraphCompletionContextExtensionRetriever, -) - - -class TestGraphCompletionWithContextExtensionRetriever: - @pytest.mark.asyncio - async def test_graph_completion_extension_context_simple(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".cognee_system/test_graph_completion_extension_context_simple", - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".data_storage/test_graph_completion_extension_context_simple", - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - - class Company(DataPoint): - name: str - - class Person(DataPoint): - name: str - works_for: Company - - company1 = Company(name="Figma") - company2 = Company(name="Canva") - person1 = Person(name="Steve Rodger", works_for=company1) - person2 = Person(name="Ike Loma", works_for=company1) - person3 = Person(name="Jason Statham", works_for=company1) - person4 = Person(name="Mike Broski", works_for=company2) - person5 = Person(name="Christina Mayer", works_for=company2) - - entities = [company1, company2, person1, person2, person3, person4, person5] - - await add_data_points(entities) - - retriever = GraphCompletionContextExtensionRetriever() - - context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?")) - - assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski" - assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer" - - answer = await retriever.get_completion("Who works at Canva?") - - assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" - assert all(isinstance(item, str) and item.strip() for item in answer), ( - "Answer must contain only non-empty strings" - ) - - @pytest.mark.asyncio - async def test_graph_completion_extension_context_complex(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".cognee_system/test_graph_completion_extension_context_complex", - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".data_storage/test_graph_completion_extension_context_complex", - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - - class Company(DataPoint): - name: str - metadata: dict = {"index_fields": ["name"]} - - class Car(DataPoint): - brand: str - model: str - year: int - - class Location(DataPoint): - country: str - city: str - - class Home(DataPoint): - location: Location - rooms: int - sqm: int - - class Person(DataPoint): - name: str - works_for: Company - owns: Optional[list[Union[Car, Home]]] = None - - company1 = Company(name="Figma") - company2 = Company(name="Canva") - - person1 = Person(name="Mike Rodger", works_for=company1) - person1.owns = [Car(brand="Toyota", model="Camry", year=2020)] - - person2 = Person(name="Ike Loma", works_for=company1) - person2.owns = [ - Car(brand="Tesla", model="Model S", year=2021), - Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4), - ] - - person3 = Person(name="Jason Statham", works_for=company1) - - person4 = Person(name="Mike Broski", works_for=company2) - person4.owns = [Car(brand="Ford", model="Mustang", year=1978)] - - person5 = Person(name="Christina Mayer", works_for=company2) - person5.owns = [Car(brand="Honda", model="Civic", year=2023)] - - entities = [company1, company2, person1, person2, person3, person4, person5] - - await add_data_points(entities) - - retriever = GraphCompletionContextExtensionRetriever(top_k=20) - - context = await resolve_edges_to_text( - await retriever.get_context("Who works at Figma and drives Tesla?") - ) - - print(context) - - assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger" - assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma" - assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham" - - answer = await retriever.get_completion("Who works at Figma?") - - assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" - assert all(isinstance(item, str) and item.strip() for item in answer), ( - "Answer must contain only non-empty strings" - ) - - @pytest.mark.asyncio - async def test_get_graph_completion_extension_context_on_empty_graph(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".cognee_system/test_get_graph_completion_extension_context_on_empty_graph", - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".data_storage/test_get_graph_completion_extension_context_on_empty_graph", - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - - retriever = GraphCompletionContextExtensionRetriever() - - await setup() - - context = await retriever.get_context("Who works at Figma?") - assert context == [], "Context should be empty on an empty graph" - - answer = await retriever.get_completion("Who works at Figma?") - - assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" - assert all(isinstance(item, str) and item.strip() for item in answer), ( - "Answer must contain only non-empty strings" - ) diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py deleted file mode 100644 index 206cfaf84..000000000 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +++ /dev/null @@ -1,170 +0,0 @@ -import os -import pytest -import pathlib -from typing import Optional, Union - -import cognee -from cognee.low_level import setup, DataPoint -from cognee.modules.graph.utils import resolve_edges_to_text -from cognee.tasks.storage import add_data_points -from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever - - -class TestGraphCompletionCoTRetriever: - @pytest.mark.asyncio - async def test_graph_completion_cot_context_simple(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_cot_context_simple" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_cot_context_simple" - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - - class Company(DataPoint): - name: str - - class Person(DataPoint): - name: str - works_for: Company - - company1 = Company(name="Figma") - company2 = Company(name="Canva") - person1 = Person(name="Steve Rodger", works_for=company1) - person2 = Person(name="Ike Loma", works_for=company1) - person3 = Person(name="Jason Statham", works_for=company1) - person4 = Person(name="Mike Broski", works_for=company2) - person5 = Person(name="Christina Mayer", works_for=company2) - - entities = [company1, company2, person1, person2, person3, person4, person5] - - await add_data_points(entities) - - retriever = GraphCompletionCotRetriever() - - context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?")) - - assert "Mike Broski --[works_for]--> Canva" in context, "Failed to get Mike Broski" - assert "Christina Mayer --[works_for]--> Canva" in context, "Failed to get Christina Mayer" - - answer = await retriever.get_completion("Who works at Canva?") - - assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" - assert all(isinstance(item, str) and item.strip() for item in answer), ( - "Answer must contain only non-empty strings" - ) - - @pytest.mark.asyncio - async def test_graph_completion_cot_context_complex(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".cognee_system/test_graph_completion_cot_context_complex", - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_cot_context_complex" - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - - class Company(DataPoint): - name: str - metadata: dict = {"index_fields": ["name"]} - - class Car(DataPoint): - brand: str - model: str - year: int - - class Location(DataPoint): - country: str - city: str - - class Home(DataPoint): - location: Location - rooms: int - sqm: int - - class Person(DataPoint): - name: str - works_for: Company - owns: Optional[list[Union[Car, Home]]] = None - - company1 = Company(name="Figma") - company2 = Company(name="Canva") - - person1 = Person(name="Mike Rodger", works_for=company1) - person1.owns = [Car(brand="Toyota", model="Camry", year=2020)] - - person2 = Person(name="Ike Loma", works_for=company1) - person2.owns = [ - Car(brand="Tesla", model="Model S", year=2021), - Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4), - ] - - person3 = Person(name="Jason Statham", works_for=company1) - - person4 = Person(name="Mike Broski", works_for=company2) - person4.owns = [Car(brand="Ford", model="Mustang", year=1978)] - - person5 = Person(name="Christina Mayer", works_for=company2) - person5.owns = [Car(brand="Honda", model="Civic", year=2023)] - - entities = [company1, company2, person1, person2, person3, person4, person5] - - await add_data_points(entities) - - retriever = GraphCompletionCotRetriever(top_k=20) - - context = await resolve_edges_to_text(await retriever.get_context("Who works at Figma?")) - - print(context) - - assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger" - assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma" - assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham" - - answer = await retriever.get_completion("Who works at Figma?") - - assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" - assert all(isinstance(item, str) and item.strip() for item in answer), ( - "Answer must contain only non-empty strings" - ) - - @pytest.mark.asyncio - async def test_get_graph_completion_cot_context_on_empty_graph(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".cognee_system/test_get_graph_completion_cot_context_on_empty_graph", - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".data_storage/test_get_graph_completion_cot_context_on_empty_graph", - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - - retriever = GraphCompletionCotRetriever() - - await setup() - - context = await retriever.get_context("Who works at Figma?") - assert context == [], "Context should be empty on an empty graph" - - answer = await retriever.get_completion("Who works at Figma?") - - assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" - assert all(isinstance(item, str) and item.strip() for item in answer), ( - "Answer must contain only non-empty strings" - ) diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py deleted file mode 100644 index f462baced..000000000 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +++ /dev/null @@ -1,223 +0,0 @@ -import os -import pytest -import pathlib -from typing import Optional, Union - -import cognee -from cognee.low_level import setup, DataPoint -from cognee.modules.graph.utils import resolve_edges_to_text -from cognee.tasks.storage import add_data_points -from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever - - -class TestGraphCompletionRetriever: - @pytest.mark.asyncio - async def test_graph_completion_context_simple(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context_simple" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context_simple" - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - - class Company(DataPoint): - name: str - description: str - - class Person(DataPoint): - name: str - description: str - works_for: Company - - company1 = Company(name="Figma", description="Figma is a company") - company2 = Company(name="Canva", description="Canvas is a company") - person1 = Person( - name="Steve Rodger", - description="This is description about Steve Rodger", - works_for=company1, - ) - person2 = Person( - name="Ike Loma", description="This is description about Ike Loma", works_for=company1 - ) - person3 = Person( - name="Jason Statham", - description="This is description about Jason Statham", - works_for=company1, - ) - person4 = Person( - name="Mike Broski", - description="This is description about Mike Broski", - works_for=company2, - ) - person5 = Person( - name="Christina Mayer", - description="This is description about Christina Mayer", - works_for=company2, - ) - - entities = [company1, company2, person1, person2, person3, person4, person5] - - await add_data_points(entities) - - retriever = GraphCompletionRetriever() - - context = await resolve_edges_to_text(await retriever.get_context("Who works at Canva?")) - - # Ensure the top-level sections are present - assert "Nodes:" in context, "Missing 'Nodes:' section in context" - assert "Connections:" in context, "Missing 'Connections:' section in context" - - # --- Nodes headers --- - assert "Node: Steve Rodger" in context, "Missing node header for Steve Rodger" - assert "Node: Figma" in context, "Missing node header for Figma" - assert "Node: Ike Loma" in context, "Missing node header for Ike Loma" - assert "Node: Jason Statham" in context, "Missing node header for Jason Statham" - assert "Node: Mike Broski" in context, "Missing node header for Mike Broski" - assert "Node: Canva" in context, "Missing node header for Canva" - assert "Node: Christina Mayer" in context, "Missing node header for Christina Mayer" - - # --- Node contents --- - assert ( - "__node_content_start__\nThis is description about Steve Rodger\n__node_content_end__" - in context - ), "Description block for Steve Rodger altered" - assert "__node_content_start__\nFigma is a company\n__node_content_end__" in context, ( - "Description block for Figma altered" - ) - assert ( - "__node_content_start__\nThis is description about Ike Loma\n__node_content_end__" - in context - ), "Description block for Ike Loma altered" - assert ( - "__node_content_start__\nThis is description about Jason Statham\n__node_content_end__" - in context - ), "Description block for Jason Statham altered" - assert ( - "__node_content_start__\nThis is description about Mike Broski\n__node_content_end__" - in context - ), "Description block for Mike Broski altered" - assert "__node_content_start__\nCanvas is a company\n__node_content_end__" in context, ( - "Description block for Canva altered" - ) - assert ( - "__node_content_start__\nThis is description about Christina Mayer\n__node_content_end__" - in context - ), "Description block for Christina Mayer altered" - - # --- Connections --- - assert "Steve Rodger --[works_for]--> Figma" in context, ( - "Connection Steve Rodger→Figma missing or changed" - ) - assert "Ike Loma --[works_for]--> Figma" in context, ( - "Connection Ike Loma→Figma missing or changed" - ) - assert "Jason Statham --[works_for]--> Figma" in context, ( - "Connection Jason Statham→Figma missing or changed" - ) - assert "Mike Broski --[works_for]--> Canva" in context, ( - "Connection Mike Broski→Canva missing or changed" - ) - assert "Christina Mayer --[works_for]--> Canva" in context, ( - "Connection Christina Mayer→Canva missing or changed" - ) - - @pytest.mark.asyncio - async def test_graph_completion_context_complex(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_graph_completion_context_complex" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_graph_completion_context_complex" - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - - class Company(DataPoint): - name: str - metadata: dict = {"index_fields": ["name"]} - - class Car(DataPoint): - brand: str - model: str - year: int - - class Location(DataPoint): - country: str - city: str - - class Home(DataPoint): - location: Location - rooms: int - sqm: int - - class Person(DataPoint): - name: str - works_for: Company - owns: Optional[list[Union[Car, Home]]] = None - - company1 = Company(name="Figma") - company2 = Company(name="Canva") - - person1 = Person(name="Mike Rodger", works_for=company1) - person1.owns = [Car(brand="Toyota", model="Camry", year=2020)] - - person2 = Person(name="Ike Loma", works_for=company1) - person2.owns = [ - Car(brand="Tesla", model="Model S", year=2021), - Home(location=Location(country="USA", city="New York"), sqm=80, rooms=4), - ] - - person3 = Person(name="Jason Statham", works_for=company1) - - person4 = Person(name="Mike Broski", works_for=company2) - person4.owns = [Car(brand="Ford", model="Mustang", year=1978)] - - person5 = Person(name="Christina Mayer", works_for=company2) - person5.owns = [Car(brand="Honda", model="Civic", year=2023)] - - entities = [company1, company2, person1, person2, person3, person4, person5] - - await add_data_points(entities) - - retriever = GraphCompletionRetriever(top_k=20) - - context = await resolve_edges_to_text(await retriever.get_context("Who works at Figma?")) - - print(context) - - assert "Mike Rodger --[works_for]--> Figma" in context, "Failed to get Mike Rodger" - assert "Ike Loma --[works_for]--> Figma" in context, "Failed to get Ike Loma" - assert "Jason Statham --[works_for]--> Figma" in context, "Failed to get Jason Statham" - - @pytest.mark.asyncio - async def test_get_graph_completion_context_on_empty_graph(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".cognee_system/test_get_graph_completion_context_on_empty_graph", - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".data_storage/test_get_graph_completion_context_on_empty_graph", - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - - retriever = GraphCompletionRetriever() - - await setup() - - context = await retriever.get_context("Who works at Figma?") - assert context == [], "Context should be empty on an empty graph" diff --git a/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py deleted file mode 100644 index 9bfed68f3..000000000 --- a/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +++ /dev/null @@ -1,205 +0,0 @@ -import os -from typing import List -import pytest -import pathlib -import cognee - -from cognee.low_level import setup -from cognee.tasks.storage import add_data_points -from cognee.infrastructure.databases.vector import get_vector_engine -from cognee.modules.chunking.models import DocumentChunk -from cognee.modules.data.processing.document_types import TextDocument -from cognee.modules.retrieval.exceptions.exceptions import NoDataError -from cognee.modules.retrieval.completion_retriever import CompletionRetriever -from cognee.infrastructure.engine import DataPoint -from cognee.modules.data.processing.document_types import Document -from cognee.modules.engine.models import Entity - - -class DocumentChunkWithEntities(DataPoint): - text: str - chunk_size: int - chunk_index: int - cut_type: str - is_part_of: Document - contains: List[Entity] = None - - metadata: dict = {"index_fields": ["text"]} - - -class TestRAGCompletionRetriever: - @pytest.mark.asyncio - async def test_rag_completion_context_simple(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_rag_completion_context_simple" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_rag_completion_context_simple" - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - - document = TextDocument( - name="Steve Rodger's career", - raw_data_location="somewhere", - external_metadata="", - mime_type="text/plain", - ) - - chunk1 = DocumentChunk( - text="Steve Rodger", - chunk_size=2, - chunk_index=0, - cut_type="sentence_end", - is_part_of=document, - contains=[], - ) - chunk2 = DocumentChunk( - text="Mike Broski", - chunk_size=2, - chunk_index=1, - cut_type="sentence_end", - is_part_of=document, - contains=[], - ) - chunk3 = DocumentChunk( - text="Christina Mayer", - chunk_size=2, - chunk_index=2, - cut_type="sentence_end", - is_part_of=document, - contains=[], - ) - - entities = [chunk1, chunk2, chunk3] - - await add_data_points(entities) - - retriever = CompletionRetriever() - - context = await retriever.get_context("Mike") - - assert context == "Mike Broski", "Failed to get Mike Broski" - - @pytest.mark.asyncio - async def test_rag_completion_context_complex(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_rag_completion_context_complex" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_rag_completion_context_complex" - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - - document1 = TextDocument( - name="Employee List", - raw_data_location="somewhere", - external_metadata="", - mime_type="text/plain", - ) - - document2 = TextDocument( - name="Car List", - raw_data_location="somewhere", - external_metadata="", - mime_type="text/plain", - ) - - chunk1 = DocumentChunk( - text="Steve Rodger", - chunk_size=2, - chunk_index=0, - cut_type="sentence_end", - is_part_of=document1, - contains=[], - ) - chunk2 = DocumentChunk( - text="Mike Broski", - chunk_size=2, - chunk_index=1, - cut_type="sentence_end", - is_part_of=document1, - contains=[], - ) - chunk3 = DocumentChunk( - text="Christina Mayer", - chunk_size=2, - chunk_index=2, - cut_type="sentence_end", - is_part_of=document1, - contains=[], - ) - - chunk4 = DocumentChunk( - text="Range Rover", - chunk_size=2, - chunk_index=0, - cut_type="sentence_end", - is_part_of=document2, - contains=[], - ) - chunk5 = DocumentChunk( - text="Hyundai", - chunk_size=2, - chunk_index=1, - cut_type="sentence_end", - is_part_of=document2, - contains=[], - ) - chunk6 = DocumentChunk( - text="Chrysler", - chunk_size=2, - chunk_index=2, - cut_type="sentence_end", - is_part_of=document2, - contains=[], - ) - - entities = [chunk1, chunk2, chunk3, chunk4, chunk5, chunk6] - - await add_data_points(entities) - - # TODO: top_k doesn't affect the output, it should be fixed. - retriever = CompletionRetriever(top_k=20) - - context = await retriever.get_context("Christina") - - assert context[0:15] == "Christina Mayer", "Failed to get Christina Mayer" - - @pytest.mark.asyncio - async def test_get_rag_completion_context_on_empty_graph(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".cognee_system/test_get_rag_completion_context_on_empty_graph", - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, - ".data_storage/test_get_rag_completion_context_on_empty_graph", - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - - retriever = CompletionRetriever() - - with pytest.raises(NoDataError): - await retriever.get_context("Christina Mayer") - - vector_engine = get_vector_engine() - await vector_engine.create_collection( - "DocumentChunk_text", payload_schema=DocumentChunkWithEntities - ) - - context = await retriever.get_context("Christina Mayer") - assert context == "", "Returned context should be empty on an empty graph" diff --git a/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py b/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py deleted file mode 100644 index 5f4b93425..000000000 --- a/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +++ /dev/null @@ -1,159 +0,0 @@ -import os -import pytest -import pathlib - -import cognee -from cognee.low_level import setup -from cognee.tasks.storage import add_data_points -from cognee.infrastructure.databases.vector import get_vector_engine -from cognee.modules.chunking.models import DocumentChunk -from cognee.tasks.summarization.models import TextSummary -from cognee.modules.data.processing.document_types import TextDocument -from cognee.modules.retrieval.exceptions.exceptions import NoDataError -from cognee.modules.retrieval.summaries_retriever import SummariesRetriever - - -class TestSummariesRetriever: - @pytest.mark.asyncio - async def test_chunk_context(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_chunk_context" - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - await setup() - - document1 = TextDocument( - name="Employee List", - raw_data_location="somewhere", - external_metadata="", - mime_type="text/plain", - ) - - document2 = TextDocument( - name="Car List", - raw_data_location="somewhere", - external_metadata="", - mime_type="text/plain", - ) - - chunk1 = DocumentChunk( - text="Steve Rodger", - chunk_size=2, - chunk_index=0, - cut_type="sentence_end", - is_part_of=document1, - contains=[], - ) - chunk1_summary = TextSummary( - text="S.R.", - made_from=chunk1, - ) - chunk2 = DocumentChunk( - text="Mike Broski", - chunk_size=2, - chunk_index=1, - cut_type="sentence_end", - is_part_of=document1, - contains=[], - ) - chunk2_summary = TextSummary( - text="M.B.", - made_from=chunk2, - ) - chunk3 = DocumentChunk( - text="Christina Mayer", - chunk_size=2, - chunk_index=2, - cut_type="sentence_end", - is_part_of=document1, - contains=[], - ) - chunk3_summary = TextSummary( - text="C.M.", - made_from=chunk3, - ) - chunk4 = DocumentChunk( - text="Range Rover", - chunk_size=2, - chunk_index=0, - cut_type="sentence_end", - is_part_of=document2, - contains=[], - ) - chunk4_summary = TextSummary( - text="R.R.", - made_from=chunk4, - ) - chunk5 = DocumentChunk( - text="Hyundai", - chunk_size=2, - chunk_index=1, - cut_type="sentence_end", - is_part_of=document2, - contains=[], - ) - chunk5_summary = TextSummary( - text="H.Y.", - made_from=chunk5, - ) - chunk6 = DocumentChunk( - text="Chrysler", - chunk_size=2, - chunk_index=2, - cut_type="sentence_end", - is_part_of=document2, - contains=[], - ) - chunk6_summary = TextSummary( - text="C.H.", - made_from=chunk6, - ) - - entities = [ - chunk1_summary, - chunk2_summary, - chunk3_summary, - chunk4_summary, - chunk5_summary, - chunk6_summary, - ] - - await add_data_points(entities) - - retriever = SummariesRetriever(top_k=20) - - context = await retriever.get_context("Christina") - - assert context[0]["text"] == "C.M.", "Failed to get Christina Mayer" - - @pytest.mark.asyncio - async def test_chunk_context_on_empty_graph(self): - system_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".cognee_system/test_chunk_context_on_empty_graph" - ) - cognee.config.system_root_directory(system_directory_path) - data_directory_path = os.path.join( - pathlib.Path(__file__).parent, ".data_storage/test_chunk_context_on_empty_graph" - ) - cognee.config.data_root_directory(data_directory_path) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - - retriever = SummariesRetriever() - - with pytest.raises(NoDataError): - await retriever.get_context("Christina Mayer") - - vector_engine = get_vector_engine() - await vector_engine.create_collection("TextSummary_text", payload_schema=TextSummary) - - context = await retriever.get_context("Christina Mayer") - assert context == [], "Returned context should be empty on an empty graph" diff --git a/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py b/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py deleted file mode 100644 index c3c6a47f6..000000000 --- a/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +++ /dev/null @@ -1,224 +0,0 @@ -from types import SimpleNamespace -import pytest - -from cognee.modules.retrieval.temporal_retriever import TemporalRetriever - - -# Test TemporalRetriever initialization defaults and overrides -def test_init_defaults_and_overrides(): - tr = TemporalRetriever() - assert tr.top_k == 5 - assert tr.user_prompt_path == "graph_context_for_question.txt" - assert tr.system_prompt_path == "answer_simple_question.txt" - assert tr.time_extraction_prompt_path == "extract_query_time.txt" - - tr2 = TemporalRetriever( - top_k=3, - user_prompt_path="u.txt", - system_prompt_path="s.txt", - time_extraction_prompt_path="t.txt", - ) - assert tr2.top_k == 3 - assert tr2.user_prompt_path == "u.txt" - assert tr2.system_prompt_path == "s.txt" - assert tr2.time_extraction_prompt_path == "t.txt" - - -# Test descriptions_to_string with basic and empty results -def test_descriptions_to_string_basic_and_empty(): - tr = TemporalRetriever() - - results = [ - {"description": " First "}, - {"nope": "no description"}, - {"description": "Second"}, - {"description": ""}, - {"description": " Third line "}, - ] - - s = tr.descriptions_to_string(results) - assert s == "First\n#####################\nSecond\n#####################\nThird line" - - assert tr.descriptions_to_string([]) == "" - - -# Test filter_top_k_events sorts and limits correctly -@pytest.mark.asyncio -async def test_filter_top_k_events_sorts_and_limits(): - tr = TemporalRetriever(top_k=2) - - relevant_events = [ - { - "events": [ - {"id": "e1", "description": "E1"}, - {"id": "e2", "description": "E2"}, - {"id": "e3", "description": "E3 - not in vector results"}, - ] - } - ] - - scored_results = [ - SimpleNamespace(payload={"id": "e2"}, score=0.10), - SimpleNamespace(payload={"id": "e1"}, score=0.20), - ] - - top = await tr.filter_top_k_events(relevant_events, scored_results) - - assert [e["id"] for e in top] == ["e2", "e1"] - assert all("score" in e for e in top) - assert top[0]["score"] == 0.10 - assert top[1]["score"] == 0.20 - - -# Test filter_top_k_events handles unknown ids as infinite scores -@pytest.mark.asyncio -async def test_filter_top_k_events_includes_unknown_as_infinite_but_not_in_top_k(): - tr = TemporalRetriever(top_k=2) - - relevant_events = [ - { - "events": [ - {"id": "known1", "description": "Known 1"}, - {"id": "unknown", "description": "Unknown"}, - {"id": "known2", "description": "Known 2"}, - ] - } - ] - - scored_results = [ - SimpleNamespace(payload={"id": "known2"}, score=0.05), - SimpleNamespace(payload={"id": "known1"}, score=0.50), - ] - - top = await tr.filter_top_k_events(relevant_events, scored_results) - assert [e["id"] for e in top] == ["known2", "known1"] - assert all(e["score"] != float("inf") for e in top) - - -# Test descriptions_to_string with unicode and newlines -def test_descriptions_to_string_unicode_and_newlines(): - tr = TemporalRetriever() - results = [ - {"description": "Line A\nwith newline"}, - {"description": "This is a description"}, - ] - s = tr.descriptions_to_string(results) - assert "Line A\nwith newline" in s - assert "This is a description" in s - assert s.count("#####################") == 1 - - -# Test filter_top_k_events when top_k is larger than available events -@pytest.mark.asyncio -async def test_filter_top_k_events_limits_when_top_k_exceeds_events(): - tr = TemporalRetriever(top_k=10) - relevant_events = [{"events": [{"id": "a"}, {"id": "b"}]}] - scored_results = [ - SimpleNamespace(payload={"id": "a"}, score=0.1), - SimpleNamespace(payload={"id": "b"}, score=0.2), - ] - out = await tr.filter_top_k_events(relevant_events, scored_results) - assert [e["id"] for e in out] == ["a", "b"] - - -# Test filter_top_k_events when scored_results is empty -@pytest.mark.asyncio -async def test_filter_top_k_events_handles_empty_scored_results(): - tr = TemporalRetriever(top_k=2) - relevant_events = [{"events": [{"id": "x"}, {"id": "y"}]}] - scored_results = [] - out = await tr.filter_top_k_events(relevant_events, scored_results) - assert [e["id"] for e in out] == ["x", "y"] - assert all(e["score"] == float("inf") for e in out) - - -# Test filter_top_k_events error handling for missing structure -@pytest.mark.asyncio -async def test_filter_top_k_events_error_handling(): - tr = TemporalRetriever(top_k=2) - with pytest.raises((KeyError, TypeError)): - await tr.filter_top_k_events([{}], []) - - -class _FakeRetriever(TemporalRetriever): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._calls = [] - - async def extract_time_from_query(self, query: str): - if "both" in query: - return "2024-01-01", "2024-12-31" - if "from_only" in query: - return "2024-01-01", None - if "to_only" in query: - return None, "2024-12-31" - return None, None - - async def get_triplets(self, query: str): - self._calls.append(("get_triplets", query)) - return [{"s": "a", "p": "b", "o": "c"}] - - async def resolve_edges_to_text(self, triplets): - self._calls.append(("resolve_edges_to_text", len(triplets))) - return "edges->text" - - async def _fake_graph_collect_ids(self, **kwargs): - return ["e1", "e2"] - - async def _fake_graph_collect_events(self, ids): - return [ - { - "events": [ - {"id": "e1", "description": "E1"}, - {"id": "e2", "description": "E2"}, - {"id": "e3", "description": "E3"}, - ] - } - ] - - async def _fake_vector_embed(self, texts): - assert isinstance(texts, list) and texts - return [[0.0, 1.0, 2.0]] - - async def _fake_vector_search(self, **kwargs): - return [ - SimpleNamespace(payload={"id": "e2"}, score=0.05), - SimpleNamespace(payload={"id": "e1"}, score=0.10), - ] - - async def get_context(self, query: str): - time_from, time_to = await self.extract_time_from_query(query) - - if not (time_from or time_to): - triplets = await self.get_triplets(query) - return await self.resolve_edges_to_text(triplets) - - ids = await self._fake_graph_collect_ids(time_from=time_from, time_to=time_to) - relevant_events = await self._fake_graph_collect_events(ids) - - _ = await self._fake_vector_embed([query]) - vector_search_results = await self._fake_vector_search( - collection_name="Event_name", query_vector=[0.0], limit=0 - ) - top_k_events = await self.filter_top_k_events(relevant_events, vector_search_results) - return self.descriptions_to_string(top_k_events) - - -# Test get_context fallback to triplets when no time is extracted -@pytest.mark.asyncio -async def test_fake_get_context_falls_back_to_triplets_when_no_time(): - tr = _FakeRetriever(top_k=2) - ctx = await tr.get_context("no_time") - assert ctx == "edges->text" - assert tr._calls[0][0] == "get_triplets" - assert tr._calls[1][0] == "resolve_edges_to_text" - - -# Test get_context when time is extracted and vector ranking is applied -@pytest.mark.asyncio -async def test_fake_get_context_with_time_filters_and_vector_ranking(): - tr = _FakeRetriever(top_k=2) - ctx = await tr.get_context("both time") - assert ctx.startswith("E2") - assert "#####################" in ctx - assert "E1" in ctx and "E3" not in ctx diff --git a/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py b/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py deleted file mode 100644 index 3dc9f38d9..000000000 --- a/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +++ /dev/null @@ -1,608 +0,0 @@ -import pytest -from unittest.mock import AsyncMock, patch - -from cognee.modules.retrieval.utils.brute_force_triplet_search import ( - brute_force_triplet_search, - get_memory_fragment, -) -from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph -from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError - - -class MockScoredResult: - """Mock class for vector search results.""" - - def __init__(self, id, score, payload=None): - self.id = id - self.score = score - self.payload = payload or {} - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_empty_query(): - """Test that empty query raises ValueError.""" - with pytest.raises(ValueError, match="The query must be a non-empty string."): - await brute_force_triplet_search(query="") - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_none_query(): - """Test that None query raises ValueError.""" - with pytest.raises(ValueError, match="The query must be a non-empty string."): - await brute_force_triplet_search(query=None) - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_negative_top_k(): - """Test that negative top_k raises ValueError.""" - with pytest.raises(ValueError, match="top_k must be a positive integer."): - await brute_force_triplet_search(query="test query", top_k=-1) - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_zero_top_k(): - """Test that zero top_k raises ValueError.""" - with pytest.raises(ValueError, match="top_k must be a positive integer."): - await brute_force_triplet_search(query="test query", top_k=0) - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_wide_search_limit_global_search(): - """Test that wide_search_limit is applied for global search (node_name=None).""" - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(return_value=[]) - - with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ): - await brute_force_triplet_search( - query="test", - node_name=None, # Global search - wide_search_top_k=75, - ) - - for call in mock_vector_engine.search.call_args_list: - assert call[1]["limit"] == 75 - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_wide_search_limit_filtered_search(): - """Test that wide_search_limit is None for filtered search (node_name provided).""" - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(return_value=[]) - - with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ): - await brute_force_triplet_search( - query="test", - node_name=["Node1"], - wide_search_top_k=50, - ) - - for call in mock_vector_engine.search.call_args_list: - assert call[1]["limit"] is None - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_wide_search_default(): - """Test that wide_search_top_k defaults to 100.""" - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(return_value=[]) - - with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ): - await brute_force_triplet_search(query="test", node_name=None) - - for call in mock_vector_engine.search.call_args_list: - assert call[1]["limit"] == 100 - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_default_collections(): - """Test that default collections are used when none provided.""" - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(return_value=[]) - - with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ): - await brute_force_triplet_search(query="test") - - expected_collections = [ - "Entity_name", - "TextSummary_text", - "EntityType_name", - "DocumentChunk_text", - "EdgeType_relationship_name", - ] - - call_collections = [ - call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list - ] - assert call_collections == expected_collections - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_custom_collections(): - """Test that custom collections are used when provided.""" - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(return_value=[]) - - custom_collections = ["CustomCol1", "CustomCol2"] - - with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ): - await brute_force_triplet_search(query="test", collections=custom_collections) - - call_collections = [ - call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list - ] - assert set(call_collections) == set(custom_collections) | {"EdgeType_relationship_name"} - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_always_includes_edge_collection(): - """Test that EdgeType_relationship_name is always searched even when not in collections.""" - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(return_value=[]) - - collections_without_edge = ["Entity_name", "TextSummary_text"] - - with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ): - await brute_force_triplet_search(query="test", collections=collections_without_edge) - - call_collections = [ - call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list - ] - assert "EdgeType_relationship_name" in call_collections - assert set(call_collections) == set(collections_without_edge) | { - "EdgeType_relationship_name" - } - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_all_collections_empty(): - """Test that empty list is returned when all collections return no results.""" - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(return_value=[]) - - with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ): - results = await brute_force_triplet_search(query="test") - assert results == [] - - -# Tests for query embedding - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_embeds_query(): - """Test that query is embedded before searching.""" - query_text = "test query" - expected_vector = [0.1, 0.2, 0.3] - - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[expected_vector]) - mock_vector_engine.search = AsyncMock(return_value=[]) - - with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ): - await brute_force_triplet_search(query=query_text) - - mock_vector_engine.embedding_engine.embed_text.assert_called_once_with([query_text]) - - for call in mock_vector_engine.search.call_args_list: - assert call[1]["query_vector"] == expected_vector - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_extracts_node_ids_global_search(): - """Test that node IDs are extracted from search results for global search.""" - scored_results = [ - MockScoredResult("node1", 0.95), - MockScoredResult("node2", 0.87), - MockScoredResult("node3", 0.92), - ] - - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(return_value=scored_results) - - mock_fragment = AsyncMock( - map_vector_distances_to_graph_nodes=AsyncMock(), - map_vector_distances_to_graph_edges=AsyncMock(), - calculate_top_triplet_importances=AsyncMock(return_value=[]), - ) - - with ( - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", - return_value=mock_fragment, - ) as mock_get_fragment_fn, - ): - await brute_force_triplet_search(query="test", node_name=None) - - call_kwargs = mock_get_fragment_fn.call_args[1] - assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2", "node3"} - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_reuses_provided_fragment(): - """Test that provided memory fragment is reused instead of creating new one.""" - provided_fragment = AsyncMock( - map_vector_distances_to_graph_nodes=AsyncMock(), - map_vector_distances_to_graph_edges=AsyncMock(), - calculate_top_triplet_importances=AsyncMock(return_value=[]), - ) - - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)]) - - with ( - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment" - ) as mock_get_fragment, - ): - await brute_force_triplet_search( - query="test", - memory_fragment=provided_fragment, - node_name=["node"], - ) - - mock_get_fragment.assert_not_called() - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_creates_fragment_when_not_provided(): - """Test that memory fragment is created when not provided.""" - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)]) - - mock_fragment = AsyncMock( - map_vector_distances_to_graph_nodes=AsyncMock(), - map_vector_distances_to_graph_edges=AsyncMock(), - calculate_top_triplet_importances=AsyncMock(return_value=[]), - ) - - with ( - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", - return_value=mock_fragment, - ) as mock_get_fragment, - ): - await brute_force_triplet_search(query="test", node_name=["node"]) - - mock_get_fragment.assert_called_once() - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_passes_top_k_to_importance_calculation(): - """Test that custom top_k is passed to importance calculation.""" - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)]) - - mock_fragment = AsyncMock( - map_vector_distances_to_graph_nodes=AsyncMock(), - map_vector_distances_to_graph_edges=AsyncMock(), - calculate_top_triplet_importances=AsyncMock(return_value=[]), - ) - - with ( - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", - return_value=mock_fragment, - ), - ): - custom_top_k = 15 - await brute_force_triplet_search(query="test", top_k=custom_top_k, node_name=["n"]) - - mock_fragment.calculate_top_triplet_importances.assert_called_once_with(k=custom_top_k) - - -@pytest.mark.asyncio -async def test_get_memory_fragment_returns_empty_graph_on_entity_not_found(): - """Test that get_memory_fragment returns empty graph when entity not found.""" - mock_graph_engine = AsyncMock() - mock_graph_engine.project_graph_from_db = AsyncMock( - side_effect=EntityNotFoundError("Entity not found") - ) - - with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine", - return_value=mock_graph_engine, - ): - fragment = await get_memory_fragment() - - assert isinstance(fragment, CogneeGraph) - assert len(fragment.nodes) == 0 - - -@pytest.mark.asyncio -async def test_get_memory_fragment_returns_empty_graph_on_error(): - """Test that get_memory_fragment returns empty graph on generic error.""" - mock_graph_engine = AsyncMock() - mock_graph_engine.project_graph_from_db = AsyncMock(side_effect=Exception("Generic error")) - - with patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine", - return_value=mock_graph_engine, - ): - fragment = await get_memory_fragment() - - assert isinstance(fragment, CogneeGraph) - assert len(fragment.nodes) == 0 - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_deduplicates_node_ids(): - """Test that duplicate node IDs across collections are deduplicated.""" - - def search_side_effect(*args, **kwargs): - collection_name = kwargs.get("collection_name") - if collection_name == "Entity_name": - return [ - MockScoredResult("node1", 0.95), - MockScoredResult("node2", 0.87), - ] - elif collection_name == "TextSummary_text": - return [ - MockScoredResult("node1", 0.90), - MockScoredResult("node3", 0.92), - ] - else: - return [] - - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) - - mock_fragment = AsyncMock( - map_vector_distances_to_graph_nodes=AsyncMock(), - map_vector_distances_to_graph_edges=AsyncMock(), - calculate_top_triplet_importances=AsyncMock(return_value=[]), - ) - - with ( - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", - return_value=mock_fragment, - ) as mock_get_fragment_fn, - ): - await brute_force_triplet_search(query="test", node_name=None) - - call_kwargs = mock_get_fragment_fn.call_args[1] - assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2", "node3"} - assert len(call_kwargs["relevant_ids_to_filter"]) == 3 - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_excludes_edge_collection(): - """Test that EdgeType_relationship_name collection is excluded from ID extraction.""" - - def search_side_effect(*args, **kwargs): - collection_name = kwargs.get("collection_name") - if collection_name == "Entity_name": - return [MockScoredResult("node1", 0.95)] - elif collection_name == "EdgeType_relationship_name": - return [MockScoredResult("edge1", 0.88)] - else: - return [] - - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) - - mock_fragment = AsyncMock( - map_vector_distances_to_graph_nodes=AsyncMock(), - map_vector_distances_to_graph_edges=AsyncMock(), - calculate_top_triplet_importances=AsyncMock(return_value=[]), - ) - - with ( - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", - return_value=mock_fragment, - ) as mock_get_fragment_fn, - ): - await brute_force_triplet_search( - query="test", - node_name=None, - collections=["Entity_name", "EdgeType_relationship_name"], - ) - - call_kwargs = mock_get_fragment_fn.call_args[1] - assert call_kwargs["relevant_ids_to_filter"] == ["node1"] - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_skips_nodes_without_ids(): - """Test that nodes without ID attribute are skipped.""" - - class ScoredResultNoId: - """Mock result without id attribute.""" - - def __init__(self, score): - self.score = score - - def search_side_effect(*args, **kwargs): - collection_name = kwargs.get("collection_name") - if collection_name == "Entity_name": - return [ - MockScoredResult("node1", 0.95), - ScoredResultNoId(0.90), - MockScoredResult("node2", 0.87), - ] - else: - return [] - - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) - - mock_fragment = AsyncMock( - map_vector_distances_to_graph_nodes=AsyncMock(), - map_vector_distances_to_graph_edges=AsyncMock(), - calculate_top_triplet_importances=AsyncMock(return_value=[]), - ) - - with ( - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", - return_value=mock_fragment, - ) as mock_get_fragment_fn, - ): - await brute_force_triplet_search(query="test", node_name=None) - - call_kwargs = mock_get_fragment_fn.call_args[1] - assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"} - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_handles_tuple_results(): - """Test that both list and tuple results are handled correctly.""" - - def search_side_effect(*args, **kwargs): - collection_name = kwargs.get("collection_name") - if collection_name == "Entity_name": - return ( - MockScoredResult("node1", 0.95), - MockScoredResult("node2", 0.87), - ) - else: - return [] - - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) - - mock_fragment = AsyncMock( - map_vector_distances_to_graph_nodes=AsyncMock(), - map_vector_distances_to_graph_edges=AsyncMock(), - calculate_top_triplet_importances=AsyncMock(return_value=[]), - ) - - with ( - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", - return_value=mock_fragment, - ) as mock_get_fragment_fn, - ): - await brute_force_triplet_search(query="test", node_name=None) - - call_kwargs = mock_get_fragment_fn.call_args[1] - assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"} - - -@pytest.mark.asyncio -async def test_brute_force_triplet_search_mixed_empty_collections(): - """Test ID extraction with mixed empty and non-empty collections.""" - - def search_side_effect(*args, **kwargs): - collection_name = kwargs.get("collection_name") - if collection_name == "Entity_name": - return [MockScoredResult("node1", 0.95)] - elif collection_name == "TextSummary_text": - return [] - elif collection_name == "EntityType_name": - return [MockScoredResult("node2", 0.92)] - else: - return [] - - mock_vector_engine = AsyncMock() - mock_vector_engine.embedding_engine = AsyncMock() - mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) - mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) - - mock_fragment = AsyncMock( - map_vector_distances_to_graph_nodes=AsyncMock(), - map_vector_distances_to_graph_edges=AsyncMock(), - calculate_top_triplet_importances=AsyncMock(return_value=[]), - ) - - with ( - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", - return_value=mock_vector_engine, - ), - patch( - "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", - return_value=mock_fragment, - ) as mock_get_fragment_fn, - ): - await brute_force_triplet_search(query="test", node_name=None) - - call_kwargs = mock_get_fragment_fn.call_args[1] - assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"} diff --git a/cognee/tests/unit/modules/retrieval/triplet_retriever_test.py b/cognee/tests/unit/modules/retrieval/triplet_retriever_test.py deleted file mode 100644 index d79aca428..000000000 --- a/cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +++ /dev/null @@ -1,83 +0,0 @@ -import pytest -from unittest.mock import AsyncMock, patch, MagicMock - -from cognee.modules.retrieval.triplet_retriever import TripletRetriever -from cognee.modules.retrieval.exceptions.exceptions import NoDataError -from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError - - -@pytest.fixture -def mock_vector_engine(): - """Create a mock vector engine.""" - engine = AsyncMock() - engine.has_collection = AsyncMock(return_value=True) - engine.search = AsyncMock() - return engine - - -@pytest.mark.asyncio -async def test_get_context_success(mock_vector_engine): - """Test successful retrieval of triplet context.""" - mock_result1 = MagicMock() - mock_result1.payload = {"text": "Alice knows Bob"} - mock_result2 = MagicMock() - mock_result2.payload = {"text": "Bob works at Tech Corp"} - - mock_vector_engine.search.return_value = [mock_result1, mock_result2] - - retriever = TripletRetriever(top_k=5) - - with patch( - "cognee.modules.retrieval.triplet_retriever.get_vector_engine", - return_value=mock_vector_engine, - ): - context = await retriever.get_context("test query") - - assert context == "Alice knows Bob\nBob works at Tech Corp" - mock_vector_engine.search.assert_awaited_once_with("Triplet_text", "test query", limit=5) - - -@pytest.mark.asyncio -async def test_get_context_no_collection(mock_vector_engine): - """Test that NoDataError is raised when Triplet_text collection doesn't exist.""" - mock_vector_engine.has_collection.return_value = False - - retriever = TripletRetriever() - - with patch( - "cognee.modules.retrieval.triplet_retriever.get_vector_engine", - return_value=mock_vector_engine, - ): - with pytest.raises(NoDataError, match="create_triplet_embeddings"): - await retriever.get_context("test query") - - -@pytest.mark.asyncio -async def test_get_context_empty_results(mock_vector_engine): - """Test that empty string is returned when no triplets are found.""" - mock_vector_engine.search.return_value = [] - - retriever = TripletRetriever() - - with patch( - "cognee.modules.retrieval.triplet_retriever.get_vector_engine", - return_value=mock_vector_engine, - ): - context = await retriever.get_context("test query") - - assert context == "" - - -@pytest.mark.asyncio -async def test_get_context_collection_not_found_error(mock_vector_engine): - """Test that CollectionNotFoundError is converted to NoDataError.""" - mock_vector_engine.has_collection.side_effect = CollectionNotFoundError("Collection not found") - - retriever = TripletRetriever() - - with patch( - "cognee.modules.retrieval.triplet_retriever.get_vector_engine", - return_value=mock_vector_engine, - ): - with pytest.raises(NoDataError, match="No data found"): - await retriever.get_context("test query") From b4aaa7faefce804d9ad6fee93d9907b352206f25 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Tue, 16 Dec 2025 11:59:33 +0100 Subject: [PATCH 058/176] chore: retriever test reorganization + adding new tests (smoke e2e) (STEP 1.5) (#1888) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR restructures the end-to-end tests for the multi-database search layer to improve maintainability, consistency, and coverage across supported Python versions and database settings. Key Changes -Migrates the existing E2E tests to pytest for a more standard and extensible testing framework. -Introduces pytest fixtures to centralize and reuse test setup logic. -Implements proper event loop management to support multiple asynchronous pytest tests reliably. -Improves SQLAlchemy handling in tests, ensuring clean setup and teardown of database state. -Extends multi-database E2E test coverage across all supported Python versions. Benefits -Cleaner and more modular test structure. -Reduced duplication and clearer test intent through fixtures. -More reliable async test execution. -Better alignment with our supported Python version matrix. ## Type of Change - [ ] Bug fix (non-breaking change that fixes an issue) - [x] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [x] Code refactoring - [ ] Performance improvement - [ ] Other (please specify): ## Screenshots/Videos (if applicable) ## Pre-submission Checklist - [x] **I have tested my changes thoroughly before submitting this PR** - [x] **This PR contains minimal changes necessary to address the issue/feature** - [x] My code follows the project's coding standards and style guidelines - [x] I have added tests that prove my fix is effective or that my feature works - [x] I have added necessary documentation (if applicable) - [x] All new and existing tests pass - [x] I have searched existing PRs to ensure this change hasn't been submitted already - [x] I have linked any relevant issues in the description - [x] My commits have clear and descriptive messages ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. ## Summary by CodeRabbit * **Tests** * Expanded end-to-end test suite for the search database with comprehensive setup/teardown, new session-scoped fixtures, and multiple tests validating graph/vector consistency, retriever contexts, triplet metadata, search result shapes, side effects, and feedback-weight behavior. * **Chores** * CI updated to run matrixed test jobs across multiple Python versions and standardize test execution for more consistent, parallelized runs. ✏️ Tip: You can customize this high-level summary in your review settings. --- .github/workflows/search_db_tests.yml | 46 ++- cognee/tests/test_search_db.py | 529 +++++++++++++++++--------- 2 files changed, 374 insertions(+), 201 deletions(-) diff --git a/.github/workflows/search_db_tests.yml b/.github/workflows/search_db_tests.yml index 118c1c06c..f0c7817cd 100644 --- a/.github/workflows/search_db_tests.yml +++ b/.github/workflows/search_db_tests.yml @@ -11,12 +11,21 @@ on: type: string default: "all" description: "Which vector databases to test (comma-separated list or 'all')" + python-versions: + required: false + type: string + default: '["3.10", "3.11", "3.12", "3.13"]' + description: "Python versions to test (JSON array)" jobs: run-kuzu-lance-sqlite-search-tests: - name: Search test for Kuzu/LanceDB/Sqlite + name: Search test for Kuzu/LanceDB/Sqlite (Python ${{ matrix.python-version }}) runs-on: ubuntu-22.04 if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'kuzu/lance/sqlite') }} + strategy: + matrix: + python-version: ${{ fromJSON(inputs.python-versions) }} + fail-fast: false steps: - name: Check out uses: actions/checkout@v4 @@ -26,7 +35,7 @@ jobs: - name: Cognee Setup uses: ./.github/actions/cognee_setup with: - python-version: ${{ inputs.python-version }} + python-version: ${{ matrix.python-version }} - name: Dependencies already installed run: echo "Dependencies already installed in setup" @@ -45,13 +54,16 @@ jobs: GRAPH_DATABASE_PROVIDER: 'kuzu' VECTOR_DB_PROVIDER: 'lancedb' DB_PROVIDER: 'sqlite' - run: uv run python ./cognee/tests/test_search_db.py + run: uv run pytest cognee/tests/test_search_db.py -v --log-level=INFO run-neo4j-lance-sqlite-search-tests: - name: Search test for Neo4j/LanceDB/Sqlite + name: Search test for Neo4j/LanceDB/Sqlite (Python ${{ matrix.python-version }}) runs-on: ubuntu-22.04 if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'neo4j/lance/sqlite') }} - + strategy: + matrix: + python-version: ${{ fromJSON(inputs.python-versions) }} + fail-fast: false steps: - name: Check out uses: actions/checkout@v4 @@ -61,7 +73,7 @@ jobs: - name: Cognee Setup uses: ./.github/actions/cognee_setup with: - python-version: ${{ inputs.python-version }} + python-version: ${{ matrix.python-version }} - name: Setup Neo4j with GDS uses: ./.github/actions/setup_neo4j @@ -88,12 +100,16 @@ jobs: GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }} GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }} GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }} - run: uv run python ./cognee/tests/test_search_db.py + run: uv run pytest cognee/tests/test_search_db.py -v --log-level=INFO run-kuzu-pgvector-postgres-search-tests: - name: Search test for Kuzu/PGVector/Postgres + name: Search test for Kuzu/PGVector/Postgres (Python ${{ matrix.python-version }}) runs-on: ubuntu-22.04 if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'kuzu/pgvector/postgres') }} + strategy: + matrix: + python-version: ${{ fromJSON(inputs.python-versions) }} + fail-fast: false services: postgres: image: pgvector/pgvector:pg17 @@ -117,7 +133,7 @@ jobs: - name: Cognee Setup uses: ./.github/actions/cognee_setup with: - python-version: ${{ inputs.python-version }} + python-version: ${{ matrix.python-version }} extra-dependencies: "postgres" - name: Dependencies already installed @@ -143,12 +159,16 @@ jobs: DB_PORT: 5432 DB_USERNAME: cognee DB_PASSWORD: cognee - run: uv run python ./cognee/tests/test_search_db.py + run: uv run pytest cognee/tests/test_search_db.py -v --log-level=INFO run-neo4j-pgvector-postgres-search-tests: - name: Search test for Neo4j/PGVector/Postgres + name: Search test for Neo4j/PGVector/Postgres (Python ${{ matrix.python-version }}) runs-on: ubuntu-22.04 if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'neo4j/pgvector/postgres') }} + strategy: + matrix: + python-version: ${{ fromJSON(inputs.python-versions) }} + fail-fast: false services: postgres: image: pgvector/pgvector:pg17 @@ -172,7 +192,7 @@ jobs: - name: Cognee Setup uses: ./.github/actions/cognee_setup with: - python-version: ${{ inputs.python-version }} + python-version: ${{ matrix.python-version }} extra-dependencies: "postgres" - name: Setup Neo4j with GDS @@ -205,4 +225,4 @@ jobs: DB_PORT: 5432 DB_USERNAME: cognee DB_PASSWORD: cognee - run: uv run python ./cognee/tests/test_search_db.py + run: uv run pytest cognee/tests/test_search_db.py -v --log-level=INFO diff --git a/cognee/tests/test_search_db.py b/cognee/tests/test_search_db.py index ba150f813..0916be322 100644 --- a/cognee/tests/test_search_db.py +++ b/cognee/tests/test_search_db.py @@ -1,5 +1,10 @@ import pathlib import os +import asyncio +import pytest +import pytest_asyncio +from collections import Counter + import cognee from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.vector import get_vector_engine @@ -13,127 +18,172 @@ from cognee.modules.retrieval.graph_completion_cot_retriever import GraphComplet from cognee.modules.retrieval.graph_summary_completion_retriever import ( GraphSummaryCompletionRetriever, ) +from cognee.modules.retrieval.chunks_retriever import ChunksRetriever +from cognee.modules.retrieval.summaries_retriever import SummariesRetriever +from cognee.modules.retrieval.completion_retriever import CompletionRetriever +from cognee.modules.retrieval.temporal_retriever import TemporalRetriever from cognee.modules.retrieval.triplet_retriever import TripletRetriever from cognee.shared.logging_utils import get_logger from cognee.modules.search.types import SearchType from cognee.modules.users.methods import get_default_user -from collections import Counter logger = get_logger() -async def main(): - # This test runs for multiple db settings, to run this locally set the corresponding db envs +async def _reset_engines_and_prune() -> None: + """Reset db engine caches and prune data/system. + + Kept intentionally identical to the inlined setup logic to avoid event loop issues when + using deployed databases (Neo4j, PostgreSQL) and to ensure fresh instances per run. + """ + # Dispose of existing engines and clear caches to ensure fresh instances for each test + try: + from cognee.infrastructure.databases.vector import get_vector_engine + + vector_engine = get_vector_engine() + # Dispose SQLAlchemy engine connection pool if it exists + if hasattr(vector_engine, "engine") and hasattr(vector_engine.engine, "dispose"): + await vector_engine.engine.dispose(close=True) + except Exception: + # Engine might not exist yet + pass + + from cognee.infrastructure.databases.graph.get_graph_engine import create_graph_engine + from cognee.infrastructure.databases.vector.create_vector_engine import create_vector_engine + from cognee.infrastructure.databases.relational.create_relational_engine import ( + create_relational_engine, + ) + + create_graph_engine.cache_clear() + create_vector_engine.cache_clear() + create_relational_engine.cache_clear() + await cognee.prune.prune_data() await cognee.prune.prune_system(metadata=True) - dataset_name = "test_dataset" +async def _seed_default_dataset(dataset_name: str) -> dict: + """Add the shared test dataset contents and run cognify (same steps/order as before).""" text_1 = """Germany is located in europe right next to the Netherlands""" + + logger.info(f"Adding text data to dataset: {dataset_name}") await cognee.add(text_1, dataset_name) explanation_file_path_quantum = os.path.join( pathlib.Path(__file__).parent, "test_data/Quantum_computers.txt" ) + logger.info(f"Adding file data to dataset: {dataset_name}") await cognee.add([explanation_file_path_quantum], dataset_name) + logger.info(f"Running cognify on dataset: {dataset_name}") await cognee.cognify([dataset_name]) + return { + "dataset_name": dataset_name, + "text_1": text_1, + "explanation_file_path_quantum": explanation_file_path_quantum, + } + + +@pytest.fixture(scope="session") +def event_loop(): + """Use a single asyncio event loop for this test module. + + This helps avoid "Future attached to a different loop" when running multiple async + tests that share clients/engines. + """ + loop = asyncio.new_event_loop() + try: + yield loop + finally: + loop.close() + + +async def setup_test_environment(): + """Helper function to set up test environment with data, cognify, and triplet embeddings.""" + # This test runs for multiple db settings, to run this locally set the corresponding db envs + + dataset_name = "test_dataset" + logger.info("Starting test setup: pruning data and system") + await _reset_engines_and_prune() + state = await _seed_default_dataset(dataset_name=dataset_name) + user = await get_default_user() from cognee.memify_pipelines.create_triplet_embeddings import create_triplet_embeddings + logger.info("Creating triplet embeddings") await create_triplet_embeddings(user=user, dataset=dataset_name, triplets_batch_size=5) + # Check if Triplet_text collection was created + vector_engine = get_vector_engine() + has_collection = await vector_engine.has_collection(collection_name="Triplet_text") + logger.info(f"Triplet_text collection exists after creation: {has_collection}") + + if has_collection: + collection = await vector_engine.get_collection("Triplet_text") + count = await collection.count_rows() if hasattr(collection, "count_rows") else "unknown" + logger.info(f"Triplet_text collection row count: {count}") + + return state + + +async def setup_test_environment_for_feedback(): + """Helper function to set up test environment for feedback weight calculation test.""" + dataset_name = "test_dataset" + await _reset_engines_and_prune() + return await _seed_default_dataset(dataset_name=dataset_name) + + +@pytest_asyncio.fixture(scope="session") +async def e2e_state(): + """Compute E2E artifacts once; tests only assert. + + This avoids repeating expensive setup and LLM calls across multiple tests. + """ + await setup_test_environment() + + # --- Graph/vector engine consistency --- graph_engine = await get_graph_engine() - nodes, edges = await graph_engine.get_graph_data() + _nodes, edges = await graph_engine.get_graph_data() vector_engine = get_vector_engine() collection = await vector_engine.search( - query_text="Test", limit=None, collection_name="Triplet_text" + collection_name="Triplet_text", query_text="Test", limit=None ) - assert len(edges) == len(collection), ( - f"Expected {len(edges)} edges but got {len(collection)} in Triplet_text collection" - ) + # --- Retriever contexts --- + query = "Next to which country is Germany located?" - context_gk = await GraphCompletionRetriever().get_context( - query="Next to which country is Germany located?" - ) - context_gk_cot = await GraphCompletionCotRetriever().get_context( - query="Next to which country is Germany located?" - ) - context_gk_ext = await GraphCompletionContextExtensionRetriever().get_context( - query="Next to which country is Germany located?" - ) - context_gk_sum = await GraphSummaryCompletionRetriever().get_context( - query="Next to which country is Germany located?" - ) - context_triplet = await TripletRetriever().get_context( - query="Next to which country is Germany located?" - ) + contexts = { + "graph_completion": await GraphCompletionRetriever().get_context(query=query), + "graph_completion_cot": await GraphCompletionCotRetriever().get_context(query=query), + "graph_completion_context_extension": await GraphCompletionContextExtensionRetriever().get_context( + query=query + ), + "graph_summary_completion": await GraphSummaryCompletionRetriever().get_context( + query=query + ), + "chunks": await ChunksRetriever(top_k=5).get_context(query=query), + "summaries": await SummariesRetriever(top_k=5).get_context(query=query), + "rag_completion": await CompletionRetriever(top_k=3).get_context(query=query), + "temporal": await TemporalRetriever(top_k=5).get_context(query=query), + "triplet": await TripletRetriever().get_context(query=query), + } - for name, context in [ - ("GraphCompletionRetriever", context_gk), - ("GraphCompletionCotRetriever", context_gk_cot), - ("GraphCompletionContextExtensionRetriever", context_gk_ext), - ("GraphSummaryCompletionRetriever", context_gk_sum), - ]: - assert isinstance(context, list), f"{name}: Context should be a list" - assert len(context) > 0, f"{name}: Context should not be empty" - - context_text = await resolve_edges_to_text(context) - lower = context_text.lower() - assert "germany" in lower or "netherlands" in lower, ( - f"{name}: Context did not contain 'germany' or 'netherlands'; got: {context!r}" - ) - - assert isinstance(context_triplet, str), "TripletRetriever: Context should be a string" - assert len(context_triplet) > 0, "TripletRetriever: Context should not be empty" - lower_triplet = context_triplet.lower() - assert "germany" in lower_triplet or "netherlands" in lower_triplet, ( - f"TripletRetriever: Context did not contain 'germany' or 'netherlands'; got: {context_triplet!r}" - ) - - triplets_gk = await GraphCompletionRetriever().get_triplets( - query="Next to which country is Germany located?" - ) - triplets_gk_cot = await GraphCompletionCotRetriever().get_triplets( - query="Next to which country is Germany located?" - ) - triplets_gk_ext = await GraphCompletionContextExtensionRetriever().get_triplets( - query="Next to which country is Germany located?" - ) - triplets_gk_sum = await GraphSummaryCompletionRetriever().get_triplets( - query="Next to which country is Germany located?" - ) - - for name, triplets in [ - ("GraphCompletionRetriever", triplets_gk), - ("GraphCompletionCotRetriever", triplets_gk_cot), - ("GraphCompletionContextExtensionRetriever", triplets_gk_ext), - ("GraphSummaryCompletionRetriever", triplets_gk_sum), - ]: - assert isinstance(triplets, list), f"{name}: Triplets should be a list" - assert triplets, f"{name}: Triplets list should not be empty" - for edge in triplets: - assert isinstance(edge, Edge), f"{name}: Elements should be Edge instances" - distance = edge.attributes.get("vector_distance") - node1_distance = edge.node1.attributes.get("vector_distance") - node2_distance = edge.node2.attributes.get("vector_distance") - assert isinstance(distance, float), ( - f"{name}: vector_distance should be float, got {type(distance)}" - ) - assert 0 <= distance <= 1, ( - f"{name}: edge vector_distance {distance} out of [0,1], this shouldn't happen" - ) - assert 0 <= node1_distance <= 1, ( - f"{name}: node_1 vector_distance {distance} out of [0,1], this shouldn't happen" - ) - assert 0 <= node2_distance <= 1, ( - f"{name}: node_2 vector_distance {distance} out of [0,1], this shouldn't happen" - ) + # --- Retriever triplets + vector distance validation --- + triplets = { + "graph_completion": await GraphCompletionRetriever().get_triplets(query=query), + "graph_completion_cot": await GraphCompletionCotRetriever().get_triplets(query=query), + "graph_completion_context_extension": await GraphCompletionContextExtensionRetriever().get_triplets( + query=query + ), + "graph_summary_completion": await GraphSummaryCompletionRetriever().get_triplets( + query=query + ), + } + # --- Search operations + graph side effects --- completion_gk = await cognee.search( query_type=SearchType.GRAPH_COMPLETION, query_text="Where is germany located, next to which country?", @@ -164,6 +214,26 @@ async def main(): query_text="Next to which country is Germany located?", save_interaction=True, ) + completion_chunks = await cognee.search( + query_type=SearchType.CHUNKS, + query_text="Germany", + save_interaction=False, + ) + completion_summaries = await cognee.search( + query_type=SearchType.SUMMARIES, + query_text="Germany", + save_interaction=False, + ) + completion_rag = await cognee.search( + query_type=SearchType.RAG_COMPLETION, + query_text="Next to which country is Germany located?", + save_interaction=False, + ) + completion_temporal = await cognee.search( + query_type=SearchType.TEMPORAL, + query_text="Next to which country is Germany located?", + save_interaction=False, + ) await cognee.search( query_type=SearchType.FEEDBACK, @@ -171,134 +241,217 @@ async def main(): last_k=1, ) - for name, search_results in [ - ("GRAPH_COMPLETION", completion_gk), - ("GRAPH_COMPLETION_COT", completion_cot), - ("GRAPH_COMPLETION_CONTEXT_EXTENSION", completion_ext), - ("GRAPH_SUMMARY_COMPLETION", completion_sum), - ("TRIPLET_COMPLETION", completion_triplet), - ]: - assert isinstance(search_results, list), f"{name}: should return a list" - assert len(search_results) == 1, ( - f"{name}: expected single-element list, got {len(search_results)}" - ) + # Snapshot after all E2E operations above (used by assertion-only tests). + graph_snapshot = await (await get_graph_engine()).get_graph_data() - from cognee.context_global_variables import backend_access_control_enabled + return { + "graph_edges": edges, + "triplet_collection": collection, + "vector_collection_edges_count": len(collection), + "graph_edges_count": len(edges), + "contexts": contexts, + "triplets": triplets, + "search_results": { + "graph_completion": completion_gk, + "graph_completion_cot": completion_cot, + "graph_completion_context_extension": completion_ext, + "graph_summary_completion": completion_sum, + "triplet_completion": completion_triplet, + "chunks": completion_chunks, + "summaries": completion_summaries, + "rag_completion": completion_rag, + "temporal": completion_temporal, + }, + "graph_snapshot": graph_snapshot, + } - if backend_access_control_enabled(): - text = search_results[0]["search_result"][0] - else: - text = search_results[0] - assert isinstance(text, str), f"{name}: element should be a string" - assert text.strip(), f"{name}: string should not be empty" - assert "netherlands" in text.lower(), ( - f"{name}: expected 'netherlands' in result, got: {text!r}" - ) - graph_engine = await get_graph_engine() - graph = await graph_engine.get_graph_data() - - type_counts = Counter(node_data[1].get("type", {}) for node_data in graph[0]) - - edge_type_counts = Counter(edge_type[2] for edge_type in graph[1]) - - # Assert there are exactly 4 CogneeUserInteraction nodes. - assert type_counts.get("CogneeUserInteraction", 0) == 4, ( - f"Expected exactly four CogneeUserInteraction nodes, but found {type_counts.get('CogneeUserInteraction', 0)}" - ) - - # Assert there is exactly two CogneeUserFeedback nodes. - assert type_counts.get("CogneeUserFeedback", 0) == 2, ( - f"Expected exactly two CogneeUserFeedback nodes, but found {type_counts.get('CogneeUserFeedback', 0)}" - ) - - # Assert there is exactly two NodeSet. - assert type_counts.get("NodeSet", 0) == 2, ( - f"Expected exactly two NodeSet nodes, but found {type_counts.get('NodeSet', 0)}" - ) - - # Assert that there are at least 10 'used_graph_element_to_answer' edges. - assert edge_type_counts.get("used_graph_element_to_answer", 0) >= 10, ( - f"Expected at least ten 'used_graph_element_to_answer' edges, but found {edge_type_counts.get('used_graph_element_to_answer', 0)}" - ) - - # Assert that there are exactly 2 'gives_feedback_to' edges. - assert edge_type_counts.get("gives_feedback_to", 0) == 2, ( - f"Expected exactly two 'gives_feedback_to' edges, but found {edge_type_counts.get('gives_feedback_to', 0)}" - ) - - # Assert that there are at least 6 'belongs_to_set' edges. - assert edge_type_counts.get("belongs_to_set", 0) == 6, ( - f"Expected at least six 'belongs_to_set' edges, but found {edge_type_counts.get('belongs_to_set', 0)}" - ) - - nodes = graph[0] - - required_fields_user_interaction = {"question", "answer", "context"} - required_fields_feedback = {"feedback", "sentiment"} - - for node_id, data in nodes: - if data.get("type") == "CogneeUserInteraction": - assert required_fields_user_interaction.issubset(data.keys()), ( - f"Node {node_id} is missing fields: {required_fields_user_interaction - set(data.keys())}" - ) - - for field in required_fields_user_interaction: - value = data[field] - assert isinstance(value, str) and value.strip(), ( - f"Node {node_id} has invalid value for '{field}': {value!r}" - ) - - if data.get("type") == "CogneeUserFeedback": - assert required_fields_feedback.issubset(data.keys()), ( - f"Node {node_id} is missing fields: {required_fields_feedback - set(data.keys())}" - ) - - for field in required_fields_feedback: - value = data[field] - assert isinstance(value, str) and value.strip(), ( - f"Node {node_id} has invalid value for '{field}': {value!r}" - ) - - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - - await cognee.add(text_1, dataset_name) - - await cognee.add([text], dataset_name) - - await cognee.cognify([dataset_name]) +@pytest_asyncio.fixture(scope="session") +async def feedback_state(): + """Feedback-weight scenario computed once (fresh environment).""" + await setup_test_environment_for_feedback() await cognee.search( query_type=SearchType.GRAPH_COMPLETION, query_text="Next to which country is Germany located?", save_interaction=True, ) - await cognee.search( query_type=SearchType.FEEDBACK, query_text="This was the best answer I've ever seen", last_k=1, ) - await cognee.search( query_type=SearchType.FEEDBACK, query_text="Wow the correctness of this answer blows my mind", last_k=1, ) + graph_engine = await get_graph_engine() graph = await graph_engine.get_graph_data() + return {"graph_snapshot": graph} - edges = graph[1] - for from_node, to_node, relationship_name, properties in edges: +@pytest.mark.asyncio +async def test_e2e_graph_vector_consistency(e2e_state): + """Graph and vector stores contain the same triplet edges.""" + assert e2e_state["graph_edges_count"] == e2e_state["vector_collection_edges_count"] + + +@pytest.mark.asyncio +async def test_e2e_retriever_contexts(e2e_state): + """All retrievers return non-empty, well-typed contexts.""" + contexts = e2e_state["contexts"] + + for name in [ + "graph_completion", + "graph_completion_cot", + "graph_completion_context_extension", + "graph_summary_completion", + ]: + ctx = contexts[name] + assert isinstance(ctx, list), f"{name}: Context should be a list" + assert ctx, f"{name}: Context should not be empty" + ctx_text = await resolve_edges_to_text(ctx) + lower = ctx_text.lower() + assert "germany" in lower or "netherlands" in lower, ( + f"{name}: Context did not contain 'germany' or 'netherlands'; got: {ctx!r}" + ) + + triplet_ctx = contexts["triplet"] + assert isinstance(triplet_ctx, str), "triplet: Context should be a string" + assert triplet_ctx.strip(), "triplet: Context should not be empty" + + chunks_ctx = contexts["chunks"] + assert isinstance(chunks_ctx, list), "chunks: Context should be a list" + assert chunks_ctx, "chunks: Context should not be empty" + chunks_text = "\n".join(str(item.get("text", "")) for item in chunks_ctx).lower() + assert "germany" in chunks_text or "netherlands" in chunks_text + + summaries_ctx = contexts["summaries"] + assert isinstance(summaries_ctx, list), "summaries: Context should be a list" + assert summaries_ctx, "summaries: Context should not be empty" + assert any(str(item.get("text", "")).strip() for item in summaries_ctx) + + rag_ctx = contexts["rag_completion"] + assert isinstance(rag_ctx, str), "rag_completion: Context should be a string" + assert rag_ctx.strip(), "rag_completion: Context should not be empty" + + temporal_ctx = contexts["temporal"] + assert isinstance(temporal_ctx, str), "temporal: Context should be a string" + assert temporal_ctx.strip(), "temporal: Context should not be empty" + + +@pytest.mark.asyncio +async def test_e2e_retriever_triplets_have_vector_distances(e2e_state): + """Graph retriever triplets include sane vector_distance metadata.""" + for name, triplets in e2e_state["triplets"].items(): + assert isinstance(triplets, list), f"{name}: Triplets should be a list" + assert triplets, f"{name}: Triplets list should not be empty" + for edge in triplets: + assert isinstance(edge, Edge), f"{name}: Elements should be Edge instances" + distance = edge.attributes.get("vector_distance") + node1_distance = edge.node1.attributes.get("vector_distance") + node2_distance = edge.node2.attributes.get("vector_distance") + assert isinstance(distance, float), f"{name}: vector_distance should be float" + assert 0 <= distance <= 1 + assert 0 <= node1_distance <= 1 + assert 0 <= node2_distance <= 1 + + +@pytest.mark.asyncio +async def test_e2e_search_results_and_wrappers(e2e_state): + """Search returns expected shapes across search types and access modes.""" + from cognee.context_global_variables import backend_access_control_enabled + + sr = e2e_state["search_results"] + + # Completion-like search types: validate wrapper + content + for name in [ + "graph_completion", + "graph_completion_cot", + "graph_completion_context_extension", + "graph_summary_completion", + "triplet_completion", + "rag_completion", + "temporal", + ]: + search_results = sr[name] + assert isinstance(search_results, list), f"{name}: should return a list" + assert len(search_results) == 1, f"{name}: expected single-element list" + + if backend_access_control_enabled(): + wrapper = search_results[0] + assert isinstance(wrapper, dict), ( + f"{name}: expected wrapper dict in access control mode" + ) + assert wrapper.get("dataset_id"), f"{name}: missing dataset_id in wrapper" + assert wrapper.get("dataset_name") == "test_dataset" + assert "graphs" in wrapper + text = wrapper["search_result"][0] + else: + text = search_results[0] + + assert isinstance(text, str) and text.strip() + assert "netherlands" in text.lower() + + # Non-LLM search types: CHUNKS / SUMMARIES validate payload list + text + for name in ["chunks", "summaries"]: + search_results = sr[name] + assert isinstance(search_results, list), f"{name}: should return a list" + assert search_results, f"{name}: should not be empty" + + first = search_results[0] + assert isinstance(first, dict), f"{name}: expected dict entries" + + payloads = search_results + if "search_result" in first and "text" not in first: + payloads = (first.get("search_result") or [None])[0] + + assert isinstance(payloads, list) and payloads + assert isinstance(payloads[0], dict) + assert str(payloads[0].get("text", "")).strip() + + +@pytest.mark.asyncio +async def test_e2e_graph_side_effects_and_node_fields(e2e_state): + """Search interactions create expected graph nodes/edges and required fields.""" + graph = e2e_state["graph_snapshot"] + nodes, edges = graph + + type_counts = Counter(node_data[1].get("type", {}) for node_data in nodes) + edge_type_counts = Counter(edge_type[2] for edge_type in edges) + + assert type_counts.get("CogneeUserInteraction", 0) == 4 + assert type_counts.get("CogneeUserFeedback", 0) == 2 + assert type_counts.get("NodeSet", 0) == 2 + assert edge_type_counts.get("used_graph_element_to_answer", 0) >= 10 + assert edge_type_counts.get("gives_feedback_to", 0) == 2 + assert edge_type_counts.get("belongs_to_set", 0) >= 6 + + required_fields_user_interaction = {"question", "answer", "context"} + required_fields_feedback = {"feedback", "sentiment"} + + for node_id, data in nodes: + if data.get("type") == "CogneeUserInteraction": + assert required_fields_user_interaction.issubset(data.keys()) + for field in required_fields_user_interaction: + value = data[field] + assert isinstance(value, str) and value.strip() + + if data.get("type") == "CogneeUserFeedback": + assert required_fields_feedback.issubset(data.keys()) + for field in required_fields_feedback: + value = data[field] + assert isinstance(value, str) and value.strip() + + +@pytest.mark.asyncio +async def test_e2e_feedback_weight_calculation(feedback_state): + """Positive feedback increases used_graph_element_to_answer feedback_weight.""" + _nodes, edges = feedback_state["graph_snapshot"] + for _from_node, _to_node, relationship_name, properties in edges: if relationship_name == "used_graph_element_to_answer": assert properties["feedback_weight"] >= 6, ( "Feedback weight calculation is not correct, it should be more then 6." ) - - -if __name__ == "__main__": - import asyncio - - asyncio.run(main()) From 21407dd9ed593c476b0234e2f2ef6a3112433cb3 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Tue, 16 Dec 2025 12:03:49 +0100 Subject: [PATCH 059/176] test: Resolve silent mcp test --- cognee-mcp/src/test_client.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cognee-mcp/src/test_client.py b/cognee-mcp/src/test_client.py index bce7f807f..e3329ba91 100755 --- a/cognee-mcp/src/test_client.py +++ b/cognee-mcp/src/test_client.py @@ -627,8 +627,7 @@ class TestModel: print(f"Failed: {failed}") print(f"Success Rate: {(passed / total_tests * 100):.1f}%") - if failed > 0: - print(f"\n ⚠️ {failed} test(s) failed - review results above for details") + assert failed == 0, f"\n ⚠️ {failed} test(s) failed - review results above for details" async def main(): From c61ff60e40eedcf892b4ed6a08621aa2a7adfcc4 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Tue, 16 Dec 2025 15:37:33 +0100 Subject: [PATCH 060/176] feat: add unit tests for get_search_type_tools --- .../search/test_get_search_type_tools.py | 221 ++++++++++++++++++ 1 file changed, 221 insertions(+) create mode 100644 cognee/tests/unit/modules/search/test_get_search_type_tools.py diff --git a/cognee/tests/unit/modules/search/test_get_search_type_tools.py b/cognee/tests/unit/modules/search/test_get_search_type_tools.py new file mode 100644 index 000000000..15b489bfa --- /dev/null +++ b/cognee/tests/unit/modules/search/test_get_search_type_tools.py @@ -0,0 +1,221 @@ +import pytest + +from cognee.modules.search.exceptions import UnsupportedSearchTypeError +from cognee.modules.search.types import SearchType + + +class _DummyCommunityRetriever: + def __init__(self, *args, **kwargs): + self.kwargs = kwargs + + def get_completion(self, *args, **kwargs): + return {"kind": "completion", "init": self.kwargs, "args": args, "kwargs": kwargs} + + def get_context(self, *args, **kwargs): + return {"kind": "context", "init": self.kwargs, "args": args, "kwargs": kwargs} + + +@pytest.mark.asyncio +async def test_feeling_lucky_delegates_to_select_search_type(monkeypatch): + import cognee.modules.search.methods.get_search_type_tools as mod + from cognee.modules.retrieval.chunks_retriever import ChunksRetriever + + async def _fake_select_search_type(query_text: str): + assert query_text == "hello" + return SearchType.CHUNKS + + monkeypatch.setattr(mod, "select_search_type", _fake_select_search_type) + + tools = await mod.get_search_type_tools(SearchType.FEELING_LUCKY, query_text="hello") + + assert len(tools) == 2 + assert all(callable(t) for t in tools) + assert tools[0].__name__ == "get_completion" + assert tools[1].__name__ == "get_context" + assert tools[0].__self__.__class__ is ChunksRetriever + assert tools[1].__self__.__class__ is ChunksRetriever + + +@pytest.mark.asyncio +async def test_disallowed_cypher_search_types_raise(monkeypatch): + import cognee.modules.search.methods.get_search_type_tools as mod + + monkeypatch.setenv("ALLOW_CYPHER_QUERY", "false") + + with pytest.raises(UnsupportedSearchTypeError, match="disabled"): + await mod.get_search_type_tools(SearchType.CYPHER, query_text="MATCH (n) RETURN n") + + with pytest.raises(UnsupportedSearchTypeError, match="disabled"): + await mod.get_search_type_tools(SearchType.NATURAL_LANGUAGE, query_text="Find nodes") + + +@pytest.mark.asyncio +async def test_allowed_cypher_search_types_return_tools(monkeypatch): + import cognee.modules.search.methods.get_search_type_tools as mod + from cognee.modules.retrieval.cypher_search_retriever import CypherSearchRetriever + + monkeypatch.setenv("ALLOW_CYPHER_QUERY", "true") + + tools = await mod.get_search_type_tools(SearchType.CYPHER, query_text="q") + assert len(tools) == 2 + assert tools[0].__name__ == "get_completion" + assert tools[1].__name__ == "get_context" + assert tools[0].__self__.__class__ is CypherSearchRetriever + assert tools[1].__self__.__class__ is CypherSearchRetriever + + +@pytest.mark.asyncio +async def test_registered_community_retriever_is_used(monkeypatch): + """ + Integration point: community retrievers are loaded from the registry module and should + override the default mapping when present. + """ + import cognee.modules.search.methods.get_search_type_tools as mod + from cognee.modules.retrieval import registered_community_retrievers as registry + + monkeypatch.setattr( + registry, + "registered_community_retrievers", + {SearchType.SUMMARIES: _DummyCommunityRetriever}, + ) + + tools = await mod.get_search_type_tools(SearchType.SUMMARIES, query_text="q", top_k=7) + + assert len(tools) == 2 + assert tools[0].__self__.__class__ is _DummyCommunityRetriever + assert tools[0].__self__.kwargs["top_k"] == 7 + assert tools[1].__self__.__class__ is _DummyCommunityRetriever + assert tools[1].__self__.kwargs["top_k"] == 7 + + +@pytest.mark.asyncio +async def test_unknown_query_type_raises_unsupported(): + import cognee.modules.search.methods.get_search_type_tools as mod + + with pytest.raises(UnsupportedSearchTypeError, match="UNKNOWN_TYPE"): + await mod.get_search_type_tools("UNKNOWN_TYPE", query_text="q") # type: ignore[arg-type] + + +@pytest.mark.asyncio +async def test_default_mapping_passes_top_k_to_retrievers(): + import cognee.modules.search.methods.get_search_type_tools as mod + from cognee.modules.retrieval.summaries_retriever import SummariesRetriever + + tools = await mod.get_search_type_tools(SearchType.SUMMARIES, query_text="q", top_k=4) + assert len(tools) == 2 + assert tools[0].__self__.__class__ is SummariesRetriever + assert tools[1].__self__.__class__ is SummariesRetriever + assert tools[0].__self__.top_k == 4 + assert tools[1].__self__.top_k == 4 + + +@pytest.mark.asyncio +async def test_chunks_lexical_returns_jaccard_tools(): + import cognee.modules.search.methods.get_search_type_tools as mod + from cognee.modules.retrieval.jaccard_retrival import JaccardChunksRetriever + + tools = await mod.get_search_type_tools(SearchType.CHUNKS_LEXICAL, query_text="q", top_k=3) + assert len(tools) == 2 + assert tools[0].__self__.__class__ is JaccardChunksRetriever + assert tools[1].__self__.__class__ is JaccardChunksRetriever + assert tools[0].__self__ is tools[1].__self__ + + +@pytest.mark.asyncio +async def test_coding_rules_uses_node_name_as_rules_nodeset_name(): + import cognee.modules.search.methods.get_search_type_tools as mod + from cognee.modules.retrieval.coding_rules_retriever import CodingRulesRetriever + + tools = await mod.get_search_type_tools(SearchType.CODING_RULES, query_text="q", node_name=[]) + assert len(tools) == 1 + assert tools[0].__name__ == "get_existing_rules" + assert tools[0].__self__.__class__ is CodingRulesRetriever + # Empty list should default to ["coding_agent_rules"] + assert tools[0].__self__.rules_nodeset_name == ["coding_agent_rules"] + + +@pytest.mark.asyncio +async def test_feedback_uses_last_k(): + import cognee.modules.search.methods.get_search_type_tools as mod + from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback + + tools = await mod.get_search_type_tools(SearchType.FEEDBACK, query_text="q", last_k=11) + assert len(tools) == 1 + assert tools[0].__name__ == "add_feedback" + assert tools[0].__self__.__class__ is UserQAFeedback + assert tools[0].__self__.last_k == 11 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "query_type, expected_class_name, expected_method_names", + [ + (SearchType.CHUNKS, "ChunksRetriever", ("get_completion", "get_context")), + (SearchType.RAG_COMPLETION, "CompletionRetriever", ("get_completion", "get_context")), + (SearchType.TRIPLET_COMPLETION, "TripletRetriever", ("get_completion", "get_context")), + ( + SearchType.GRAPH_COMPLETION, + "GraphCompletionRetriever", + ("get_completion", "get_context"), + ), + ( + SearchType.GRAPH_COMPLETION_COT, + "GraphCompletionCotRetriever", + ("get_completion", "get_context"), + ), + ( + SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION, + "GraphCompletionContextExtensionRetriever", + ("get_completion", "get_context"), + ), + ( + SearchType.GRAPH_SUMMARY_COMPLETION, + "GraphSummaryCompletionRetriever", + ("get_completion", "get_context"), + ), + (SearchType.TEMPORAL, "TemporalRetriever", ("get_completion", "get_context")), + ( + SearchType.NATURAL_LANGUAGE, + "NaturalLanguageRetriever", + ("get_completion", "get_context"), + ), + ], +) +async def test_tool_construction_for_supported_search_types( + monkeypatch, query_type, expected_class_name, expected_method_names +): + import cognee.modules.search.methods.get_search_type_tools as mod + + # Natural language is guarded by ALLOW_CYPHER_QUERY too + monkeypatch.setenv("ALLOW_CYPHER_QUERY", "true") + + tools = await mod.get_search_type_tools(query_type, query_text="q") + + assert len(tools) == 2 + assert tools[0].__name__ == expected_method_names[0] + assert tools[1].__name__ == expected_method_names[1] + assert tools[0].__self__.__class__.__name__ == expected_class_name + assert tools[1].__self__.__class__.__name__ == expected_class_name + + +@pytest.mark.asyncio +async def test_some_completion_tools_are_callable_without_backends(monkeypatch): + """ + "Making search tools" should include that the returned callables are usable. + For retrievers that accept an explicit `context`, we can call get_completion without touching + DB/LLM backends. + """ + import cognee.modules.search.methods.get_search_type_tools as mod + + monkeypatch.setenv("ALLOW_CYPHER_QUERY", "true") + + for query_type in [ + SearchType.CHUNKS, + SearchType.SUMMARIES, + SearchType.CYPHER, + SearchType.NATURAL_LANGUAGE, + ]: + tools = await mod.get_search_type_tools(query_type, query_text="q") + completion = tools[0] + result = await completion("q", context=["ok"]) # type: ignore[call-arg] + assert result == ["ok"] From 89ef7d7d151cb18e81ad676ca1c205d6995e61d4 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Tue, 16 Dec 2025 15:41:13 +0100 Subject: [PATCH 061/176] feat: adds integration test for community registered retriever case --- .../test_get_search_type_tools_integration.py | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 cognee/tests/integration/search/test_get_search_type_tools_integration.py diff --git a/cognee/tests/integration/search/test_get_search_type_tools_integration.py b/cognee/tests/integration/search/test_get_search_type_tools_integration.py new file mode 100644 index 000000000..4380a3bdb --- /dev/null +++ b/cognee/tests/integration/search/test_get_search_type_tools_integration.py @@ -0,0 +1,36 @@ +import pytest + +from cognee.modules.search.types import SearchType + + +class _DummyCompletionContextRetriever: + def __init__(self, *args, **kwargs): + self.kwargs = kwargs + + def get_completion(self, *args, **kwargs): + return None + + def get_context(self, *args, **kwargs): + return None + + +@pytest.mark.asyncio +async def test_community_registry_is_consulted(monkeypatch): + """ + This test covers the dynamic import + lookup of community retrievers in + cognee.modules.retrieval.registered_community_retrievers. + """ + import cognee.modules.search.methods.get_search_type_tools as mod + from cognee.modules.retrieval import registered_community_retrievers as registry + + monkeypatch.setattr( + registry, + "registered_community_retrievers", + {SearchType.NATURAL_LANGUAGE: _DummyCompletionContextRetriever}, + ) + + tools = await mod.get_search_type_tools(SearchType.NATURAL_LANGUAGE, query_text="q", top_k=9) + + assert len(tools) == 2 + assert tools[0].__self__.kwargs["top_k"] == 9 + assert tools[1].__self__.kwargs["top_k"] == 9 From 48c2040f3ddc3ba3dcdd0b7f3f7b9005aa3350c3 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Tue, 16 Dec 2025 15:45:32 +0100 Subject: [PATCH 062/176] Delete test_get_search_type_tools_integration.py --- .../test_get_search_type_tools_integration.py | 36 ------------------- 1 file changed, 36 deletions(-) delete mode 100644 cognee/tests/integration/search/test_get_search_type_tools_integration.py diff --git a/cognee/tests/integration/search/test_get_search_type_tools_integration.py b/cognee/tests/integration/search/test_get_search_type_tools_integration.py deleted file mode 100644 index 4380a3bdb..000000000 --- a/cognee/tests/integration/search/test_get_search_type_tools_integration.py +++ /dev/null @@ -1,36 +0,0 @@ -import pytest - -from cognee.modules.search.types import SearchType - - -class _DummyCompletionContextRetriever: - def __init__(self, *args, **kwargs): - self.kwargs = kwargs - - def get_completion(self, *args, **kwargs): - return None - - def get_context(self, *args, **kwargs): - return None - - -@pytest.mark.asyncio -async def test_community_registry_is_consulted(monkeypatch): - """ - This test covers the dynamic import + lookup of community retrievers in - cognee.modules.retrieval.registered_community_retrievers. - """ - import cognee.modules.search.methods.get_search_type_tools as mod - from cognee.modules.retrieval import registered_community_retrievers as registry - - monkeypatch.setattr( - registry, - "registered_community_retrievers", - {SearchType.NATURAL_LANGUAGE: _DummyCompletionContextRetriever}, - ) - - tools = await mod.get_search_type_tools(SearchType.NATURAL_LANGUAGE, query_text="q", top_k=9) - - assert len(tools) == 2 - assert tools[0].__self__.kwargs["top_k"] == 9 - assert tools[1].__self__.kwargs["top_k"] == 9 From a52873a71f82315931758273ce1ec822817e30b4 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Tue, 16 Dec 2025 15:54:25 +0100 Subject: [PATCH 063/176] refactor: make return type mandatory for transcription --- .../litellm_instructor/llm/generic_llm_api/adapter.py | 11 ++++------- .../litellm_instructor/llm/llm_interface.py | 10 ++++++---- .../litellm_instructor/llm/openai/adapter.py | 11 +++++------ 3 files changed, 15 insertions(+), 17 deletions(-) diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py index 84a88f9d6..dc30b7310 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py @@ -193,7 +193,7 @@ class GenericAPIAdapter(LLMInterface): before_sleep=before_sleep_log(logger, logging.DEBUG), reraise=True, ) - async def create_transcript(self, input) -> Optional[TranscriptionReturnType]: + async def create_transcript(self, input) -> TranscriptionReturnType: """ Generate an audio transcript from a user query. @@ -216,7 +216,7 @@ class GenericAPIAdapter(LLMInterface): raise ValueError( f"Could not determine MIME type for audio file: {input}. Is the extension correct?" ) - response = litellm.completion( + response = litellm.completion( model=self.transcription_model, messages=[ { @@ -236,11 +236,8 @@ class GenericAPIAdapter(LLMInterface): api_base=self.endpoint, max_retries=self.MAX_RETRIES, ) - if response and response.choices and len(response.choices) > 0: - return TranscriptionReturnType(response.choices[0].message.content,response) - else: - return None + return TranscriptionReturnType(response.choices[0].message.content, response) @observe(as_type="transcribe_image") @retry( @@ -250,7 +247,7 @@ class GenericAPIAdapter(LLMInterface): before_sleep=before_sleep_log(logger, logging.DEBUG), reraise=True, ) - async def transcribe_image(self, input) -> Optional[BaseModel]: + async def transcribe_image(self, input) -> BaseModel: """ Generate a transcription of an image from a user query. diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py index f8352737d..ae4fb25b0 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py @@ -1,9 +1,11 @@ """LLM Interface""" -from typing import Type, Protocol, Optional +from typing import Type, Protocol from abc import abstractmethod from pydantic import BaseModel -from cognee.infrastructure.llm.LLMGateway import LLMGateway +from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.types import ( + TranscriptionReturnType, +) class LLMInterface(Protocol): @@ -37,7 +39,7 @@ class LLMInterface(Protocol): raise NotImplementedError @abstractmethod - async def create_transcript(self, input) -> Optional[BaseModel]: + async def create_transcript(self, input) -> TranscriptionReturnType: """ Transcribe audio content to text. @@ -55,7 +57,7 @@ class LLMInterface(Protocol): raise NotImplementedError @abstractmethod - async def transcribe_image(self, input) -> Optional[BaseModel]: + async def transcribe_image(self, input) -> BaseModel: """ Analyze image content and return text description. diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py index eba859802..d708f03e8 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py @@ -25,7 +25,9 @@ from cognee.shared.rate_limiting import llm_rate_limiter_context_manager from cognee.infrastructure.files.utils.open_data_file import open_data_file from cognee.modules.observability.get_observe import get_observe from cognee.shared.logging_utils import get_logger -from ..types import TranscriptionReturnType +from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.types import ( + TranscriptionReturnType, +) logger = get_logger() @@ -203,7 +205,7 @@ class OpenAIAdapter(GenericAPIAdapter): before_sleep=before_sleep_log(logger, logging.DEBUG), reraise=True, ) - async def create_transcript(self, input, **kwargs) -> Optional[TranscriptionReturnType]: + async def create_transcript(self, input, **kwargs) -> TranscriptionReturnType: """ Generate an audio transcript from a user query. @@ -232,9 +234,6 @@ class OpenAIAdapter(GenericAPIAdapter): max_retries=self.MAX_RETRIES, **kwargs, ) - if transcription: - return TranscriptionReturnType(transcription.text, transcription) - - return None + return TranscriptionReturnType(transcription.text, transcription) # transcribe_image is inherited from GenericAPIAdapter From 7892b48afe08ea338a753cb6ddc45e48fb884ff7 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Tue, 16 Dec 2025 15:59:15 +0100 Subject: [PATCH 064/176] Update test_get_search_type_tools.py --- .../unit/modules/search/test_get_search_type_tools.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/cognee/tests/unit/modules/search/test_get_search_type_tools.py b/cognee/tests/unit/modules/search/test_get_search_type_tools.py index 15b489bfa..3748a4e4b 100644 --- a/cognee/tests/unit/modules/search/test_get_search_type_tools.py +++ b/cognee/tests/unit/modules/search/test_get_search_type_tools.py @@ -93,7 +93,7 @@ async def test_unknown_query_type_raises_unsupported(): import cognee.modules.search.methods.get_search_type_tools as mod with pytest.raises(UnsupportedSearchTypeError, match="UNKNOWN_TYPE"): - await mod.get_search_type_tools("UNKNOWN_TYPE", query_text="q") # type: ignore[arg-type] + await mod.get_search_type_tools("UNKNOWN_TYPE", query_text="q") @pytest.mark.asyncio @@ -130,7 +130,7 @@ async def test_coding_rules_uses_node_name_as_rules_nodeset_name(): assert len(tools) == 1 assert tools[0].__name__ == "get_existing_rules" assert tools[0].__self__.__class__ is CodingRulesRetriever - # Empty list should default to ["coding_agent_rules"] + assert tools[0].__self__.rules_nodeset_name == ["coding_agent_rules"] @@ -186,7 +186,6 @@ async def test_tool_construction_for_supported_search_types( ): import cognee.modules.search.methods.get_search_type_tools as mod - # Natural language is guarded by ALLOW_CYPHER_QUERY too monkeypatch.setenv("ALLOW_CYPHER_QUERY", "true") tools = await mod.get_search_type_tools(query_type, query_text="q") @@ -217,5 +216,5 @@ async def test_some_completion_tools_are_callable_without_backends(monkeypatch): ]: tools = await mod.get_search_type_tools(query_type, query_text="q") completion = tools[0] - result = await completion("q", context=["ok"]) # type: ignore[call-arg] + result = await completion("q", context=["ok"]) assert result == ["ok"] From d92d6b9d8f4062424997f0a524f22fcf70a84edc Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Tue, 16 Dec 2025 16:02:15 +0100 Subject: [PATCH 065/176] refactor: remove optional return value --- .../litellm_instructor/llm/mistral/adapter.py | 5 ++--- .../litellm_instructor/llm/openai/adapter.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py index e3d0b72da..a52bbe281 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py @@ -162,6 +162,5 @@ class MistralAdapter(GenericAPIAdapter): "file_name": file_name, }, ) - if transcription_response: - return TranscriptionReturnType(transcription_response.text, transcription_response) - return None + + return TranscriptionReturnType(transcription_response.text, transcription_response) diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py index d708f03e8..582c3a08f 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py @@ -1,6 +1,6 @@ import litellm import instructor -from typing import Type, Optional +from typing import Type from pydantic import BaseModel from openai import ContentFilterFinishReasonError from litellm.exceptions import ContentPolicyViolationError From f2cb68dd5e6ddf3aae34070e789a01547335adde Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Tue, 16 Dec 2025 16:27:13 +0100 Subject: [PATCH 066/176] refactor: use async image and transcription handling --- .../litellm_instructor/llm/generic_llm_api/adapter.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py index dc30b7310..4fd7b45c1 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py @@ -27,7 +27,9 @@ from tenacity import ( before_sleep_log, ) -from ..types import TranscriptionReturnType +from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.types import ( + TranscriptionReturnType, +) logger = get_logger() observe = get_observe() @@ -216,7 +218,7 @@ class GenericAPIAdapter(LLMInterface): raise ValueError( f"Could not determine MIME type for audio file: {input}. Is the extension correct?" ) - response = litellm.completion( + response = await litellm.acompletion( model=self.transcription_model, messages=[ { @@ -270,7 +272,7 @@ class GenericAPIAdapter(LLMInterface): raise ValueError( f"Could not determine MIME type for image file: {input}. Is the extension correct?" ) - return litellm.completion( + response = await litellm.acompletion( model=self.image_transcribe_model, messages=[ { @@ -295,3 +297,4 @@ class GenericAPIAdapter(LLMInterface): max_completion_tokens=300, max_retries=self.MAX_RETRIES, ) + return response From 3e041ec12f1f6ed8358f26d3d4b3aeeea6f2db0e Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Tue, 16 Dec 2025 16:28:30 +0100 Subject: [PATCH 067/176] refactor: format code --- .../litellm_instructor/llm/types.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/types.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/types.py index 887cdd88d..fc850830d 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/types.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/types.py @@ -1,9 +1,10 @@ from pydantic import BaseModel + class TranscriptionReturnType: text: str payload: BaseModel - def __init__(self, text:str, payload: BaseModel): + def __init__(self, text: str, payload: BaseModel): self.text = text - self.payload = payload \ No newline at end of file + self.payload = payload From 789fa9079052b0b04274cd992a294c98b3668de2 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Tue, 16 Dec 2025 16:39:31 +0100 Subject: [PATCH 068/176] chore: covering search.py behavior with unit tests --- .../tests/unit/modules/search/test_search.py | 465 ++++++++++++++++++ 1 file changed, 465 insertions(+) create mode 100644 cognee/tests/unit/modules/search/test_search.py diff --git a/cognee/tests/unit/modules/search/test_search.py b/cognee/tests/unit/modules/search/test_search.py new file mode 100644 index 000000000..9c36f07da --- /dev/null +++ b/cognee/tests/unit/modules/search/test_search.py @@ -0,0 +1,465 @@ +import types +from uuid import uuid4 + +import pytest + +from cognee.modules.search.types import SearchType + + +def _make_user(user_id: str = "u1", tenant_id=None): + return types.SimpleNamespace(id=user_id, tenant_id=tenant_id) + + +def _make_dataset(*, name="ds", tenant_id="t1", dataset_id=None, owner_id=None): + return types.SimpleNamespace( + id=dataset_id or uuid4(), + name=name, + tenant_id=tenant_id, + owner_id=owner_id or uuid4(), + ) + + +@pytest.fixture +def search_mod(): + import importlib + + return importlib.import_module("cognee.modules.search.methods.search") + + +@pytest.fixture(autouse=True) +def _patch_side_effect_boundaries(monkeypatch, search_mod): + """ + Keep production logic; patch only unavoidable side-effect boundaries. + """ + + async def dummy_log_query(_query_text, _query_type, _user_id): + return types.SimpleNamespace(id="qid-1") + + async def dummy_log_result(*_args, **_kwargs): + return None + + async def dummy_prepare_search_result(search_result): + # search() and helpers mostly exchange tuples: (result, context, datasets) + if isinstance(search_result, tuple) and len(search_result) == 3: + result, context, datasets = search_result + return {"result": result, "context": context, "graphs": {}, "datasets": datasets} + return {"result": None, "context": None, "graphs": {}, "datasets": []} + + monkeypatch.setattr(search_mod, "send_telemetry", lambda *a, **k: None) + monkeypatch.setattr(search_mod, "log_query", dummy_log_query) + monkeypatch.setattr(search_mod, "log_result", dummy_log_result) + monkeypatch.setattr(search_mod, "prepare_search_result", dummy_prepare_search_result) + + yield + + +@pytest.mark.asyncio +async def test_search_no_access_control_flattens_single_list_result(monkeypatch, search_mod): + user = _make_user() + + async def dummy_no_access_control_search(**_kwargs): + return (["r"], ["ctx"], []) + + monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: False) + monkeypatch.setattr(search_mod, "no_access_control_search", dummy_no_access_control_search) + + out = await search_mod.search( + query_text="q", + query_type=SearchType.CHUNKS, + dataset_ids=None, + user=user, + ) + + assert out == ["r"] + + +@pytest.mark.asyncio +async def test_search_no_access_control_non_list_result_returns_list(monkeypatch, search_mod): + """ + Covers the non-flattening back-compat branch in `search()`: if the single returned result is + not a list, `search()` returns a list of results instead of flattening. + """ + user = _make_user() + + async def dummy_no_access_control_search(**_kwargs): + return ("r", ["ctx"], []) + + monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: False) + monkeypatch.setattr(search_mod, "no_access_control_search", dummy_no_access_control_search) + + out = await search_mod.search( + query_text="q", + query_type=SearchType.CHUNKS, + dataset_ids=None, + user=user, + ) + + assert out == ["r"] + + +@pytest.mark.asyncio +async def test_search_no_access_control_only_context_returns_context(monkeypatch, search_mod): + user = _make_user() + + async def dummy_no_access_control_search(**_kwargs): + return (None, ["ctx"], []) + + monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: False) + monkeypatch.setattr(search_mod, "no_access_control_search", dummy_no_access_control_search) + + out = await search_mod.search( + query_text="q", + query_type=SearchType.CHUNKS, + dataset_ids=None, + user=user, + only_context=True, + ) + + assert out == ["ctx"] + + +@pytest.mark.asyncio +async def test_search_access_control_returns_dataset_shaped_dicts(monkeypatch, search_mod): + user = _make_user() + ds = _make_dataset(name="ds1", tenant_id="t1") + + async def dummy_authorized_search(**kwargs): + assert kwargs["dataset_ids"] == [ds.id] + return [("r", ["ctx"], [ds])] + + monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True) + monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search) + + out = await search_mod.search( + query_text="q", + query_type=SearchType.CHUNKS, + dataset_ids=[ds.id], + user=user, + ) + + assert out == [ + { + "search_result": ["r"], + "dataset_id": ds.id, + "dataset_name": "ds1", + "dataset_tenant_id": "t1", + "graphs": {}, + } + ] + + +@pytest.mark.asyncio +async def test_search_access_control_only_context_returns_dataset_shaped_dicts( + monkeypatch, search_mod +): + user = _make_user() + ds = _make_dataset(name="ds1", tenant_id="t1") + + async def dummy_authorized_search(**_kwargs): + return [(None, "ctx", [ds])] + + monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True) + monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search) + + out = await search_mod.search( + query_text="q", + query_type=SearchType.CHUNKS, + dataset_ids=[ds.id], + user=user, + only_context=True, + ) + + assert out == [ + { + "search_result": ["ctx"], + "dataset_id": ds.id, + "dataset_name": "ds1", + "dataset_tenant_id": "t1", + "graphs": {}, + } + ] + + +@pytest.mark.asyncio +async def test_search_access_control_use_combined_context_returns_combined_model( + monkeypatch, search_mod +): + user = _make_user() + ds1 = _make_dataset(name="ds1", tenant_id="t1") + ds2 = _make_dataset(name="ds2", tenant_id="t1") + + async def dummy_authorized_search(**_kwargs): + return ("answer", {"k": "v"}, [ds1, ds2]) + + monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True) + monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search) + + out = await search_mod.search( + query_text="q", + query_type=SearchType.CHUNKS, + dataset_ids=[ds1.id, ds2.id], + user=user, + use_combined_context=True, + ) + + assert out.result == "answer" + assert out.context == {"k": "v"} + assert out.graphs == {} + assert [d.id for d in out.datasets] == [ds1.id, ds2.id] + + +@pytest.mark.asyncio +async def test_authorized_search_non_combined_delegates(monkeypatch, search_mod): + user = _make_user() + ds = _make_dataset(name="ds1") + + async def dummy_get_authorized_existing_datasets(*_args, **_kwargs): + return [ds] + + expected = [("r", ["ctx"], [ds])] + + async def dummy_search_in_datasets_context(**kwargs): + assert kwargs["use_combined_context"] is False if "use_combined_context" in kwargs else True + return expected + + monkeypatch.setattr( + search_mod, "get_authorized_existing_datasets", dummy_get_authorized_existing_datasets + ) + monkeypatch.setattr(search_mod, "search_in_datasets_context", dummy_search_in_datasets_context) + + out = await search_mod.authorized_search( + query_type=SearchType.CHUNKS, + query_text="q", + user=user, + dataset_ids=[ds.id], + use_combined_context=False, + only_context=False, + ) + + assert out == expected + + +@pytest.mark.asyncio +async def test_authorized_search_use_combined_context_joins_string_context(monkeypatch, search_mod): + user = _make_user() + ds1 = _make_dataset(name="ds1") + ds2 = _make_dataset(name="ds2") + + async def dummy_get_authorized_existing_datasets(*_args, **_kwargs): + return [ds1, ds2] + + async def dummy_search_in_datasets_context(**kwargs): + assert kwargs["only_context"] is True + return [(None, ["a"], [ds1]), (None, ["b"], [ds2])] + + seen = {} + + async def dummy_get_completion(query_text, context, session_id=None): + seen["query_text"] = query_text + seen["context"] = context + seen["session_id"] = session_id + return ["answer"] + + async def dummy_get_search_type_tools(**_kwargs): + return [dummy_get_completion, lambda *_a, **_k: None] + + monkeypatch.setattr( + search_mod, "get_authorized_existing_datasets", dummy_get_authorized_existing_datasets + ) + monkeypatch.setattr(search_mod, "search_in_datasets_context", dummy_search_in_datasets_context) + monkeypatch.setattr(search_mod, "get_search_type_tools", dummy_get_search_type_tools) + + completion, combined_context, datasets = await search_mod.authorized_search( + query_type=SearchType.CHUNKS, + query_text="q", + user=user, + dataset_ids=[ds1.id, ds2.id], + use_combined_context=True, + session_id="s1", + ) + + assert combined_context == "a\nb" + assert completion == ["answer"] + assert datasets == [ds1, ds2] + assert seen == {"query_text": "q", "context": "a\nb", "session_id": "s1"} + + +@pytest.mark.asyncio +async def test_authorized_search_use_combined_context_keeps_non_string_context( + monkeypatch, search_mod +): + user = _make_user() + ds1 = _make_dataset(name="ds1") + ds2 = _make_dataset(name="ds2") + + class DummyEdge: + pass + + e1, e2 = DummyEdge(), DummyEdge() + + async def dummy_get_authorized_existing_datasets(*_args, **_kwargs): + return [ds1, ds2] + + async def dummy_search_in_datasets_context(**_kwargs): + return [(None, [e1], [ds1]), (None, [e2], [ds2])] + + async def dummy_get_completion(query_text, context, session_id=None): + assert query_text == "q" + assert context == [e1, e2] + return ["answer"] + + async def dummy_get_search_type_tools(**_kwargs): + return [dummy_get_completion] + + monkeypatch.setattr( + search_mod, "get_authorized_existing_datasets", dummy_get_authorized_existing_datasets + ) + monkeypatch.setattr(search_mod, "search_in_datasets_context", dummy_search_in_datasets_context) + monkeypatch.setattr(search_mod, "get_search_type_tools", dummy_get_search_type_tools) + + completion, combined_context, datasets = await search_mod.authorized_search( + query_type=SearchType.CHUNKS, + query_text="q", + user=user, + dataset_ids=[ds1.id, ds2.id], + use_combined_context=True, + ) + + assert combined_context == [e1, e2] + assert completion == ["answer"] + assert datasets == [ds1, ds2] + + +@pytest.mark.asyncio +async def test_search_in_datasets_context_two_tool_context_override_and_is_empty_branches( + monkeypatch, search_mod +): + ds1 = _make_dataset(name="ds1") + ds2 = _make_dataset(name="ds2") + + async def dummy_set_database_global_context_variables(*_args, **_kwargs): + return None + + class DummyGraphEngine: + async def is_empty(self): + return True + + async def dummy_get_graph_engine(): + return DummyGraphEngine() + + async def dummy_get_dataset_data(dataset_id): + return [1] if dataset_id == ds1.id else [] + + calls = {"completion": 0, "context": 0} + + async def dummy_get_context(_query_text: str): + calls["context"] += 1 + return ["ctx"] + + async def dummy_get_completion(_query_text: str, _context, session_id=None): + calls["completion"] += 1 + assert session_id == "s1" + return ["r"] + + async def dummy_get_search_type_tools(**_kwargs): + return [dummy_get_completion, dummy_get_context] + + monkeypatch.setattr( + search_mod, + "set_database_global_context_variables", + dummy_set_database_global_context_variables, + ) + monkeypatch.setattr(search_mod, "get_graph_engine", dummy_get_graph_engine) + monkeypatch.setattr(search_mod, "get_search_type_tools", dummy_get_search_type_tools) + monkeypatch.setattr("cognee.modules.data.methods.get_dataset_data", dummy_get_dataset_data) + + out = await search_mod.search_in_datasets_context( + search_datasets=[ds1, ds2], + query_type=SearchType.CHUNKS, + query_text="q", + context=["pre_ctx"], + session_id="s1", + ) + + assert out == [(["r"], ["pre_ctx"], [ds1]), (["r"], ["pre_ctx"], [ds2])] + assert calls == {"completion": 2, "context": 0} + + +@pytest.mark.asyncio +async def test_search_in_datasets_context_two_tool_only_context_true(monkeypatch, search_mod): + ds = _make_dataset(name="ds1") + + async def dummy_set_database_global_context_variables(*_args, **_kwargs): + return None + + class DummyGraphEngine: + async def is_empty(self): + return False + + async def dummy_get_graph_engine(): + return DummyGraphEngine() + + async def dummy_get_context(query_text: str): + assert query_text == "q" + return ["ctx"] + + async def dummy_get_completion(*_args, **_kwargs): + raise AssertionError("Completion should not be called when only_context=True") + + async def dummy_get_search_type_tools(**_kwargs): + return [dummy_get_completion, dummy_get_context] + + monkeypatch.setattr( + search_mod, + "set_database_global_context_variables", + dummy_set_database_global_context_variables, + ) + monkeypatch.setattr(search_mod, "get_graph_engine", dummy_get_graph_engine) + monkeypatch.setattr(search_mod, "get_search_type_tools", dummy_get_search_type_tools) + + out = await search_mod.search_in_datasets_context( + search_datasets=[ds], + query_type=SearchType.CHUNKS, + query_text="q", + only_context=True, + ) + + assert out == [(None, ["ctx"], [ds])] + + +@pytest.mark.asyncio +async def test_search_in_datasets_context_unknown_tool_path(monkeypatch, search_mod): + ds = _make_dataset(name="ds1") + + async def dummy_set_database_global_context_variables(*_args, **_kwargs): + return None + + class DummyGraphEngine: + async def is_empty(self): + return False + + async def dummy_get_graph_engine(): + return DummyGraphEngine() + + async def dummy_unknown_tool(query_text: str): + assert query_text == "q" + return ["u"] + + async def dummy_get_search_type_tools(**_kwargs): + return [dummy_unknown_tool] + + monkeypatch.setattr( + search_mod, + "set_database_global_context_variables", + dummy_set_database_global_context_variables, + ) + monkeypatch.setattr(search_mod, "get_graph_engine", dummy_get_graph_engine) + monkeypatch.setattr(search_mod, "get_search_type_tools", dummy_get_search_type_tools) + + out = await search_mod.search_in_datasets_context( + search_datasets=[ds], + query_type=SearchType.CODING_RULES, + query_text="q", + ) + + assert out == [(["u"], "", [ds])] From f27d07d902a5a15b2ba1847f1aefc9221179d1d0 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Tue, 16 Dec 2025 16:53:56 +0100 Subject: [PATCH 069/176] refactor: remove mandatory transcription and image methods in LLMInterface --- .../litellm_instructor/llm/llm_interface.py | 38 ------------------- 1 file changed, 38 deletions(-) diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py index ae4fb25b0..da538aad8 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py @@ -14,8 +14,6 @@ class LLMInterface(Protocol): Methods: - acreate_structured_output(text_input: str, system_prompt: str, response_model: Type[BaseModel]) - - create_transcript(input): Transcribe audio files to text - - transcribe_image(input): Analyze image files and return text description """ @abstractmethod @@ -37,39 +35,3 @@ class LLMInterface(Protocol): output. """ raise NotImplementedError - - @abstractmethod - async def create_transcript(self, input) -> TranscriptionReturnType: - """ - Transcribe audio content to text. - - This method should be implemented by subclasses that support audio transcription. - If not implemented, returns None and should be handled gracefully by callers. - - Parameters: - ----------- - - input: The path to the audio file that needs to be transcribed. - - Returns: - -------- - - BaseModel: A structured output containing the transcription, or None if not supported. - """ - raise NotImplementedError - - @abstractmethod - async def transcribe_image(self, input) -> BaseModel: - """ - Analyze image content and return text description. - - This method should be implemented by subclasses that support image analysis. - If not implemented, returns None and should be handled gracefully by callers. - - Parameters: - ----------- - - input: The path to the image file that needs to be analyzed. - - Returns: - -------- - - BaseModel: A structured output containing the image description, or None if not supported. - """ - raise NotImplementedError From 8027263e8b4c209c0c73cba784e28f805fd90363 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Tue, 16 Dec 2025 16:54:27 +0100 Subject: [PATCH 070/176] refactor: remove unused import --- .../litellm_instructor/llm/llm_interface.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py index da538aad8..6afd4138c 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py @@ -3,9 +3,6 @@ from typing import Type, Protocol from abc import abstractmethod from pydantic import BaseModel -from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.types import ( - TranscriptionReturnType, -) class LLMInterface(Protocol): From 4ff2a35476a3707630c3fbc783a5357907183b26 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Tue, 16 Dec 2025 17:33:20 +0100 Subject: [PATCH 071/176] chore: moves unit tests into their correct directory --- .../{integration => unit/modules}/retrieval/test_completion.py | 0 .../modules}/retrieval/test_graph_summary_completion_retriever.py | 0 .../modules}/retrieval/test_user_qa_feedback.py | 0 3 files changed, 0 insertions(+), 0 deletions(-) rename cognee/tests/{integration => unit/modules}/retrieval/test_completion.py (100%) rename cognee/tests/{integration => unit/modules}/retrieval/test_graph_summary_completion_retriever.py (100%) rename cognee/tests/{integration => unit/modules}/retrieval/test_user_qa_feedback.py (100%) diff --git a/cognee/tests/integration/retrieval/test_completion.py b/cognee/tests/unit/modules/retrieval/test_completion.py similarity index 100% rename from cognee/tests/integration/retrieval/test_completion.py rename to cognee/tests/unit/modules/retrieval/test_completion.py diff --git a/cognee/tests/integration/retrieval/test_graph_summary_completion_retriever.py b/cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py similarity index 100% rename from cognee/tests/integration/retrieval/test_graph_summary_completion_retriever.py rename to cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py diff --git a/cognee/tests/integration/retrieval/test_user_qa_feedback.py b/cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py similarity index 100% rename from cognee/tests/integration/retrieval/test_user_qa_feedback.py rename to cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py From 18d0a418505dd885d6ab7e2f8efd7a33091471be Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Tue, 16 Dec 2025 17:49:43 +0100 Subject: [PATCH 072/176] Update test_search.py --- cognee/tests/unit/modules/search/test_search.py | 1 - 1 file changed, 1 deletion(-) diff --git a/cognee/tests/unit/modules/search/test_search.py b/cognee/tests/unit/modules/search/test_search.py index 9c36f07da..175fd9aa4 100644 --- a/cognee/tests/unit/modules/search/test_search.py +++ b/cognee/tests/unit/modules/search/test_search.py @@ -39,7 +39,6 @@ def _patch_side_effect_boundaries(monkeypatch, search_mod): return None async def dummy_prepare_search_result(search_result): - # search() and helpers mostly exchange tuples: (result, context, datasets) if isinstance(search_result, tuple) and len(search_result) == 3: result, context, datasets = search_result return {"result": result, "context": context, "graphs": {}, "datasets": datasets} From b77961b0f178e985c664e22ccca8e4f40e76f456 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Tue, 16 Dec 2025 20:59:17 +0100 Subject: [PATCH 073/176] fix: Resolve issues with data label PR, add tests and upgrade migration --- .github/workflows/e2e_tests.yml | 25 +++++++ .../a1b2c3d4e5f6_add_label_column_to_data.py | 23 +++++-- cognee/api/v1/add/add.py | 3 +- .../ingestion/save_data_item_to_storage.py | 5 ++ cognee/tests/test_custom_data_label.py | 68 +++++++++++++++++++ 5 files changed, 117 insertions(+), 7 deletions(-) create mode 100644 cognee/tests/test_custom_data_label.py diff --git a/.github/workflows/e2e_tests.yml b/.github/workflows/e2e_tests.yml index 8cd62910c..5f5828da8 100644 --- a/.github/workflows/e2e_tests.yml +++ b/.github/workflows/e2e_tests.yml @@ -315,6 +315,31 @@ jobs: EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} run: uv run python ./cognee/tests/test_multi_tenancy.py + test-data-label: + name: Test adding of label for data in Cognee + runs-on: ubuntu-22.04 + steps: + - name: Check out repository + uses: actions/checkout@v4 + + - name: Cognee Setup + uses: ./.github/actions/cognee_setup + with: + python-version: '3.11.x' + + - name: Run custom data label test + env: + ENV: 'dev' + LLM_MODEL: ${{ secrets.LLM_MODEL }} + LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} + LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} + EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} + EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} + EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} + EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} + run: uv run python ./cognee/tests/test_custom_data_label.py + test-graph-edges: name: Test graph edge ingestion runs-on: ubuntu-22.04 diff --git a/alembic/versions/a1b2c3d4e5f6_add_label_column_to_data.py b/alembic/versions/a1b2c3d4e5f6_add_label_column_to_data.py index 814467954..c127e078b 100644 --- a/alembic/versions/a1b2c3d4e5f6_add_label_column_to_data.py +++ b/alembic/versions/a1b2c3d4e5f6_add_label_column_to_data.py @@ -13,15 +13,26 @@ import sqlalchemy as sa # revision identifiers, used by Alembic. revision: str = "a1b2c3d4e5f6" -down_revision: Union[str, None] = "211ab850ef3d" +down_revision: Union[str, None] = "46a6ce2bd2b2" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None + +def _get_column(inspector, table, name, schema=None): + for col in inspector.get_columns(table, schema=schema): + if col["name"] == name: + return col + return None + + def upgrade() -> None: - op.add_column( - "data", - sa.Column("label", sa.String(), nullable=True) - ) + conn = op.get_bind() + insp = sa.inspect(conn) + + label_column = _get_column(insp, "data", "label") + if not label_column: + op.add_column("data", sa.Column("label", sa.String(), nullable=True)) + def downgrade() -> None: - op.drop_column("data", "label") \ No newline at end of file + op.drop_column("data", "label") diff --git a/cognee/api/v1/add/add.py b/cognee/api/v1/add/add.py index 90ea32ae7..3b355f284 100644 --- a/cognee/api/v1/add/add.py +++ b/cognee/api/v1/add/add.py @@ -10,13 +10,14 @@ from cognee.modules.pipelines.layers.reset_dataset_pipeline_run_status import ( ) from cognee.modules.engine.operations.setup import setup from cognee.tasks.ingestion import ingest_data, resolve_data_directories +from cognee.tasks.ingestion.data_item import DataItem from cognee.shared.logging_utils import get_logger logger = get_logger() async def add( - data: Union[BinaryIO, list[BinaryIO], str, list[str]], + data: Union[BinaryIO, list[BinaryIO], str, list[str], DataItem, list[DataItem]], dataset_name: str = "main_dataset", user: User = None, node_set: Optional[List[str]] = None, diff --git a/cognee/tasks/ingestion/save_data_item_to_storage.py b/cognee/tasks/ingestion/save_data_item_to_storage.py index 05d21e617..85eef2736 100644 --- a/cognee/tasks/ingestion/save_data_item_to_storage.py +++ b/cognee/tasks/ingestion/save_data_item_to_storage.py @@ -9,6 +9,7 @@ from cognee.shared.logging_utils import get_logger from pydantic_settings import BaseSettings, SettingsConfigDict from cognee.tasks.web_scraper.utils import fetch_page_content +from cognee.tasks.ingestion.data_item import DataItem logger = get_logger() @@ -95,5 +96,9 @@ async def save_data_item_to_storage(data_item: Union[BinaryIO, str, Any]) -> str # data is text, save it to data storage and return the file path return await save_data_to_file(data_item) + if isinstance(data_item, DataItem): + # If instance is DataItem use the underlying data + return await save_data_item_to_storage(data_item.data) + # data is not a supported type raise IngestionError(message=f"Data type not supported: {type(data_item)}") diff --git a/cognee/tests/test_custom_data_label.py b/cognee/tests/test_custom_data_label.py new file mode 100644 index 000000000..0dab1cbd7 --- /dev/null +++ b/cognee/tests/test_custom_data_label.py @@ -0,0 +1,68 @@ +import asyncio +import cognee +from cognee.shared.logging_utils import setup_logging, ERROR +from cognee.api.v1.search import SearchType + + +async def main(): + # Create a clean slate for cognee -- reset data and system state + print("Resetting cognee data...") + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + print("Data reset complete.\n") + + # cognee knowledge graph will be created based on this text + text = """ + Natural language processing (NLP) is an interdisciplinary + subfield of computer science and information retrieval. + """ + from cognee.tasks.ingestion.data_item import DataItem + + test_item = DataItem(text, "test_item") + # Add the text, and make it available for cognify + await cognee.add(test_item) + + # Use LLMs and cognee to create knowledge graph + ret_val = await cognee.cognify() + + query_text = "Tell me about NLP" + print(f"Searching cognee for insights with query: '{query_text}'") + # Query cognee for insights on the added text + search_results = await cognee.search( + query_type=SearchType.GRAPH_COMPLETION, query_text=query_text + ) + + print("Search results:") + # Display results + for result_text in search_results: + print(result_text) + + from cognee.modules.data.methods.get_dataset_data import get_dataset_data + + for pipeline in ret_val.values(): + dataset_id = pipeline.dataset_id + + dataset_data = await get_dataset_data(dataset_id=dataset_id) + + from fastapi.encoders import jsonable_encoder + + data = [ + dict( + **jsonable_encoder(data), + dataset_id=dataset_id, + ) + for data in dataset_data + ] + + # Check if label is properly added and stored + assert data[0]["label"] == "test_item" + + +if __name__ == "__main__": + logger = setup_logging(log_level=ERROR) + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(main()) + finally: + loop.run_until_complete(loop.shutdown_asyncgens()) From cc872fc8de506e8ff0dc635fd5043db6a6f74fac Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Tue, 16 Dec 2025 21:04:15 +0100 Subject: [PATCH 074/176] refactor: format PR --- cognee/tasks/ingestion/data_item.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cognee/tasks/ingestion/data_item.py b/cognee/tasks/ingestion/data_item.py index 23285d677..da213ed1c 100644 --- a/cognee/tasks/ingestion/data_item.py +++ b/cognee/tasks/ingestion/data_item.py @@ -1,8 +1,8 @@ from dataclasses import dataclass from typing import Any, Optional + @dataclass class DataItem: data: Any label: Optional[str] = None - From 94d5175570a2358b242feb76529577c5ee6024e2 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Wed, 17 Dec 2025 10:34:57 +0100 Subject: [PATCH 075/176] feat: adds unit test for the prepare search result - search contract --- ...t_search_prepare_search_result_contract.py | 296 ++++++++++++++++++ 1 file changed, 296 insertions(+) create mode 100644 cognee/tests/unit/modules/search/test_search_prepare_search_result_contract.py diff --git a/cognee/tests/unit/modules/search/test_search_prepare_search_result_contract.py b/cognee/tests/unit/modules/search/test_search_prepare_search_result_contract.py new file mode 100644 index 000000000..8700e6a1b --- /dev/null +++ b/cognee/tests/unit/modules/search/test_search_prepare_search_result_contract.py @@ -0,0 +1,296 @@ +## The Objective of these tests is to cover the search - prepare search results behavior (later to be removed) + +import types +from uuid import uuid4 + +import pytest +from pydantic import BaseModel + +from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge, Node +from cognee.modules.search.types import SearchType + + +class DummyDataset(BaseModel): + id: object + name: str + tenant_id: str | None = None + owner_id: object + + +def _ds(name="ds1", tenant_id="t1"): + return DummyDataset(id=uuid4(), name=name, tenant_id=tenant_id, owner_id=uuid4()) + + +def _edge(rel="rel", n1="A", n2="B"): + node1 = Node(str(uuid4()), attributes={"type": "Entity", "name": n1}) + node2 = Node(str(uuid4()), attributes={"type": "Entity", "name": n2}) + return Edge(node1, node2, attributes={"relationship_name": rel}) + + +@pytest.fixture +def search_mod(): + import importlib + + return importlib.import_module("cognee.modules.search.methods.search") + + +@pytest.fixture(autouse=True) +def _patch_search_side_effects(monkeypatch, search_mod): + """ + These tests validate prepare_search_result behavior *through* search.py. + We only patch unavoidable side effects (telemetry + query/result logging). + """ + + async def dummy_log_query(_query_text, _query_type, _user_id): + return types.SimpleNamespace(id="qid-1") + + async def dummy_log_result(*_args, **_kwargs): + return None + + monkeypatch.setattr(search_mod, "send_telemetry", lambda *a, **k: None) + monkeypatch.setattr(search_mod, "log_query", dummy_log_query) + monkeypatch.setattr(search_mod, "log_result", dummy_log_result) + + yield + + +@pytest.fixture(autouse=True) +def _patch_resolve_edges_to_text(monkeypatch): + """ + Keep graph-text conversion deterministic and lightweight. + """ + import importlib + + psr_mod = importlib.import_module("cognee.modules.search.utils.prepare_search_result") + + async def dummy_resolve_edges_to_text(_edges): + return "EDGE_TEXT" + + monkeypatch.setattr(psr_mod, "resolve_edges_to_text", dummy_resolve_edges_to_text) + + yield + + +@pytest.mark.asyncio +async def test_search_access_control_edges_context_produces_graphs_and_context_map( + monkeypatch, search_mod +): + user = types.SimpleNamespace(id="u1", tenant_id=None) + ds = _ds("ds1", "t1") + context = [_edge("likes")] + + async def dummy_authorized_search(**_kwargs): + return [(["answer"], context, [ds])] + + monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True) + monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search) + + out = await search_mod.search( + query_text="q", + query_type=SearchType.CHUNKS, + dataset_ids=[ds.id], + user=user, + ) + + assert out[0]["dataset_name"] == "ds1" + assert out[0]["dataset_tenant_id"] == "t1" + assert out[0]["graphs"] is not None + assert "ds1" in out[0]["graphs"] + assert out[0]["graphs"]["ds1"]["nodes"] + assert out[0]["graphs"]["ds1"]["edges"] + assert out[0]["search_result"] == ["answer"] + + +@pytest.mark.asyncio +async def test_search_access_control_insights_context_produces_graphs_and_null_result( + monkeypatch, search_mod +): + user = types.SimpleNamespace(id="u1", tenant_id=None) + ds = _ds("ds1", "t1") + insights = [ + ( + {"id": "n1", "type": "Entity", "name": "Alice"}, + {"relationship_name": "knows"}, + {"id": "n2", "type": "Entity", "name": "Bob"}, + ) + ] + + async def dummy_authorized_search(**_kwargs): + return [(["something"], insights, [ds])] + + monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True) + monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search) + + out = await search_mod.search( + query_text="q", + query_type=SearchType.CHUNKS, + dataset_ids=[ds.id], + user=user, + ) + + assert out[0]["graphs"] is not None + assert "ds1" in out[0]["graphs"] + assert out[0]["search_result"] is None + + +@pytest.mark.asyncio +async def test_search_access_control_only_context_returns_context_text_map(monkeypatch, search_mod): + user = types.SimpleNamespace(id="u1", tenant_id=None) + ds = _ds("ds1", "t1") + + async def dummy_authorized_search(**_kwargs): + return [(None, ["a", "b"], [ds])] + + monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True) + monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search) + + out = await search_mod.search( + query_text="q", + query_type=SearchType.CHUNKS, + dataset_ids=[ds.id], + user=user, + only_context=True, + ) + + assert out[0]["search_result"] == [{"ds1": "a\nb"}] + + +@pytest.mark.asyncio +async def test_search_access_control_results_edges_become_graph_result(monkeypatch, search_mod): + user = types.SimpleNamespace(id="u1", tenant_id=None) + ds = _ds("ds1", "t1") + results = [_edge("connected_to")] + + async def dummy_authorized_search(**_kwargs): + return [(results, "ctx", [ds])] + + monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True) + monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search) + + out = await search_mod.search( + query_text="q", + query_type=SearchType.CHUNKS, + dataset_ids=[ds.id], + user=user, + ) + + assert isinstance(out[0]["search_result"][0], dict) + assert "nodes" in out[0]["search_result"][0] + assert "edges" in out[0]["search_result"][0] + + +@pytest.mark.asyncio +async def test_search_use_combined_context_defaults_empty_datasets(monkeypatch, search_mod): + user = types.SimpleNamespace(id="u1", tenant_id=None) + + async def dummy_authorized_search(**_kwargs): + return ("answer", "ctx", []) + + monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True) + monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search) + + out = await search_mod.search( + query_text="q", + query_type=SearchType.CHUNKS, + dataset_ids=None, + user=user, + use_combined_context=True, + ) + + assert out.result == "answer" + assert out.context == {"all available datasets": "ctx"} + assert out.datasets[0].name == "all available datasets" + + +@pytest.mark.asyncio +async def test_search_access_control_context_str_branch(monkeypatch, search_mod): + """Covers prepare_search_result(context is str) through search().""" + user = types.SimpleNamespace(id="u1", tenant_id=None) + ds = _ds("ds1", "t1") + + async def dummy_authorized_search(**_kwargs): + return [(["answer"], "plain context", [ds])] + + monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True) + monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search) + + out = await search_mod.search( + query_text="q", + query_type=SearchType.CHUNKS, + dataset_ids=[ds.id], + user=user, + ) + + assert out[0]["graphs"] is None + assert out[0]["search_result"] == ["answer"] + + +@pytest.mark.asyncio +async def test_search_access_control_context_empty_list_branch(monkeypatch, search_mod): + """Covers prepare_search_result(context is empty list) through search().""" + user = types.SimpleNamespace(id="u1", tenant_id=None) + ds = _ds("ds1", "t1") + + async def dummy_authorized_search(**_kwargs): + return [(["answer"], [], [ds])] + + monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True) + monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search) + + out = await search_mod.search( + query_text="q", + query_type=SearchType.CHUNKS, + dataset_ids=[ds.id], + user=user, + ) + + assert out[0]["graphs"] is None + assert out[0]["search_result"] == ["answer"] + + +@pytest.mark.asyncio +async def test_search_access_control_multiple_results_list_branch(monkeypatch, search_mod): + """Covers prepare_search_result(result list length > 1) through search().""" + user = types.SimpleNamespace(id="u1", tenant_id=None) + ds = _ds("ds1", "t1") + + async def dummy_authorized_search(**_kwargs): + return [(["r1", "r2"], "ctx", [ds])] + + monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True) + monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search) + + out = await search_mod.search( + query_text="q", + query_type=SearchType.CHUNKS, + dataset_ids=[ds.id], + user=user, + ) + + assert out[0]["search_result"] == [["r1", "r2"]] + + +@pytest.mark.asyncio +async def test_search_access_control_defaults_empty_datasets(monkeypatch, search_mod): + """ + Covers prepare_search_result(datasets empty list) through search(). + + Note: in access-control mode, search.py expects datasets[0] to have `tenant_id`, + but prepare_search_result defaults to SearchResultDataset which doesn't define it. + We assert the current behavior (it raises) so refactors don't silently change it. + """ + user = types.SimpleNamespace(id="u1", tenant_id=None) + + async def dummy_authorized_search(**_kwargs): + return [(["answer"], "ctx", [])] + + monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True) + monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search) + + with pytest.raises(AttributeError, match="tenant_id"): + await search_mod.search( + query_text="q", + query_type=SearchType.CHUNKS, + dataset_ids=None, + user=user, + ) From ee29dd1f8147fd8a7441bb3904ece7245a6ce86f Mon Sep 17 00:00:00 2001 From: Christina_Raichel_Francis Date: Wed, 17 Dec 2025 10:36:59 +0000 Subject: [PATCH 076/176] refactor: update cognee tasks to add frequency tracking script --- cognee/tasks/memify/extract_usage_frequency.py | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 cognee/tasks/memify/extract_usage_frequency.py diff --git a/cognee/tasks/memify/extract_usage_frequency.py b/cognee/tasks/memify/extract_usage_frequency.py new file mode 100644 index 000000000..d6ca3773f --- /dev/null +++ b/cognee/tasks/memify/extract_usage_frequency.py @@ -0,0 +1,7 @@ +from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph + + +async def extract_subgraph(subgraphs: list[CogneeGraph]): + for subgraph in subgraphs: + for edge in subgraph.edges: + yield edge From f79ba53e1d97dfcb8843d1343f6f1fcad3c7b31f Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Wed, 17 Dec 2025 12:30:15 +0100 Subject: [PATCH 077/176] COG-3532 chore: retriever test reorganization + adding new tests (unit) (STEP 2) (#1892) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR restructures/adds unittests for the retrieval module. (STEP 2) -Added missing unit tests for all core retrieval business logic ## Type of Change - [ ] Bug fix (non-breaking change that fixes an issue) - [x] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [ ] Code refactoring - [ ] Performance improvement - [ ] Other (please specify): ## Screenshots/Videos (if applicable) ## Pre-submission Checklist - [x] **I have tested my changes thoroughly before submitting this PR** - [x] **This PR contains minimal changes necessary to address the issue/feature** - [x] My code follows the project's coding standards and style guidelines - [x] I have added tests that prove my fix is effective or that my feature works - [x] I have added necessary documentation (if applicable) - [x] All new and existing tests pass - [x] I have searched existing PRs to ensure this change hasn't been submitted already - [x] I have linked any relevant issues in the description - [x] My commits have clear and descriptive messages ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. ## Summary by CodeRabbit * **Tests** * Expanded and refactored retrieval module test suites with comprehensive unit test coverage for ChunksRetriever, SummariesRetriever, RagCompletionRetriever, TripletRetriever, GraphCompletionRetriever, TemporalRetriever, and related components. * Added new test modules for completion utilities, graph summary retrieval, and user feedback functionality. * Improved test robustness with edge case handling and error scenario coverage. ✏️ Tip: You can customize this high-level summary in your review settings. --- .../retrieval/chunks_retriever_test.py | 183 ++++ .../retrieval/conversation_history_test.py | 492 +++++++++++ ...letion_retriever_context_extension_test.py | 469 ++++++++++ .../graph_completion_retriever_cot_test.py | 688 +++++++++++++++ .../graph_completion_retriever_test.py | 648 ++++++++++++++ .../rag_completion_retriever_test.py | 321 +++++++ .../retrieval/summaries_retriever_test.py | 193 +++++ .../retrieval/temporal_retriever_test.py | 705 +++++++++++++++ .../test_brute_force_triplet_search.py | 817 ++++++++++++++++++ .../unit/modules/retrieval/test_completion.py | 343 ++++++++ ...test_graph_summary_completion_retriever.py | 157 ++++ .../retrieval/test_user_qa_feedback.py | 312 +++++++ .../retrieval/triplet_retriever_test.py | 329 +++++++ 13 files changed, 5657 insertions(+) create mode 100644 cognee/tests/unit/modules/retrieval/chunks_retriever_test.py create mode 100644 cognee/tests/unit/modules/retrieval/conversation_history_test.py create mode 100644 cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py create mode 100644 cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py create mode 100644 cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py create mode 100644 cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py create mode 100644 cognee/tests/unit/modules/retrieval/summaries_retriever_test.py create mode 100644 cognee/tests/unit/modules/retrieval/temporal_retriever_test.py create mode 100644 cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py create mode 100644 cognee/tests/unit/modules/retrieval/test_completion.py create mode 100644 cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py create mode 100644 cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py create mode 100644 cognee/tests/unit/modules/retrieval/triplet_retriever_test.py diff --git a/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py b/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py new file mode 100644 index 000000000..98bfd48fe --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/chunks_retriever_test.py @@ -0,0 +1,183 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock + +from cognee.modules.retrieval.chunks_retriever import ChunksRetriever +from cognee.modules.retrieval.exceptions.exceptions import NoDataError +from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError + + +@pytest.fixture +def mock_vector_engine(): + """Create a mock vector engine.""" + engine = AsyncMock() + engine.search = AsyncMock() + return engine + + +@pytest.mark.asyncio +async def test_get_context_success(mock_vector_engine): + """Test successful retrieval of chunk context.""" + mock_result1 = MagicMock() + mock_result1.payload = {"text": "Steve Rodger", "chunk_index": 0} + mock_result2 = MagicMock() + mock_result2.payload = {"text": "Mike Broski", "chunk_index": 1} + + mock_vector_engine.search.return_value = [mock_result1, mock_result2] + + retriever = ChunksRetriever(top_k=5) + + with patch( + "cognee.modules.retrieval.chunks_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") + + assert len(context) == 2 + assert context[0]["text"] == "Steve Rodger" + assert context[1]["text"] == "Mike Broski" + mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=5) + + +@pytest.mark.asyncio +async def test_get_context_collection_not_found_error(mock_vector_engine): + """Test that CollectionNotFoundError is converted to NoDataError.""" + mock_vector_engine.search.side_effect = CollectionNotFoundError("Collection not found") + + retriever = ChunksRetriever() + + with patch( + "cognee.modules.retrieval.chunks_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + with pytest.raises(NoDataError, match="No data found"): + await retriever.get_context("test query") + + +@pytest.mark.asyncio +async def test_get_context_empty_results(mock_vector_engine): + """Test that empty list is returned when no chunks are found.""" + mock_vector_engine.search.return_value = [] + + retriever = ChunksRetriever() + + with patch( + "cognee.modules.retrieval.chunks_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") + + assert context == [] + + +@pytest.mark.asyncio +async def test_get_context_top_k_limit(mock_vector_engine): + """Test that top_k parameter limits the number of results.""" + mock_results = [MagicMock() for _ in range(3)] + for i, result in enumerate(mock_results): + result.payload = {"text": f"Chunk {i}"} + + mock_vector_engine.search.return_value = mock_results + + retriever = ChunksRetriever(top_k=3) + + with patch( + "cognee.modules.retrieval.chunks_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") + + assert len(context) == 3 + mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=3) + + +@pytest.mark.asyncio +async def test_get_completion_with_context(mock_vector_engine): + """Test get_completion returns provided context.""" + retriever = ChunksRetriever() + + provided_context = [{"text": "Steve Rodger"}, {"text": "Mike Broski"}] + completion = await retriever.get_completion("test query", context=provided_context) + + assert completion == provided_context + + +@pytest.mark.asyncio +async def test_get_completion_without_context(mock_vector_engine): + """Test get_completion retrieves context when not provided.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Steve Rodger"} + mock_vector_engine.search.return_value = [mock_result] + + retriever = ChunksRetriever() + + with patch( + "cognee.modules.retrieval.chunks_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + completion = await retriever.get_completion("test query") + + assert len(completion) == 1 + assert completion[0]["text"] == "Steve Rodger" + + +@pytest.mark.asyncio +async def test_init_defaults(): + """Test ChunksRetriever initialization with defaults.""" + retriever = ChunksRetriever() + + assert retriever.top_k == 5 + + +@pytest.mark.asyncio +async def test_init_custom_top_k(): + """Test ChunksRetriever initialization with custom top_k.""" + retriever = ChunksRetriever(top_k=10) + + assert retriever.top_k == 10 + + +@pytest.mark.asyncio +async def test_init_none_top_k(): + """Test ChunksRetriever initialization with None top_k.""" + retriever = ChunksRetriever(top_k=None) + + assert retriever.top_k is None + + +@pytest.mark.asyncio +async def test_get_context_empty_payload(mock_vector_engine): + """Test get_context handles empty payload.""" + mock_result = MagicMock() + mock_result.payload = {} + + mock_vector_engine.search.return_value = [mock_result] + + retriever = ChunksRetriever() + + with patch( + "cognee.modules.retrieval.chunks_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") + + assert len(context) == 1 + assert context[0] == {} + + +@pytest.mark.asyncio +async def test_get_completion_with_session_id(mock_vector_engine): + """Test get_completion with session_id parameter.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Steve Rodger"} + mock_vector_engine.search.return_value = [mock_result] + + retriever = ChunksRetriever() + + with patch( + "cognee.modules.retrieval.chunks_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + completion = await retriever.get_completion("test query", session_id="test_session") + + assert len(completion) == 1 + assert completion[0]["text"] == "Steve Rodger" diff --git a/cognee/tests/unit/modules/retrieval/conversation_history_test.py b/cognee/tests/unit/modules/retrieval/conversation_history_test.py new file mode 100644 index 000000000..f1ce9b370 --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/conversation_history_test.py @@ -0,0 +1,492 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock +from cognee.context_global_variables import session_user +import importlib + + +def create_mock_cache_engine(qa_history=None): + mock_cache = AsyncMock() + if qa_history is None: + qa_history = [] + mock_cache.get_latest_qa = AsyncMock(return_value=qa_history) + mock_cache.add_qa = AsyncMock(return_value=None) + return mock_cache + + +def create_mock_user(): + mock_user = MagicMock() + mock_user.id = "test-user-id-123" + return mock_user + + +class TestConversationHistoryUtils: + @pytest.mark.asyncio + async def test_get_conversation_history_returns_empty_when_no_history(self): + user = create_mock_user() + session_user.set(user) + mock_cache = create_mock_cache_engine([]) + + cache_module = importlib.import_module( + "cognee.infrastructure.databases.cache.get_cache_engine" + ) + + with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): + from cognee.modules.retrieval.utils.session_cache import get_conversation_history + + result = await get_conversation_history(session_id="test_session") + + assert result == "" + + @pytest.mark.asyncio + async def test_get_conversation_history_formats_history_correctly(self): + """Test get_conversation_history formats Q&A history with correct structure.""" + user = create_mock_user() + session_user.set(user) + + mock_history = [ + { + "time": "2024-01-15 10:30:45", + "question": "What is AI?", + "context": "AI is artificial intelligence", + "answer": "AI stands for Artificial Intelligence", + } + ] + mock_cache = create_mock_cache_engine(mock_history) + + # Import the real module to patch safely + cache_module = importlib.import_module( + "cognee.infrastructure.databases.cache.get_cache_engine" + ) + + with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): + with patch( + "cognee.modules.retrieval.utils.session_cache.CacheConfig" + ) as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = True + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + get_conversation_history, + ) + + result = await get_conversation_history(session_id="test_session") + + assert "Previous conversation:" in result + assert "[2024-01-15 10:30:45]" in result + assert "QUESTION: What is AI?" in result + assert "CONTEXT: AI is artificial intelligence" in result + assert "ANSWER: AI stands for Artificial Intelligence" in result + + @pytest.mark.asyncio + async def test_save_to_session_cache_saves_correctly(self): + """Test save_conversation_history calls add_qa with correct parameters.""" + user = create_mock_user() + session_user.set(user) + + mock_cache = create_mock_cache_engine([]) + + cache_module = importlib.import_module( + "cognee.infrastructure.databases.cache.get_cache_engine" + ) + + with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): + with patch( + "cognee.modules.retrieval.utils.session_cache.CacheConfig" + ) as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = True + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + save_conversation_history, + ) + + result = await save_conversation_history( + query="What is Python?", + context_summary="Python is a programming language", + answer="Python is a high-level programming language", + session_id="my_session", + ) + + assert result is True + mock_cache.add_qa.assert_called_once() + + call_kwargs = mock_cache.add_qa.call_args.kwargs + assert call_kwargs["question"] == "What is Python?" + assert call_kwargs["context"] == "Python is a programming language" + assert call_kwargs["answer"] == "Python is a high-level programming language" + assert call_kwargs["session_id"] == "my_session" + + @pytest.mark.asyncio + async def test_save_to_session_cache_uses_default_session_when_none(self): + """Test save_conversation_history uses 'default_session' when session_id is None.""" + user = create_mock_user() + session_user.set(user) + + mock_cache = create_mock_cache_engine([]) + + cache_module = importlib.import_module( + "cognee.infrastructure.databases.cache.get_cache_engine" + ) + + with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): + with patch( + "cognee.modules.retrieval.utils.session_cache.CacheConfig" + ) as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = True + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + save_conversation_history, + ) + + result = await save_conversation_history( + query="Test question", + context_summary="Test context", + answer="Test answer", + session_id=None, + ) + + assert result is True + call_kwargs = mock_cache.add_qa.call_args.kwargs + assert call_kwargs["session_id"] == "default_session" + + @pytest.mark.asyncio + async def test_save_conversation_history_no_user_id(self): + """Test save_conversation_history returns False when user_id is None.""" + session_user.set(None) + + with patch("cognee.modules.retrieval.utils.session_cache.CacheConfig") as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = True + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + save_conversation_history, + ) + + result = await save_conversation_history( + query="Test question", + context_summary="Test context", + answer="Test answer", + ) + + assert result is False + + @pytest.mark.asyncio + async def test_save_conversation_history_caching_disabled(self): + """Test save_conversation_history returns False when caching is disabled.""" + user = create_mock_user() + session_user.set(user) + + with patch("cognee.modules.retrieval.utils.session_cache.CacheConfig") as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = False + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + save_conversation_history, + ) + + result = await save_conversation_history( + query="Test question", + context_summary="Test context", + answer="Test answer", + ) + + assert result is False + + @pytest.mark.asyncio + async def test_save_conversation_history_cache_engine_none(self): + """Test save_conversation_history returns False when cache_engine is None.""" + user = create_mock_user() + session_user.set(user) + + cache_module = importlib.import_module( + "cognee.infrastructure.databases.cache.get_cache_engine" + ) + + with patch.object(cache_module, "get_cache_engine", return_value=None): + with patch( + "cognee.modules.retrieval.utils.session_cache.CacheConfig" + ) as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = True + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + save_conversation_history, + ) + + result = await save_conversation_history( + query="Test question", + context_summary="Test context", + answer="Test answer", + ) + + assert result is False + + @pytest.mark.asyncio + async def test_save_conversation_history_cache_connection_error(self): + """Test save_conversation_history handles CacheConnectionError gracefully.""" + user = create_mock_user() + session_user.set(user) + + from cognee.infrastructure.databases.exceptions import CacheConnectionError + + mock_cache = create_mock_cache_engine([]) + mock_cache.add_qa = AsyncMock(side_effect=CacheConnectionError("Connection failed")) + + cache_module = importlib.import_module( + "cognee.infrastructure.databases.cache.get_cache_engine" + ) + + with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): + with patch( + "cognee.modules.retrieval.utils.session_cache.CacheConfig" + ) as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = True + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + save_conversation_history, + ) + + result = await save_conversation_history( + query="Test question", + context_summary="Test context", + answer="Test answer", + ) + + assert result is False + + @pytest.mark.asyncio + async def test_save_conversation_history_generic_exception(self): + """Test save_conversation_history handles generic exceptions gracefully.""" + user = create_mock_user() + session_user.set(user) + + mock_cache = create_mock_cache_engine([]) + mock_cache.add_qa = AsyncMock(side_effect=ValueError("Unexpected error")) + + cache_module = importlib.import_module( + "cognee.infrastructure.databases.cache.get_cache_engine" + ) + + with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): + with patch( + "cognee.modules.retrieval.utils.session_cache.CacheConfig" + ) as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = True + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + save_conversation_history, + ) + + result = await save_conversation_history( + query="Test question", + context_summary="Test context", + answer="Test answer", + ) + + assert result is False + + @pytest.mark.asyncio + async def test_get_conversation_history_no_user_id(self): + """Test get_conversation_history returns empty string when user_id is None.""" + session_user.set(None) + + with patch("cognee.modules.retrieval.utils.session_cache.CacheConfig") as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = True + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + get_conversation_history, + ) + + result = await get_conversation_history(session_id="test_session") + + assert result == "" + + @pytest.mark.asyncio + async def test_get_conversation_history_caching_disabled(self): + """Test get_conversation_history returns empty string when caching is disabled.""" + user = create_mock_user() + session_user.set(user) + + with patch("cognee.modules.retrieval.utils.session_cache.CacheConfig") as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = False + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + get_conversation_history, + ) + + result = await get_conversation_history(session_id="test_session") + + assert result == "" + + @pytest.mark.asyncio + async def test_get_conversation_history_default_session(self): + """Test get_conversation_history uses 'default_session' when session_id is None.""" + user = create_mock_user() + session_user.set(user) + + mock_cache = create_mock_cache_engine([]) + + cache_module = importlib.import_module( + "cognee.infrastructure.databases.cache.get_cache_engine" + ) + + with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): + with patch( + "cognee.modules.retrieval.utils.session_cache.CacheConfig" + ) as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = True + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + get_conversation_history, + ) + + await get_conversation_history(session_id=None) + + mock_cache.get_latest_qa.assert_called_once_with(str(user.id), "default_session") + + @pytest.mark.asyncio + async def test_get_conversation_history_cache_engine_none(self): + """Test get_conversation_history returns empty string when cache_engine is None.""" + user = create_mock_user() + session_user.set(user) + + cache_module = importlib.import_module( + "cognee.infrastructure.databases.cache.get_cache_engine" + ) + + with patch.object(cache_module, "get_cache_engine", return_value=None): + with patch( + "cognee.modules.retrieval.utils.session_cache.CacheConfig" + ) as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = True + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + get_conversation_history, + ) + + result = await get_conversation_history(session_id="test_session") + + assert result == "" + + @pytest.mark.asyncio + async def test_get_conversation_history_cache_connection_error(self): + """Test get_conversation_history handles CacheConnectionError gracefully.""" + user = create_mock_user() + session_user.set(user) + + from cognee.infrastructure.databases.exceptions import CacheConnectionError + + mock_cache = create_mock_cache_engine([]) + mock_cache.get_latest_qa = AsyncMock(side_effect=CacheConnectionError("Connection failed")) + + cache_module = importlib.import_module( + "cognee.infrastructure.databases.cache.get_cache_engine" + ) + + with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): + with patch( + "cognee.modules.retrieval.utils.session_cache.CacheConfig" + ) as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = True + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + get_conversation_history, + ) + + result = await get_conversation_history(session_id="test_session") + + assert result == "" + + @pytest.mark.asyncio + async def test_get_conversation_history_generic_exception(self): + """Test get_conversation_history handles generic exceptions gracefully.""" + user = create_mock_user() + session_user.set(user) + + mock_cache = create_mock_cache_engine([]) + mock_cache.get_latest_qa = AsyncMock(side_effect=ValueError("Unexpected error")) + + cache_module = importlib.import_module( + "cognee.infrastructure.databases.cache.get_cache_engine" + ) + + with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): + with patch( + "cognee.modules.retrieval.utils.session_cache.CacheConfig" + ) as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = True + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + get_conversation_history, + ) + + result = await get_conversation_history(session_id="test_session") + + assert result == "" + + @pytest.mark.asyncio + async def test_get_conversation_history_missing_keys(self): + """Test get_conversation_history handles missing keys in history entries.""" + user = create_mock_user() + session_user.set(user) + + mock_history = [ + { + "time": "2024-01-15 10:30:45", + "question": "What is AI?", + }, + { + "context": "AI is artificial intelligence", + "answer": "AI stands for Artificial Intelligence", + }, + {}, + ] + mock_cache = create_mock_cache_engine(mock_history) + + cache_module = importlib.import_module( + "cognee.infrastructure.databases.cache.get_cache_engine" + ) + + with patch.object(cache_module, "get_cache_engine", return_value=mock_cache): + with patch( + "cognee.modules.retrieval.utils.session_cache.CacheConfig" + ) as MockCacheConfig: + mock_config = MagicMock() + mock_config.caching = True + MockCacheConfig.return_value = mock_config + + from cognee.modules.retrieval.utils.session_cache import ( + get_conversation_history, + ) + + result = await get_conversation_history(session_id="test_session") + + assert "Previous conversation:" in result + assert "[2024-01-15 10:30:45]" in result + assert "QUESTION: What is AI?" in result + assert "Unknown time" in result + assert "CONTEXT: AI is artificial intelligence" in result + assert "ANSWER: AI stands for Artificial Intelligence" in result diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py new file mode 100644 index 000000000..6a9b07d38 --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py @@ -0,0 +1,469 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock +from uuid import UUID + +from cognee.modules.retrieval.graph_completion_context_extension_retriever import ( + GraphCompletionContextExtensionRetriever, +) +from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge + + +@pytest.fixture +def mock_edge(): + """Create a mock edge.""" + edge = MagicMock(spec=Edge) + return edge + + +@pytest.mark.asyncio +async def test_get_triplets_inherited(mock_edge): + """Test that get_triplets is inherited from parent class.""" + retriever = GraphCompletionContextExtensionRetriever() + + with patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ): + triplets = await retriever.get_triplets("test query") + + assert len(triplets) == 1 + assert triplets[0] == mock_edge + + +@pytest.mark.asyncio +async def test_init_defaults(): + """Test GraphCompletionContextExtensionRetriever initialization with defaults.""" + retriever = GraphCompletionContextExtensionRetriever() + + assert retriever.top_k == 5 + assert retriever.user_prompt_path == "graph_context_for_question.txt" + assert retriever.system_prompt_path == "answer_simple_question.txt" + + +@pytest.mark.asyncio +async def test_init_custom_params(): + """Test GraphCompletionContextExtensionRetriever initialization with custom parameters.""" + retriever = GraphCompletionContextExtensionRetriever( + top_k=10, + user_prompt_path="custom_user.txt", + system_prompt_path="custom_system.txt", + system_prompt="Custom prompt", + node_type=str, + node_name=["node1"], + save_interaction=True, + wide_search_top_k=200, + triplet_distance_penalty=5.0, + ) + + assert retriever.top_k == 10 + assert retriever.user_prompt_path == "custom_user.txt" + assert retriever.system_prompt_path == "custom_system.txt" + assert retriever.system_prompt == "Custom prompt" + assert retriever.node_type is str + assert retriever.node_name == ["node1"] + assert retriever.save_interaction is True + assert retriever.wide_search_top_k == 200 + assert retriever.triplet_distance_penalty == 5.0 + + +@pytest.mark.asyncio +async def test_get_completion_without_context(mock_edge): + """Test get_completion retrieves context when not provided.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionContextExtensionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", context_extension_rounds=1) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_provided_context(mock_edge): + """Test get_completion uses provided context.""" + retriever = GraphCompletionContextExtensionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + "test query", context=[mock_edge], context_extension_rounds=1 + ) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_context_extension_rounds(mock_edge): + """Test get_completion with multiple context extension rounds.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionContextExtensionRetriever() + + # Create a second edge for extension rounds + mock_edge2 = MagicMock(spec=Edge) + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch.object( + retriever, + "get_context", + new_callable=AsyncMock, + side_effect=[[mock_edge], [mock_edge2]], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + side_effect=["Resolved context", "Extended context"], # Different contexts + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + side_effect=[ + "Extension query", + "Generated answer", + ], # Query for extension, then final answer + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", context_extension_rounds=1) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_context_extension_stops_early(mock_edge): + """Test get_completion stops early when no new triplets found.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionContextExtensionRetriever() + + with ( + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + side_effect=[ + "Extension query", + "Generated answer", + ], + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + # When get_context returns same triplets, the loop should stop early + completion = await retriever.get_completion( + "test query", context=[mock_edge], context_extension_rounds=4 + ) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_session(mock_edge): + """Test get_completion with session caching enabled.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionContextExtensionRetriever() + + mock_user = MagicMock() + mock_user.id = "test-user-id" + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.get_conversation_history", + return_value="Previous conversation", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.summarize_text", + return_value="Context summary", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + side_effect=[ + "Extension query", + "Generated answer", + ], # Extension query, then final answer + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.save_conversation_history", + ) as mock_save, + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.session_user" + ) as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = mock_user + + completion = await retriever.get_completion( + "test query", session_id="test_session", context_extension_rounds=1 + ) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + mock_save.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_completion_with_save_interaction(mock_edge): + """Test get_completion with save_interaction enabled.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + mock_graph_engine.add_edges = AsyncMock() + + retriever = GraphCompletionContextExtensionRetriever(save_interaction=True) + + mock_node1 = MagicMock() + mock_node2 = MagicMock() + mock_edge.node1 = mock_node1 + mock_edge.node2 = mock_node2 + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + side_effect=[ + "Extension query", + "Generated answer", + ], # Extension query, then final answer + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node", + side_effect=[ + UUID("550e8400-e29b-41d4-a716-446655440000"), + UUID("550e8400-e29b-41d4-a716-446655440001"), + ], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.add_data_points", + ) as mock_add_data, + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + "test query", context=[mock_edge], context_extension_rounds=1 + ) + + assert isinstance(completion, list) + assert len(completion) == 1 + mock_add_data.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_completion_with_response_model(mock_edge): + """Test get_completion with custom response model.""" + from pydantic import BaseModel + + class TestModel(BaseModel): + answer: str + + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionContextExtensionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + side_effect=[ + "Extension query", + TestModel(answer="Test answer"), + ], # Extension query, then final answer + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + "test query", response_model=TestModel, context_extension_rounds=1 + ) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert isinstance(completion[0], TestModel) + + +@pytest.mark.asyncio +async def test_get_completion_with_session_no_user_id(mock_edge): + """Test get_completion with session config but no user ID.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionContextExtensionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + side_effect=[ + "Extension query", + "Generated answer", + ], # Extension query, then final answer + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.session_user" + ) as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = None # No user + + completion = await retriever.get_completion("test query", context_extension_rounds=1) + + assert isinstance(completion, list) + assert len(completion) == 1 + + +@pytest.mark.asyncio +async def test_get_completion_zero_extension_rounds(mock_edge): + """Test get_completion with zero context extension rounds.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionContextExtensionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_context_extension_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", context_extension_rounds=0) + + assert isinstance(completion, list) + assert len(completion) == 1 diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py new file mode 100644 index 000000000..9f3147512 --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py @@ -0,0 +1,688 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock +from uuid import UUID + +from cognee.modules.retrieval.graph_completion_cot_retriever import ( + GraphCompletionCotRetriever, + _as_answer_text, +) +from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge +from cognee.infrastructure.llm.LLMGateway import LLMGateway + + +@pytest.fixture +def mock_edge(): + """Create a mock edge.""" + edge = MagicMock(spec=Edge) + return edge + + +@pytest.mark.asyncio +async def test_get_triplets_inherited(mock_edge): + """Test that get_triplets is inherited from parent class.""" + retriever = GraphCompletionCotRetriever() + + with patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ): + triplets = await retriever.get_triplets("test query") + + assert len(triplets) == 1 + assert triplets[0] == mock_edge + + +@pytest.mark.asyncio +async def test_init_custom_params(): + """Test GraphCompletionCotRetriever initialization with custom parameters.""" + retriever = GraphCompletionCotRetriever( + top_k=10, + user_prompt_path="custom_user.txt", + system_prompt_path="custom_system.txt", + validation_user_prompt_path="custom_validation_user.txt", + validation_system_prompt_path="custom_validation_system.txt", + followup_system_prompt_path="custom_followup_system.txt", + followup_user_prompt_path="custom_followup_user.txt", + ) + + assert retriever.top_k == 10 + assert retriever.user_prompt_path == "custom_user.txt" + assert retriever.system_prompt_path == "custom_system.txt" + assert retriever.validation_user_prompt_path == "custom_validation_user.txt" + assert retriever.validation_system_prompt_path == "custom_validation_system.txt" + assert retriever.followup_system_prompt_path == "custom_followup_system.txt" + assert retriever.followup_user_prompt_path == "custom_followup_user.txt" + + +@pytest.mark.asyncio +async def test_init_defaults(): + """Test GraphCompletionCotRetriever initialization with defaults.""" + retriever = GraphCompletionCotRetriever() + + assert retriever.validation_user_prompt_path == "cot_validation_user_prompt.txt" + assert retriever.validation_system_prompt_path == "cot_validation_system_prompt.txt" + assert retriever.followup_system_prompt_path == "cot_followup_system_prompt.txt" + assert retriever.followup_user_prompt_path == "cot_followup_user_prompt.txt" + + +@pytest.mark.asyncio +async def test_run_cot_completion_round_zero_with_context(mock_edge): + """Test _run_cot_completion round 0 with provided context.""" + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt", + return_value="Rendered prompt", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt", + return_value="System prompt", + ), + patch.object( + LLMGateway, + "acreate_structured_output", + new_callable=AsyncMock, + side_effect=["validation_result", "followup_question"], + ), + ): + completion, context_text, triplets = await retriever._run_cot_completion( + query="test query", + context=[mock_edge], + max_iter=1, + ) + + assert completion == "Generated answer" + assert context_text == "Resolved context" + assert len(triplets) >= 1 + + +@pytest.mark.asyncio +async def test_run_cot_completion_round_zero_without_context(mock_edge): + """Test _run_cot_completion round 0 without provided context.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + ): + completion, context_text, triplets = await retriever._run_cot_completion( + query="test query", + context=None, + max_iter=1, + ) + + assert completion == "Generated answer" + assert context_text == "Resolved context" + assert len(triplets) >= 1 + + +@pytest.mark.asyncio +async def test_run_cot_completion_multiple_rounds(mock_edge): + """Test _run_cot_completion with multiple rounds.""" + retriever = GraphCompletionCotRetriever() + + mock_edge2 = MagicMock(spec=Edge) + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + patch.object( + retriever, + "get_context", + new_callable=AsyncMock, + side_effect=[[mock_edge], [mock_edge2]], + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt", + return_value="Rendered prompt", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt", + return_value="System prompt", + ), + patch.object( + LLMGateway, + "acreate_structured_output", + new_callable=AsyncMock, + side_effect=[ + "validation_result", + "followup_question", + "validation_result2", + "followup_question2", + ], + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", + return_value="Generated answer", + ), + ): + completion, context_text, triplets = await retriever._run_cot_completion( + query="test query", + context=[mock_edge], + max_iter=2, + ) + + assert completion == "Generated answer" + assert context_text == "Resolved context" + assert len(triplets) >= 1 + + +@pytest.mark.asyncio +async def test_run_cot_completion_with_conversation_history(mock_edge): + """Test _run_cot_completion with conversation history.""" + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ) as mock_generate, + ): + completion, context_text, triplets = await retriever._run_cot_completion( + query="test query", + context=[mock_edge], + conversation_history="Previous conversation", + max_iter=1, + ) + + assert completion == "Generated answer" + call_kwargs = mock_generate.call_args[1] + assert call_kwargs.get("conversation_history") == "Previous conversation" + + +@pytest.mark.asyncio +async def test_run_cot_completion_with_response_model(mock_edge): + """Test _run_cot_completion with custom response model.""" + from pydantic import BaseModel + + class TestModel(BaseModel): + answer: str + + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value=TestModel(answer="Test answer"), + ), + ): + completion, context_text, triplets = await retriever._run_cot_completion( + query="test query", + context=[mock_edge], + response_model=TestModel, + max_iter=1, + ) + + assert isinstance(completion, TestModel) + assert completion.answer == "Test answer" + + +@pytest.mark.asyncio +async def test_run_cot_completion_empty_conversation_history(mock_edge): + """Test _run_cot_completion with empty conversation history.""" + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ) as mock_generate, + ): + completion, context_text, triplets = await retriever._run_cot_completion( + query="test query", + context=[mock_edge], + conversation_history="", + max_iter=1, + ) + + assert completion == "Generated answer" + # Verify conversation_history was passed as None when empty + call_kwargs = mock_generate.call_args[1] + assert call_kwargs.get("conversation_history") is None + + +@pytest.mark.asyncio +async def test_get_completion_without_context(mock_edge): + """Test get_completion retrieves context when not provided.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt", + return_value="Rendered prompt", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt", + return_value="System prompt", + ), + patch.object( + LLMGateway, + "acreate_structured_output", + new_callable=AsyncMock, + side_effect=["validation_result", "followup_question"], + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", max_iter=1) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_provided_context(mock_edge): + """Test get_completion uses provided context.""" + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", context=[mock_edge], max_iter=1) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_session(mock_edge): + """Test get_completion with session caching enabled.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionCotRetriever() + + mock_user = MagicMock() + mock_user.id = "test-user-id" + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.get_conversation_history", + return_value="Previous conversation", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.summarize_text", + return_value="Context summary", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.save_conversation_history", + ) as mock_save, + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" + ) as mock_cache_config, + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.session_user" + ) as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = mock_user + + completion = await retriever.get_completion( + "test query", session_id="test_session", max_iter=1 + ) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + mock_save.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_completion_with_save_interaction(mock_edge): + """Test get_completion with save_interaction enabled.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + mock_graph_engine.add_edges = AsyncMock() + + retriever = GraphCompletionCotRetriever(save_interaction=True) + + mock_node1 = MagicMock() + mock_node2 = MagicMock() + mock_edge.node1 = mock_node1 + mock_edge.node2 = mock_node2 + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt", + return_value="Rendered prompt", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt", + return_value="System prompt", + ), + patch.object( + LLMGateway, + "acreate_structured_output", + new_callable=AsyncMock, + side_effect=["validation_result", "followup_question"], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node", + side_effect=[ + UUID("550e8400-e29b-41d4-a716-446655440000"), + UUID("550e8400-e29b-41d4-a716-446655440001"), + ], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.add_data_points", + ) as mock_add_data, + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + # Pass context so save_interaction condition is met + completion = await retriever.get_completion("test query", context=[mock_edge], max_iter=1) + + assert isinstance(completion, list) + assert len(completion) == 1 + mock_add_data.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_completion_with_response_model(mock_edge): + """Test get_completion with custom response model.""" + from pydantic import BaseModel + + class TestModel(BaseModel): + answer: str + + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value=TestModel(answer="Test answer"), + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + "test query", response_model=TestModel, max_iter=1 + ) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert isinstance(completion[0], TestModel) + + +@pytest.mark.asyncio +async def test_get_completion_with_session_no_user_id(mock_edge): + """Test get_completion with session config but no user ID.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionCotRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" + ) as mock_cache_config, + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.session_user" + ) as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = None # No user + + completion = await retriever.get_completion("test query", max_iter=1) + + assert isinstance(completion, list) + assert len(completion) == 1 + + +@pytest.mark.asyncio +async def test_get_completion_with_save_interaction_no_context(mock_edge): + """Test get_completion with save_interaction but no context provided.""" + retriever = GraphCompletionCotRetriever(save_interaction=True) + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.generate_completion", + return_value="Generated answer", + ), + patch.object(retriever, "get_context", new_callable=AsyncMock, return_value=[mock_edge]), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever._as_answer_text", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.render_prompt", + return_value="Rendered prompt", + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.read_query_prompt", + return_value="System prompt", + ), + patch.object( + LLMGateway, + "acreate_structured_output", + new_callable=AsyncMock, + side_effect=["validation_result", "followup_question"], + ), + patch( + "cognee.modules.retrieval.graph_completion_cot_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", context=None, max_iter=1) + + assert isinstance(completion, list) + assert len(completion) == 1 + + +@pytest.mark.asyncio +async def test_as_answer_text_with_typeerror(): + """Test _as_answer_text handles TypeError when json.dumps fails.""" + non_serializable = {1, 2, 3} + + result = _as_answer_text(non_serializable) + + assert isinstance(result, str) + assert result == str(non_serializable) + + +@pytest.mark.asyncio +async def test_as_answer_text_with_string(): + """Test _as_answer_text with string input.""" + result = _as_answer_text("test string") + assert result == "test string" + + +@pytest.mark.asyncio +async def test_as_answer_text_with_dict(): + """Test _as_answer_text with dictionary input.""" + test_dict = {"key": "value", "number": 42} + result = _as_answer_text(test_dict) + assert isinstance(result, str) + assert "key" in result + assert "value" in result + + +@pytest.mark.asyncio +async def test_as_answer_text_with_basemodel(): + """Test _as_answer_text with Pydantic BaseModel input.""" + from pydantic import BaseModel + + class TestModel(BaseModel): + answer: str + + test_model = TestModel(answer="test answer") + result = _as_answer_text(test_model) + + assert isinstance(result, str) + assert "[Structured Response]" in result + assert "test answer" in result diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py new file mode 100644 index 000000000..c22f30fd0 --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py @@ -0,0 +1,648 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock +from uuid import UUID + +from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever +from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge + + +@pytest.fixture +def mock_edge(): + """Create a mock edge.""" + edge = MagicMock(spec=Edge) + return edge + + +@pytest.mark.asyncio +async def test_get_triplets_success(mock_edge): + """Test successful retrieval of triplets.""" + retriever = GraphCompletionRetriever(top_k=5) + + with patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ) as mock_search: + triplets = await retriever.get_triplets("test query") + + assert len(triplets) == 1 + assert triplets[0] == mock_edge + mock_search.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_triplets_empty_results(): + """Test that empty list is returned when no triplets are found.""" + retriever = GraphCompletionRetriever() + + with patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[], + ): + triplets = await retriever.get_triplets("test query") + + assert triplets == [] + + +@pytest.mark.asyncio +async def test_get_triplets_top_k_parameter(): + """Test that top_k parameter is passed to brute_force_triplet_search.""" + retriever = GraphCompletionRetriever(top_k=10) + + with patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[], + ) as mock_search: + await retriever.get_triplets("test query") + + call_kwargs = mock_search.call_args[1] + assert call_kwargs["top_k"] == 10 + + +@pytest.mark.asyncio +async def test_get_context_success(mock_edge): + """Test successful retrieval of context.""" + retriever = GraphCompletionRetriever() + + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + ): + context = await retriever.get_context("test query") + + assert isinstance(context, list) + assert len(context) == 1 + assert context[0] == mock_edge + + +@pytest.mark.asyncio +async def test_get_context_empty_results(): + """Test that empty list is returned when no context is found.""" + retriever = GraphCompletionRetriever() + + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[], + ), + ): + context = await retriever.get_context("test query") + + assert context == [] + + +@pytest.mark.asyncio +async def test_get_context_empty_graph(): + """Test that empty list is returned when graph is empty.""" + retriever = GraphCompletionRetriever() + + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=True) + + with patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ): + context = await retriever.get_context("test query") + + assert context == [] + + +@pytest.mark.asyncio +async def test_resolve_edges_to_text(mock_edge): + """Test resolve_edges_to_text method.""" + retriever = GraphCompletionRetriever() + + with patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved text", + ) as mock_resolve: + result = await retriever.resolve_edges_to_text([mock_edge]) + + assert result == "Resolved text" + mock_resolve.assert_awaited_once_with([mock_edge]) + + +@pytest.mark.asyncio +async def test_init_defaults(): + """Test GraphCompletionRetriever initialization with defaults.""" + retriever = GraphCompletionRetriever() + + assert retriever.top_k == 5 + assert retriever.user_prompt_path == "graph_context_for_question.txt" + assert retriever.system_prompt_path == "answer_simple_question.txt" + assert retriever.node_type is None + assert retriever.node_name is None + + +@pytest.mark.asyncio +async def test_init_custom_params(): + """Test GraphCompletionRetriever initialization with custom parameters.""" + retriever = GraphCompletionRetriever( + top_k=10, + user_prompt_path="custom_user.txt", + system_prompt_path="custom_system.txt", + system_prompt="Custom prompt", + node_type=str, + node_name=["node1"], + save_interaction=True, + wide_search_top_k=200, + triplet_distance_penalty=5.0, + ) + + assert retriever.top_k == 10 + assert retriever.user_prompt_path == "custom_user.txt" + assert retriever.system_prompt_path == "custom_system.txt" + assert retriever.system_prompt == "Custom prompt" + assert retriever.node_type is str + assert retriever.node_name == ["node1"] + assert retriever.save_interaction is True + assert retriever.wide_search_top_k == 200 + assert retriever.triplet_distance_penalty == 5.0 + + +@pytest.mark.asyncio +async def test_init_none_top_k(): + """Test GraphCompletionRetriever initialization with None top_k.""" + retriever = GraphCompletionRetriever(top_k=None) + + assert retriever.top_k == 5 # None defaults to 5 + + +@pytest.mark.asyncio +async def test_convert_retrieved_objects_to_context(mock_edge): + """Test convert_retrieved_objects_to_context method.""" + retriever = GraphCompletionRetriever() + + with patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved text", + ) as mock_resolve: + result = await retriever.convert_retrieved_objects_to_context([mock_edge]) + + assert result == "Resolved text" + mock_resolve.assert_awaited_once_with([mock_edge]) + + +@pytest.mark.asyncio +async def test_get_completion_without_context(mock_edge): + """Test get_completion retrieves context when not provided.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_provided_context(mock_edge): + """Test get_completion uses provided context.""" + retriever = GraphCompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", context=[mock_edge]) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_session(mock_edge): + """Test get_completion with session caching enabled.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionRetriever() + + mock_user = MagicMock() + mock_user.id = "test-user-id" + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_conversation_history", + return_value="Previous conversation", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.summarize_text", + return_value="Context summary", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.save_conversation_history", + ) as mock_save, + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + patch( + "cognee.modules.retrieval.graph_completion_retriever.session_user" + ) as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = mock_user + + completion = await retriever.get_completion("test query", session_id="test_session") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + mock_save.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_completion_with_response_model(mock_edge): + """Test get_completion with custom response model.""" + from pydantic import BaseModel + + class TestModel(BaseModel): + answer: str + + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value=TestModel(answer="Test answer"), + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", response_model=TestModel) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert isinstance(completion[0], TestModel) + + +@pytest.mark.asyncio +async def test_get_completion_empty_context(mock_edge): + """Test get_completion with empty context.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query") + + assert isinstance(completion, list) + assert len(completion) == 1 + + +@pytest.mark.asyncio +async def test_save_qa(mock_edge): + """Test save_qa method.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.add_edges = AsyncMock() + + retriever = GraphCompletionRetriever() + + mock_node1 = MagicMock() + mock_node2 = MagicMock() + mock_edge.node1 = mock_node1 + mock_edge.node2 = mock_node2 + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node", + side_effect=["uuid1", "uuid2"], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.add_data_points", + ) as mock_add_data, + ): + await retriever.save_qa( + question="Test question", + answer="Test answer", + context="Test context", + triplets=[mock_edge], + ) + + mock_add_data.assert_awaited_once() + mock_graph_engine.add_edges.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_save_qa_no_triplet_ids(mock_edge): + """Test save_qa when triplets have no extractable IDs.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.add_edges = AsyncMock() + + retriever = GraphCompletionRetriever() + + mock_node1 = MagicMock() + mock_node2 = MagicMock() + mock_edge.node1 = mock_node1 + mock_edge.node2 = mock_node2 + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node", + return_value=None, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.add_data_points", + ) as mock_add_data, + ): + await retriever.save_qa( + question="Test question", + answer="Test answer", + context="Test context", + triplets=[mock_edge], + ) + + mock_add_data.assert_awaited_once() + mock_graph_engine.add_edges.assert_not_called() + + +@pytest.mark.asyncio +async def test_save_qa_empty_triplets(): + """Test save_qa with empty triplets list.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.add_edges = AsyncMock() + + retriever = GraphCompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.add_data_points", + ) as mock_add_data, + ): + await retriever.save_qa( + question="Test question", + answer="Test answer", + context="Test context", + triplets=[], + ) + + mock_add_data.assert_awaited_once() + mock_graph_engine.add_edges.assert_not_called() + + +@pytest.mark.asyncio +async def test_get_completion_with_save_interaction_no_completion(mock_edge): + """Test get_completion with save_interaction but no completion.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionRetriever(save_interaction=True) + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value=None, # No completion + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] is None + + +@pytest.mark.asyncio +async def test_get_completion_with_save_interaction_no_context(mock_edge): + """Test get_completion with save_interaction but no context provided.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionRetriever(save_interaction=True) + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", context=None) + + assert isinstance(completion, list) + assert len(completion) == 1 + + +@pytest.mark.asyncio +async def test_get_completion_with_save_interaction_all_conditions_met(mock_edge): + """Test get_completion with save_interaction when all conditions are met (line 216).""" + mock_graph_engine = AsyncMock() + mock_graph_engine.is_empty = AsyncMock(return_value=False) + + retriever = GraphCompletionRetriever(save_interaction=True) + + mock_node1 = MagicMock() + mock_node2 = MagicMock() + mock_edge.node1 = mock_node1 + mock_edge.node2 = mock_node2 + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.brute_force_triplet_search", + return_value=[mock_edge], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.resolve_edges_to_text", + return_value="Resolved context", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.extract_uuid_from_node", + side_effect=[ + UUID("550e8400-e29b-41d4-a716-446655440000"), + UUID("550e8400-e29b-41d4-a716-446655440001"), + ], + ), + patch( + "cognee.modules.retrieval.graph_completion_retriever.add_data_points", + ) as mock_add_data, + patch( + "cognee.modules.retrieval.graph_completion_retriever.CacheConfig" + ) as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", context=[mock_edge]) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + mock_add_data.assert_awaited_once() diff --git a/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py new file mode 100644 index 000000000..e998d419d --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py @@ -0,0 +1,321 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock + +from cognee.modules.retrieval.completion_retriever import CompletionRetriever +from cognee.modules.retrieval.exceptions.exceptions import NoDataError +from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError + + +@pytest.fixture +def mock_vector_engine(): + """Create a mock vector engine.""" + engine = AsyncMock() + engine.search = AsyncMock() + return engine + + +@pytest.mark.asyncio +async def test_get_context_success(mock_vector_engine): + """Test successful retrieval of context.""" + mock_result1 = MagicMock() + mock_result1.payload = {"text": "Steve Rodger"} + mock_result2 = MagicMock() + mock_result2.payload = {"text": "Mike Broski"} + + mock_vector_engine.search.return_value = [mock_result1, mock_result2] + + retriever = CompletionRetriever(top_k=2) + + with patch( + "cognee.modules.retrieval.completion_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") + + assert context == "Steve Rodger\nMike Broski" + mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=2) + + +@pytest.mark.asyncio +async def test_get_context_collection_not_found_error(mock_vector_engine): + """Test that CollectionNotFoundError is converted to NoDataError.""" + mock_vector_engine.search.side_effect = CollectionNotFoundError("Collection not found") + + retriever = CompletionRetriever() + + with patch( + "cognee.modules.retrieval.completion_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + with pytest.raises(NoDataError, match="No data found"): + await retriever.get_context("test query") + + +@pytest.mark.asyncio +async def test_get_context_empty_results(mock_vector_engine): + """Test that empty string is returned when no chunks are found.""" + mock_vector_engine.search.return_value = [] + + retriever = CompletionRetriever() + + with patch( + "cognee.modules.retrieval.completion_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") + + assert context == "" + + +@pytest.mark.asyncio +async def test_get_context_top_k_limit(mock_vector_engine): + """Test that top_k parameter limits the number of results.""" + mock_results = [MagicMock() for _ in range(2)] + for i, result in enumerate(mock_results): + result.payload = {"text": f"Chunk {i}"} + + mock_vector_engine.search.return_value = mock_results + + retriever = CompletionRetriever(top_k=2) + + with patch( + "cognee.modules.retrieval.completion_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") + + assert context == "Chunk 0\nChunk 1" + mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=2) + + +@pytest.mark.asyncio +async def test_get_context_single_chunk(mock_vector_engine): + """Test get_context with single chunk result.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Single chunk text"} + mock_vector_engine.search.return_value = [mock_result] + + retriever = CompletionRetriever() + + with patch( + "cognee.modules.retrieval.completion_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") + + assert context == "Single chunk text" + + +@pytest.mark.asyncio +async def test_get_completion_without_session(mock_vector_engine): + """Test get_completion without session caching.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Chunk text"} + mock_vector_engine.search.return_value = [mock_result] + + retriever = CompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.completion_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch("cognee.modules.retrieval.completion_retriever.CacheConfig") as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_provided_context(mock_vector_engine): + """Test get_completion with provided context.""" + retriever = CompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch("cognee.modules.retrieval.completion_retriever.CacheConfig") as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", context="Provided context") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_session(mock_vector_engine): + """Test get_completion with session caching enabled.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Chunk text"} + mock_vector_engine.search.return_value = [mock_result] + + retriever = CompletionRetriever() + + mock_user = MagicMock() + mock_user.id = "test-user-id" + + with ( + patch( + "cognee.modules.retrieval.completion_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.completion_retriever.get_conversation_history", + return_value="Previous conversation", + ), + patch( + "cognee.modules.retrieval.completion_retriever.summarize_text", + return_value="Context summary", + ), + patch( + "cognee.modules.retrieval.completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.completion_retriever.save_conversation_history", + ) as mock_save, + patch("cognee.modules.retrieval.completion_retriever.CacheConfig") as mock_cache_config, + patch("cognee.modules.retrieval.completion_retriever.session_user") as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = mock_user + + completion = await retriever.get_completion("test query", session_id="test_session") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + mock_save.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_completion_with_session_no_user_id(mock_vector_engine): + """Test get_completion with session config but no user ID.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Chunk text"} + mock_vector_engine.search.return_value = [mock_result] + + retriever = CompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.completion_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.completion_retriever.generate_completion", + return_value="Generated answer", + ), + patch("cognee.modules.retrieval.completion_retriever.CacheConfig") as mock_cache_config, + patch("cognee.modules.retrieval.completion_retriever.session_user") as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = None # No user + + completion = await retriever.get_completion("test query") + + assert isinstance(completion, list) + assert len(completion) == 1 + + +@pytest.mark.asyncio +async def test_get_completion_with_response_model(mock_vector_engine): + """Test get_completion with custom response model.""" + from pydantic import BaseModel + + class TestModel(BaseModel): + answer: str + + mock_result = MagicMock() + mock_result.payload = {"text": "Chunk text"} + mock_vector_engine.search.return_value = [mock_result] + + retriever = CompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.completion_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.completion_retriever.generate_completion", + return_value=TestModel(answer="Test answer"), + ), + patch("cognee.modules.retrieval.completion_retriever.CacheConfig") as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", response_model=TestModel) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert isinstance(completion[0], TestModel) + + +@pytest.mark.asyncio +async def test_init_defaults(): + """Test CompletionRetriever initialization with defaults.""" + retriever = CompletionRetriever() + + assert retriever.user_prompt_path == "context_for_question.txt" + assert retriever.system_prompt_path == "answer_simple_question.txt" + assert retriever.top_k == 1 + assert retriever.system_prompt is None + + +@pytest.mark.asyncio +async def test_init_custom_params(): + """Test CompletionRetriever initialization with custom parameters.""" + retriever = CompletionRetriever( + user_prompt_path="custom_user.txt", + system_prompt_path="custom_system.txt", + system_prompt="Custom prompt", + top_k=10, + ) + + assert retriever.user_prompt_path == "custom_user.txt" + assert retriever.system_prompt_path == "custom_system.txt" + assert retriever.system_prompt == "Custom prompt" + assert retriever.top_k == 10 + + +@pytest.mark.asyncio +async def test_get_context_missing_text_key(mock_vector_engine): + """Test get_context handles missing text key in payload.""" + mock_result = MagicMock() + mock_result.payload = {"other_key": "value"} + + mock_vector_engine.search.return_value = [mock_result] + + retriever = CompletionRetriever() + + with patch( + "cognee.modules.retrieval.completion_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + with pytest.raises(KeyError): + await retriever.get_context("test query") diff --git a/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py b/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py new file mode 100644 index 000000000..e552ac74a --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/summaries_retriever_test.py @@ -0,0 +1,193 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock + +from cognee.modules.retrieval.summaries_retriever import SummariesRetriever +from cognee.modules.retrieval.exceptions.exceptions import NoDataError +from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError + + +@pytest.fixture +def mock_vector_engine(): + """Create a mock vector engine.""" + engine = AsyncMock() + engine.search = AsyncMock() + return engine + + +@pytest.mark.asyncio +async def test_get_context_success(mock_vector_engine): + """Test successful retrieval of summary context.""" + mock_result1 = MagicMock() + mock_result1.payload = {"text": "S.R.", "made_from": "chunk1"} + mock_result2 = MagicMock() + mock_result2.payload = {"text": "M.B.", "made_from": "chunk2"} + + mock_vector_engine.search.return_value = [mock_result1, mock_result2] + + retriever = SummariesRetriever(top_k=5) + + with patch( + "cognee.modules.retrieval.summaries_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") + + assert len(context) == 2 + assert context[0]["text"] == "S.R." + assert context[1]["text"] == "M.B." + mock_vector_engine.search.assert_awaited_once_with("TextSummary_text", "test query", limit=5) + + +@pytest.mark.asyncio +async def test_get_context_collection_not_found_error(mock_vector_engine): + """Test that CollectionNotFoundError is converted to NoDataError.""" + mock_vector_engine.search.side_effect = CollectionNotFoundError("Collection not found") + + retriever = SummariesRetriever() + + with patch( + "cognee.modules.retrieval.summaries_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + with pytest.raises(NoDataError, match="No data found"): + await retriever.get_context("test query") + + +@pytest.mark.asyncio +async def test_get_context_empty_results(mock_vector_engine): + """Test that empty list is returned when no summaries are found.""" + mock_vector_engine.search.return_value = [] + + retriever = SummariesRetriever() + + with patch( + "cognee.modules.retrieval.summaries_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") + + assert context == [] + + +@pytest.mark.asyncio +async def test_get_context_top_k_limit(mock_vector_engine): + """Test that top_k parameter limits the number of results.""" + mock_results = [MagicMock() for _ in range(3)] + for i, result in enumerate(mock_results): + result.payload = {"text": f"Summary {i}"} + + mock_vector_engine.search.return_value = mock_results + + retriever = SummariesRetriever(top_k=3) + + with patch( + "cognee.modules.retrieval.summaries_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") + + assert len(context) == 3 + mock_vector_engine.search.assert_awaited_once_with("TextSummary_text", "test query", limit=3) + + +@pytest.mark.asyncio +async def test_get_completion_with_context(mock_vector_engine): + """Test get_completion returns provided context.""" + retriever = SummariesRetriever() + + provided_context = [{"text": "S.R."}, {"text": "M.B."}] + completion = await retriever.get_completion("test query", context=provided_context) + + assert completion == provided_context + + +@pytest.mark.asyncio +async def test_get_completion_without_context(mock_vector_engine): + """Test get_completion retrieves context when not provided.""" + mock_result = MagicMock() + mock_result.payload = {"text": "S.R."} + mock_vector_engine.search.return_value = [mock_result] + + retriever = SummariesRetriever() + + with patch( + "cognee.modules.retrieval.summaries_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + completion = await retriever.get_completion("test query") + + assert len(completion) == 1 + assert completion[0]["text"] == "S.R." + + +@pytest.mark.asyncio +async def test_init_defaults(): + """Test SummariesRetriever initialization with defaults.""" + retriever = SummariesRetriever() + + assert retriever.top_k == 5 + + +@pytest.mark.asyncio +async def test_init_custom_top_k(): + """Test SummariesRetriever initialization with custom top_k.""" + retriever = SummariesRetriever(top_k=10) + + assert retriever.top_k == 10 + + +@pytest.mark.asyncio +async def test_get_context_empty_payload(mock_vector_engine): + """Test get_context handles empty payload.""" + mock_result = MagicMock() + mock_result.payload = {} + + mock_vector_engine.search.return_value = [mock_result] + + retriever = SummariesRetriever() + + with patch( + "cognee.modules.retrieval.summaries_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") + + assert len(context) == 1 + assert context[0] == {} + + +@pytest.mark.asyncio +async def test_get_completion_with_session_id(mock_vector_engine): + """Test get_completion with session_id parameter.""" + mock_result = MagicMock() + mock_result.payload = {"text": "S.R."} + mock_vector_engine.search.return_value = [mock_result] + + retriever = SummariesRetriever() + + with patch( + "cognee.modules.retrieval.summaries_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + completion = await retriever.get_completion("test query", session_id="test_session") + + assert len(completion) == 1 + assert completion[0]["text"] == "S.R." + + +@pytest.mark.asyncio +async def test_get_completion_with_kwargs(mock_vector_engine): + """Test get_completion accepts additional kwargs.""" + mock_result = MagicMock() + mock_result.payload = {"text": "S.R."} + mock_vector_engine.search.return_value = [mock_result] + + retriever = SummariesRetriever() + + with patch( + "cognee.modules.retrieval.summaries_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + completion = await retriever.get_completion("test query", extra_param="value") + + assert len(completion) == 1 diff --git a/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py b/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py new file mode 100644 index 000000000..1d2f4c84d --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/temporal_retriever_test.py @@ -0,0 +1,705 @@ +from types import SimpleNamespace +import pytest +import os +from unittest.mock import AsyncMock, patch, MagicMock +from datetime import datetime + +from cognee.modules.retrieval.temporal_retriever import TemporalRetriever +from cognee.tasks.temporal_graph.models import QueryInterval, Timestamp +from cognee.infrastructure.llm import LLMGateway + + +# Test TemporalRetriever initialization defaults and overrides +def test_init_defaults_and_overrides(): + tr = TemporalRetriever() + assert tr.top_k == 5 + assert tr.user_prompt_path == "graph_context_for_question.txt" + assert tr.system_prompt_path == "answer_simple_question.txt" + assert tr.time_extraction_prompt_path == "extract_query_time.txt" + + tr2 = TemporalRetriever( + top_k=3, + user_prompt_path="u.txt", + system_prompt_path="s.txt", + time_extraction_prompt_path="t.txt", + ) + assert tr2.top_k == 3 + assert tr2.user_prompt_path == "u.txt" + assert tr2.system_prompt_path == "s.txt" + assert tr2.time_extraction_prompt_path == "t.txt" + + +# Test descriptions_to_string with basic and empty results +def test_descriptions_to_string_basic_and_empty(): + tr = TemporalRetriever() + + results = [ + {"description": " First "}, + {"nope": "no description"}, + {"description": "Second"}, + {"description": ""}, + {"description": " Third line "}, + ] + + s = tr.descriptions_to_string(results) + assert s == "First\n#####################\nSecond\n#####################\nThird line" + + assert tr.descriptions_to_string([]) == "" + + +# Test filter_top_k_events sorts and limits correctly +@pytest.mark.asyncio +async def test_filter_top_k_events_sorts_and_limits(): + tr = TemporalRetriever(top_k=2) + + relevant_events = [ + { + "events": [ + {"id": "e1", "description": "E1"}, + {"id": "e2", "description": "E2"}, + {"id": "e3", "description": "E3 - not in vector results"}, + ] + } + ] + + scored_results = [ + SimpleNamespace(payload={"id": "e2"}, score=0.10), + SimpleNamespace(payload={"id": "e1"}, score=0.20), + ] + + top = await tr.filter_top_k_events(relevant_events, scored_results) + + assert [e["id"] for e in top] == ["e2", "e1"] + assert all("score" in e for e in top) + assert top[0]["score"] == 0.10 + assert top[1]["score"] == 0.20 + + +# Test filter_top_k_events handles unknown ids as infinite scores +@pytest.mark.asyncio +async def test_filter_top_k_events_includes_unknown_as_infinite_but_not_in_top_k(): + tr = TemporalRetriever(top_k=2) + + relevant_events = [ + { + "events": [ + {"id": "known1", "description": "Known 1"}, + {"id": "unknown", "description": "Unknown"}, + {"id": "known2", "description": "Known 2"}, + ] + } + ] + + scored_results = [ + SimpleNamespace(payload={"id": "known2"}, score=0.05), + SimpleNamespace(payload={"id": "known1"}, score=0.50), + ] + + top = await tr.filter_top_k_events(relevant_events, scored_results) + assert [e["id"] for e in top] == ["known2", "known1"] + assert all(e["score"] != float("inf") for e in top) + + +# Test descriptions_to_string with unicode and newlines +def test_descriptions_to_string_unicode_and_newlines(): + tr = TemporalRetriever() + results = [ + {"description": "Line A\nwith newline"}, + {"description": "This is a description"}, + ] + s = tr.descriptions_to_string(results) + assert "Line A\nwith newline" in s + assert "This is a description" in s + assert s.count("#####################") == 1 + + +# Test filter_top_k_events when top_k is larger than available events +@pytest.mark.asyncio +async def test_filter_top_k_events_limits_when_top_k_exceeds_events(): + tr = TemporalRetriever(top_k=10) + relevant_events = [{"events": [{"id": "a"}, {"id": "b"}]}] + scored_results = [ + SimpleNamespace(payload={"id": "a"}, score=0.1), + SimpleNamespace(payload={"id": "b"}, score=0.2), + ] + out = await tr.filter_top_k_events(relevant_events, scored_results) + assert [e["id"] for e in out] == ["a", "b"] + + +# Test filter_top_k_events when scored_results is empty +@pytest.mark.asyncio +async def test_filter_top_k_events_handles_empty_scored_results(): + tr = TemporalRetriever(top_k=2) + relevant_events = [{"events": [{"id": "x"}, {"id": "y"}]}] + scored_results = [] + out = await tr.filter_top_k_events(relevant_events, scored_results) + assert [e["id"] for e in out] == ["x", "y"] + assert all(e["score"] == float("inf") for e in out) + + +# Test filter_top_k_events error handling for missing structure +@pytest.mark.asyncio +async def test_filter_top_k_events_error_handling(): + tr = TemporalRetriever(top_k=2) + with pytest.raises((KeyError, TypeError)): + await tr.filter_top_k_events([{}], []) + + +@pytest.fixture +def mock_graph_engine(): + """Create a mock graph engine.""" + engine = AsyncMock() + engine.collect_time_ids = AsyncMock() + engine.collect_events = AsyncMock() + return engine + + +@pytest.fixture +def mock_vector_engine(): + """Create a mock vector engine.""" + engine = AsyncMock() + engine.embedding_engine = AsyncMock() + engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + engine.search = AsyncMock() + return engine + + +@pytest.mark.asyncio +async def test_get_context_with_time_range(mock_graph_engine, mock_vector_engine): + """Test get_context when time range is extracted from query.""" + retriever = TemporalRetriever(top_k=5) + + mock_graph_engine.collect_time_ids.return_value = ["e1", "e2"] + mock_graph_engine.collect_events.return_value = [ + { + "events": [ + {"id": "e1", "description": "Event 1"}, + {"id": "e2", "description": "Event 2"}, + ] + } + ] + + mock_result1 = SimpleNamespace(payload={"id": "e2"}, score=0.05) + mock_result2 = SimpleNamespace(payload={"id": "e1"}, score=0.10) + mock_vector_engine.search.return_value = [mock_result1, mock_result2] + + with ( + patch.object( + retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31") + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + ): + context = await retriever.get_context("What happened in 2024?") + + assert isinstance(context, str) + assert len(context) > 0 + assert "Event" in context + + +@pytest.mark.asyncio +async def test_get_context_fallback_to_triplets_no_time(mock_graph_engine): + """Test get_context falls back to triplets when no time is extracted.""" + retriever = TemporalRetriever() + + with ( + patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch.object( + retriever, "get_triplets", return_value=[{"s": "a", "p": "b", "o": "c"}] + ) as mock_get_triplets, + patch.object( + retriever, "resolve_edges_to_text", return_value="triplet text" + ) as mock_resolve, + ): + + async def mock_extract_time(query): + return None, None + + retriever.extract_time_from_query = mock_extract_time + + context = await retriever.get_context("test query") + + assert context == "triplet text" + mock_get_triplets.assert_awaited_once_with("test query") + mock_resolve.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_context_no_events_found(mock_graph_engine): + """Test get_context falls back to triplets when no events are found.""" + retriever = TemporalRetriever() + + mock_graph_engine.collect_time_ids.return_value = [] + + with ( + patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch.object( + retriever, "get_triplets", return_value=[{"s": "a", "p": "b", "o": "c"}] + ) as mock_get_triplets, + patch.object( + retriever, "resolve_edges_to_text", return_value="triplet text" + ) as mock_resolve, + ): + + async def mock_extract_time(query): + return "2024-01-01", "2024-12-31" + + retriever.extract_time_from_query = mock_extract_time + + context = await retriever.get_context("test query") + + assert context == "triplet text" + mock_get_triplets.assert_awaited_once_with("test query") + mock_resolve.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_context_time_from_only(mock_graph_engine, mock_vector_engine): + """Test get_context with only time_from.""" + retriever = TemporalRetriever(top_k=5) + + mock_graph_engine.collect_time_ids.return_value = ["e1"] + mock_graph_engine.collect_events.return_value = [ + { + "events": [ + {"id": "e1", "description": "Event 1"}, + ] + } + ] + + mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05) + mock_vector_engine.search.return_value = [mock_result] + + with ( + patch.object(retriever, "extract_time_from_query", return_value=("2024-01-01", None)), + patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + ): + context = await retriever.get_context("What happened after 2024?") + + assert isinstance(context, str) + assert "Event 1" in context + + +@pytest.mark.asyncio +async def test_get_context_time_to_only(mock_graph_engine, mock_vector_engine): + """Test get_context with only time_to.""" + retriever = TemporalRetriever(top_k=5) + + mock_graph_engine.collect_time_ids.return_value = ["e1"] + mock_graph_engine.collect_events.return_value = [ + { + "events": [ + {"id": "e1", "description": "Event 1"}, + ] + } + ] + + mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05) + mock_vector_engine.search.return_value = [mock_result] + + with ( + patch.object(retriever, "extract_time_from_query", return_value=(None, "2024-12-31")), + patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + ): + context = await retriever.get_context("What happened before 2024?") + + assert isinstance(context, str) + assert "Event 1" in context + + +@pytest.mark.asyncio +async def test_get_completion_without_context(mock_graph_engine, mock_vector_engine): + """Test get_completion retrieves context when not provided.""" + retriever = TemporalRetriever() + + mock_graph_engine.collect_time_ids.return_value = ["e1"] + mock_graph_engine.collect_events.return_value = [ + { + "events": [ + {"id": "e1", "description": "Event 1"}, + ] + } + ] + + mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05) + mock_vector_engine.search.return_value = [mock_result] + + with ( + patch.object( + retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31") + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.generate_completion", + return_value="Generated answer", + ), + patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("What happened in 2024?") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_provided_context(): + """Test get_completion uses provided context.""" + retriever = TemporalRetriever() + + with ( + patch( + "cognee.modules.retrieval.temporal_retriever.generate_completion", + return_value="Generated answer", + ), + patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", context="Provided context") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_session(mock_graph_engine, mock_vector_engine): + """Test get_completion with session caching enabled.""" + retriever = TemporalRetriever() + + mock_graph_engine.collect_time_ids.return_value = ["e1"] + mock_graph_engine.collect_events.return_value = [ + { + "events": [ + {"id": "e1", "description": "Event 1"}, + ] + } + ] + + mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05) + mock_vector_engine.search.return_value = [mock_result] + + mock_user = MagicMock() + mock_user.id = "test-user-id" + + with ( + patch.object( + retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31") + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_conversation_history", + return_value="Previous conversation", + ), + patch( + "cognee.modules.retrieval.temporal_retriever.summarize_text", + return_value="Context summary", + ), + patch( + "cognee.modules.retrieval.temporal_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.temporal_retriever.save_conversation_history", + ) as mock_save, + patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config, + patch("cognee.modules.retrieval.temporal_retriever.session_user") as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = mock_user + + completion = await retriever.get_completion( + "What happened in 2024?", session_id="test_session" + ) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + mock_save.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_completion_with_session_no_user_id(mock_graph_engine, mock_vector_engine): + """Test get_completion with session config but no user ID.""" + retriever = TemporalRetriever() + + mock_graph_engine.collect_time_ids.return_value = ["e1"] + mock_graph_engine.collect_events.return_value = [ + { + "events": [ + {"id": "e1", "description": "Event 1"}, + ] + } + ] + + mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05) + mock_vector_engine.search.return_value = [mock_result] + + with ( + patch.object( + retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31") + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.generate_completion", + return_value="Generated answer", + ), + patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config, + patch("cognee.modules.retrieval.temporal_retriever.session_user") as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = None # No user + + completion = await retriever.get_completion("What happened in 2024?") + + assert isinstance(completion, list) + assert len(completion) == 1 + + +@pytest.mark.asyncio +async def test_get_completion_context_retrieved_but_empty(mock_graph_engine): + """Test get_completion when get_context returns empty string.""" + retriever = TemporalRetriever() + + with ( + patch.object( + retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31") + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_vector_engine", + ) as mock_get_vector, + patch.object(retriever, "filter_top_k_events", return_value=[]), + ): + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[]) + mock_get_vector.return_value = mock_vector_engine + + mock_graph_engine.collect_time_ids.return_value = ["e1"] + mock_graph_engine.collect_events.return_value = [ + { + "events": [ + {"id": "e1", "description": ""}, + ] + } + ] + + with pytest.raises((UnboundLocalError, NameError)): + await retriever.get_completion("test query") + + +@pytest.mark.asyncio +async def test_get_completion_with_response_model(mock_graph_engine, mock_vector_engine): + """Test get_completion with custom response model.""" + from pydantic import BaseModel + + class TestModel(BaseModel): + answer: str + + retriever = TemporalRetriever() + + mock_graph_engine.collect_time_ids.return_value = ["e1"] + mock_graph_engine.collect_events.return_value = [ + { + "events": [ + {"id": "e1", "description": "Event 1"}, + ] + } + ] + + mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05) + mock_vector_engine.search.return_value = [mock_result] + + with ( + patch.object( + retriever, "extract_time_from_query", return_value=("2024-01-01", "2024-12-31") + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.temporal_retriever.generate_completion", + return_value=TestModel(answer="Test answer"), + ), + patch("cognee.modules.retrieval.temporal_retriever.CacheConfig") as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion( + "What happened in 2024?", response_model=TestModel + ) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert isinstance(completion[0], TestModel) + + +@pytest.mark.asyncio +async def test_extract_time_from_query_relative_path(): + """Test extract_time_from_query with relative prompt path.""" + retriever = TemporalRetriever(time_extraction_prompt_path="extract_query_time.txt") + + mock_timestamp_from = Timestamp(year=2024, month=1, day=1) + mock_timestamp_to = Timestamp(year=2024, month=12, day=31) + mock_interval = QueryInterval(starts_at=mock_timestamp_from, ends_at=mock_timestamp_to) + + with ( + patch("cognee.modules.retrieval.temporal_retriever.os.path.isabs", return_value=False), + patch("cognee.modules.retrieval.temporal_retriever.datetime") as mock_datetime, + patch( + "cognee.modules.retrieval.temporal_retriever.render_prompt", + return_value="System prompt", + ), + patch.object( + LLMGateway, + "acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_interval, + ), + ): + mock_datetime.now.return_value.strftime.return_value = "11-12-2024" + + time_from, time_to = await retriever.extract_time_from_query("What happened in 2024?") + + assert time_from == mock_timestamp_from + assert time_to == mock_timestamp_to + + +@pytest.mark.asyncio +async def test_extract_time_from_query_absolute_path(): + """Test extract_time_from_query with absolute prompt path.""" + retriever = TemporalRetriever( + time_extraction_prompt_path="/absolute/path/to/extract_query_time.txt" + ) + + mock_timestamp_from = Timestamp(year=2024, month=1, day=1) + mock_timestamp_to = Timestamp(year=2024, month=12, day=31) + mock_interval = QueryInterval(starts_at=mock_timestamp_from, ends_at=mock_timestamp_to) + + with ( + patch("cognee.modules.retrieval.temporal_retriever.os.path.isabs", return_value=True), + patch( + "cognee.modules.retrieval.temporal_retriever.os.path.dirname", + return_value="/absolute/path/to", + ), + patch( + "cognee.modules.retrieval.temporal_retriever.os.path.basename", + return_value="extract_query_time.txt", + ), + patch("cognee.modules.retrieval.temporal_retriever.datetime") as mock_datetime, + patch( + "cognee.modules.retrieval.temporal_retriever.render_prompt", + return_value="System prompt", + ), + patch.object( + LLMGateway, + "acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_interval, + ), + ): + mock_datetime.now.return_value.strftime.return_value = "11-12-2024" + + time_from, time_to = await retriever.extract_time_from_query("What happened in 2024?") + + assert time_from == mock_timestamp_from + assert time_to == mock_timestamp_to + + +@pytest.mark.asyncio +async def test_extract_time_from_query_with_none_values(): + """Test extract_time_from_query when interval has None values.""" + retriever = TemporalRetriever(time_extraction_prompt_path="extract_query_time.txt") + + mock_interval = QueryInterval(starts_at=None, ends_at=None) + + with ( + patch("cognee.modules.retrieval.temporal_retriever.os.path.isabs", return_value=False), + patch("cognee.modules.retrieval.temporal_retriever.datetime") as mock_datetime, + patch( + "cognee.modules.retrieval.temporal_retriever.render_prompt", + return_value="System prompt", + ), + patch.object( + LLMGateway, + "acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_interval, + ), + ): + mock_datetime.now.return_value.strftime.return_value = "11-12-2024" + + time_from, time_to = await retriever.extract_time_from_query("What happened?") + + assert time_from is None + assert time_to is None diff --git a/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py b/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py new file mode 100644 index 000000000..b7cbe08d7 --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py @@ -0,0 +1,817 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock + +from cognee.modules.retrieval.utils.brute_force_triplet_search import ( + brute_force_triplet_search, + get_memory_fragment, + format_triplets, +) +from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph +from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError +from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError + + +class MockScoredResult: + """Mock class for vector search results.""" + + def __init__(self, id, score, payload=None): + self.id = id + self.score = score + self.payload = payload or {} + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_empty_query(): + """Test that empty query raises ValueError.""" + with pytest.raises(ValueError, match="The query must be a non-empty string."): + await brute_force_triplet_search(query="") + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_none_query(): + """Test that None query raises ValueError.""" + with pytest.raises(ValueError, match="The query must be a non-empty string."): + await brute_force_triplet_search(query=None) + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_negative_top_k(): + """Test that negative top_k raises ValueError.""" + with pytest.raises(ValueError, match="top_k must be a positive integer."): + await brute_force_triplet_search(query="test query", top_k=-1) + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_zero_top_k(): + """Test that zero top_k raises ValueError.""" + with pytest.raises(ValueError, match="top_k must be a positive integer."): + await brute_force_triplet_search(query="test query", top_k=0) + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_wide_search_limit_global_search(): + """Test that wide_search_limit is applied for global search (node_name=None).""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[]) + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ): + await brute_force_triplet_search( + query="test", + node_name=None, # Global search + wide_search_top_k=75, + ) + + for call in mock_vector_engine.search.call_args_list: + assert call[1]["limit"] == 75 + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_wide_search_limit_filtered_search(): + """Test that wide_search_limit is None for filtered search (node_name provided).""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[]) + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ): + await brute_force_triplet_search( + query="test", + node_name=["Node1"], + wide_search_top_k=50, + ) + + for call in mock_vector_engine.search.call_args_list: + assert call[1]["limit"] is None + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_wide_search_default(): + """Test that wide_search_top_k defaults to 100.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[]) + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ): + await brute_force_triplet_search(query="test", node_name=None) + + for call in mock_vector_engine.search.call_args_list: + assert call[1]["limit"] == 100 + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_default_collections(): + """Test that default collections are used when none provided.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[]) + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ): + await brute_force_triplet_search(query="test") + + expected_collections = [ + "Entity_name", + "TextSummary_text", + "EntityType_name", + "DocumentChunk_text", + "EdgeType_relationship_name", + ] + + call_collections = [ + call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list + ] + assert call_collections == expected_collections + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_custom_collections(): + """Test that custom collections are used when provided.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[]) + + custom_collections = ["CustomCol1", "CustomCol2"] + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ): + await brute_force_triplet_search(query="test", collections=custom_collections) + + call_collections = [ + call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list + ] + assert set(call_collections) == set(custom_collections) | {"EdgeType_relationship_name"} + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_always_includes_edge_collection(): + """Test that EdgeType_relationship_name is always searched even when not in collections.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[]) + + collections_without_edge = ["Entity_name", "TextSummary_text"] + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ): + await brute_force_triplet_search(query="test", collections=collections_without_edge) + + call_collections = [ + call[1]["collection_name"] for call in mock_vector_engine.search.call_args_list + ] + assert "EdgeType_relationship_name" in call_collections + assert set(call_collections) == set(collections_without_edge) | { + "EdgeType_relationship_name" + } + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_all_collections_empty(): + """Test that empty list is returned when all collections return no results.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[]) + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ): + results = await brute_force_triplet_search(query="test") + assert results == [] + + +# Tests for query embedding + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_embeds_query(): + """Test that query is embedded before searching.""" + query_text = "test query" + expected_vector = [0.1, 0.2, 0.3] + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[expected_vector]) + mock_vector_engine.search = AsyncMock(return_value=[]) + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ): + await brute_force_triplet_search(query=query_text) + + mock_vector_engine.embedding_engine.embed_text.assert_called_once_with([query_text]) + + for call in mock_vector_engine.search.call_args_list: + assert call[1]["query_vector"] == expected_vector + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_extracts_node_ids_global_search(): + """Test that node IDs are extracted from search results for global search.""" + scored_results = [ + MockScoredResult("node1", 0.95), + MockScoredResult("node2", 0.87), + MockScoredResult("node3", 0.92), + ] + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=scored_results) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment_fn, + ): + await brute_force_triplet_search(query="test", node_name=None) + + call_kwargs = mock_get_fragment_fn.call_args[1] + assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2", "node3"} + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_reuses_provided_fragment(): + """Test that provided memory fragment is reused instead of creating new one.""" + provided_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)]) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment" + ) as mock_get_fragment, + ): + await brute_force_triplet_search( + query="test", + memory_fragment=provided_fragment, + node_name=["node"], + ) + + mock_get_fragment.assert_not_called() + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_creates_fragment_when_not_provided(): + """Test that memory fragment is created when not provided.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)]) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment, + ): + await brute_force_triplet_search(query="test", node_name=["node"]) + + mock_get_fragment.assert_called_once() + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_passes_top_k_to_importance_calculation(): + """Test that custom top_k is passed to importance calculation.""" + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(return_value=[MockScoredResult("n1", 0.95)]) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ), + ): + custom_top_k = 15 + await brute_force_triplet_search(query="test", top_k=custom_top_k, node_name=["n"]) + + mock_fragment.calculate_top_triplet_importances.assert_called_once_with(k=custom_top_k) + + +@pytest.mark.asyncio +async def test_get_memory_fragment_returns_empty_graph_on_entity_not_found(): + """Test that get_memory_fragment returns empty graph when entity not found (line 85).""" + mock_graph_engine = AsyncMock() + + # Create a mock fragment that will raise EntityNotFoundError when project_graph_from_db is called + mock_fragment = MagicMock(spec=CogneeGraph) + mock_fragment.project_graph_from_db = AsyncMock( + side_effect=EntityNotFoundError("Entity not found") + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.CogneeGraph", + return_value=mock_fragment, + ), + ): + result = await get_memory_fragment() + + # Fragment should be returned even though EntityNotFoundError was raised (pass statement on line 85) + assert result == mock_fragment + mock_fragment.project_graph_from_db.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_memory_fragment_returns_empty_graph_on_error(): + """Test that get_memory_fragment returns empty graph on generic error.""" + mock_graph_engine = AsyncMock() + mock_graph_engine.project_graph_from_db = AsyncMock(side_effect=Exception("Generic error")) + + with patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_graph_engine", + return_value=mock_graph_engine, + ): + fragment = await get_memory_fragment() + + assert isinstance(fragment, CogneeGraph) + assert len(fragment.nodes) == 0 + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_deduplicates_node_ids(): + """Test that duplicate node IDs across collections are deduplicated.""" + + def search_side_effect(*args, **kwargs): + collection_name = kwargs.get("collection_name") + if collection_name == "Entity_name": + return [ + MockScoredResult("node1", 0.95), + MockScoredResult("node2", 0.87), + ] + elif collection_name == "TextSummary_text": + return [ + MockScoredResult("node1", 0.90), + MockScoredResult("node3", 0.92), + ] + else: + return [] + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment_fn, + ): + await brute_force_triplet_search(query="test", node_name=None) + + call_kwargs = mock_get_fragment_fn.call_args[1] + assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2", "node3"} + assert len(call_kwargs["relevant_ids_to_filter"]) == 3 + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_excludes_edge_collection(): + """Test that EdgeType_relationship_name collection is excluded from ID extraction.""" + + def search_side_effect(*args, **kwargs): + collection_name = kwargs.get("collection_name") + if collection_name == "Entity_name": + return [MockScoredResult("node1", 0.95)] + elif collection_name == "EdgeType_relationship_name": + return [MockScoredResult("edge1", 0.88)] + else: + return [] + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment_fn, + ): + await brute_force_triplet_search( + query="test", + node_name=None, + collections=["Entity_name", "EdgeType_relationship_name"], + ) + + call_kwargs = mock_get_fragment_fn.call_args[1] + assert call_kwargs["relevant_ids_to_filter"] == ["node1"] + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_skips_nodes_without_ids(): + """Test that nodes without ID attribute are skipped.""" + + class ScoredResultNoId: + """Mock result without id attribute.""" + + def __init__(self, score): + self.score = score + + def search_side_effect(*args, **kwargs): + collection_name = kwargs.get("collection_name") + if collection_name == "Entity_name": + return [ + MockScoredResult("node1", 0.95), + ScoredResultNoId(0.90), + MockScoredResult("node2", 0.87), + ] + else: + return [] + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment_fn, + ): + await brute_force_triplet_search(query="test", node_name=None) + + call_kwargs = mock_get_fragment_fn.call_args[1] + assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"} + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_handles_tuple_results(): + """Test that both list and tuple results are handled correctly.""" + + def search_side_effect(*args, **kwargs): + collection_name = kwargs.get("collection_name") + if collection_name == "Entity_name": + return ( + MockScoredResult("node1", 0.95), + MockScoredResult("node2", 0.87), + ) + else: + return [] + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment_fn, + ): + await brute_force_triplet_search(query="test", node_name=None) + + call_kwargs = mock_get_fragment_fn.call_args[1] + assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"} + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_mixed_empty_collections(): + """Test ID extraction with mixed empty and non-empty collections.""" + + def search_side_effect(*args, **kwargs): + collection_name = kwargs.get("collection_name") + if collection_name == "Entity_name": + return [MockScoredResult("node1", 0.95)] + elif collection_name == "TextSummary_text": + return [] + elif collection_name == "EntityType_name": + return [MockScoredResult("node2", 0.92)] + else: + return [] + + mock_vector_engine = AsyncMock() + mock_vector_engine.embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + mock_vector_engine.search = AsyncMock(side_effect=search_side_effect) + + mock_fragment = AsyncMock( + map_vector_distances_to_graph_nodes=AsyncMock(), + map_vector_distances_to_graph_edges=AsyncMock(), + calculate_top_triplet_importances=AsyncMock(return_value=[]), + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment_fn, + ): + await brute_force_triplet_search(query="test", node_name=None) + + call_kwargs = mock_get_fragment_fn.call_args[1] + assert set(call_kwargs["relevant_ids_to_filter"]) == {"node1", "node2"} + + +def test_format_triplets(): + """Test format_triplets function.""" + mock_edge = MagicMock() + mock_node1 = MagicMock() + mock_node2 = MagicMock() + + mock_node1.attributes = {"name": "Node1", "type": "Entity", "id": "n1"} + mock_node2.attributes = {"name": "Node2", "type": "Entity", "id": "n2"} + mock_edge.attributes = {"relationship_name": "relates_to", "edge_text": "connects"} + + mock_edge.node1 = mock_node1 + mock_edge.node2 = mock_node2 + + result = format_triplets([mock_edge]) + + assert isinstance(result, str) + assert "Node1" in result + assert "Node2" in result + assert "relates_to" in result + assert "connects" in result + + +def test_format_triplets_with_none_values(): + """Test format_triplets filters out None values.""" + mock_edge = MagicMock() + mock_node1 = MagicMock() + mock_node2 = MagicMock() + + mock_node1.attributes = {"name": "Node1", "type": None, "id": "n1"} + mock_node2.attributes = {"name": "Node2", "type": "Entity", "id": None} + mock_edge.attributes = {"relationship_name": "relates_to", "edge_text": None} + + mock_edge.node1 = mock_node1 + mock_edge.node2 = mock_node2 + + result = format_triplets([mock_edge]) + + assert "Node1" in result + assert "Node2" in result + assert "relates_to" in result + assert "None" not in result or result.count("None") == 0 + + +def test_format_triplets_with_nested_dict(): + """Test format_triplets handles nested dict attributes (lines 23-35).""" + mock_edge = MagicMock() + mock_node1 = MagicMock() + mock_node2 = MagicMock() + + mock_node1.attributes = {"name": "Node1", "metadata": {"type": "Entity", "id": "n1"}} + mock_node2.attributes = {"name": "Node2", "metadata": {"type": "Entity", "id": "n2"}} + mock_edge.attributes = {"relationship_name": "relates_to"} + + mock_edge.node1 = mock_node1 + mock_edge.node2 = mock_node2 + + result = format_triplets([mock_edge]) + + assert isinstance(result, str) + assert "Node1" in result + assert "Node2" in result + assert "relates_to" in result + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_vector_engine_init_error(): + """Test brute_force_triplet_search handles vector engine initialization error (lines 145-147).""" + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine" + ) as mock_get_vector_engine, + ): + mock_get_vector_engine.side_effect = Exception("Initialization error") + + with pytest.raises(RuntimeError, match="Initialization error"): + await brute_force_triplet_search(query="test query") + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_collection_not_found_error(): + """Test brute_force_triplet_search handles CollectionNotFoundError in search (lines 156-157).""" + mock_vector_engine = AsyncMock() + mock_embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine = mock_embedding_engine + mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + + mock_vector_engine.search = AsyncMock( + side_effect=[ + CollectionNotFoundError("Collection not found"), + [], + [], + ] + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=CogneeGraph(), + ), + ): + result = await brute_force_triplet_search( + query="test query", collections=["missing_collection", "existing_collection"] + ) + + assert result == [] + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_generic_exception(): + """Test brute_force_triplet_search handles generic exceptions (lines 209-217).""" + mock_vector_engine = AsyncMock() + mock_embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine = mock_embedding_engine + mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + + mock_vector_engine.search = AsyncMock(side_effect=Exception("Generic error")) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + ): + with pytest.raises(Exception, match="Generic error"): + await brute_force_triplet_search(query="test query") + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_with_node_name_sets_relevant_ids_to_none(): + """Test brute_force_triplet_search sets relevant_ids_to_filter to None when node_name is provided (line 191).""" + mock_vector_engine = AsyncMock() + mock_embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine = mock_embedding_engine + mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + + mock_result = MockScoredResult(id="node1", score=0.8, payload={"id": "node1"}) + mock_vector_engine.search = AsyncMock(return_value=[mock_result]) + + mock_fragment = AsyncMock() + mock_fragment.map_vector_distances_to_graph_nodes = AsyncMock() + mock_fragment.map_vector_distances_to_graph_edges = AsyncMock() + mock_fragment.calculate_top_triplet_importances = AsyncMock(return_value=[]) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ) as mock_get_fragment, + ): + await brute_force_triplet_search(query="test query", node_name=["Node1"]) + + assert mock_get_fragment.called + call_kwargs = mock_get_fragment.call_args.kwargs if mock_get_fragment.call_args else {} + assert call_kwargs.get("relevant_ids_to_filter") is None + + +@pytest.mark.asyncio +async def test_brute_force_triplet_search_collection_not_found_at_top_level(): + """Test brute_force_triplet_search handles CollectionNotFoundError at top level (line 210).""" + mock_vector_engine = AsyncMock() + mock_embedding_engine = AsyncMock() + mock_vector_engine.embedding_engine = mock_embedding_engine + mock_embedding_engine.embed_text = AsyncMock(return_value=[[0.1, 0.2, 0.3]]) + + mock_result = MockScoredResult(id="node1", score=0.8, payload={"id": "node1"}) + mock_vector_engine.search = AsyncMock(return_value=[mock_result]) + + mock_fragment = AsyncMock() + mock_fragment.map_vector_distances_to_graph_nodes = AsyncMock() + mock_fragment.map_vector_distances_to_graph_edges = AsyncMock() + mock_fragment.calculate_top_triplet_importances = AsyncMock( + side_effect=CollectionNotFoundError("Collection not found") + ) + + with ( + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.utils.brute_force_triplet_search.get_memory_fragment", + return_value=mock_fragment, + ), + ): + result = await brute_force_triplet_search(query="test query") + + assert result == [] diff --git a/cognee/tests/unit/modules/retrieval/test_completion.py b/cognee/tests/unit/modules/retrieval/test_completion.py new file mode 100644 index 000000000..9a836c2cc --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/test_completion.py @@ -0,0 +1,343 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock +from typing import Type + + +class TestGenerateCompletion: + @pytest.mark.asyncio + async def test_generate_completion_with_system_prompt(self): + """Test generate_completion with provided system_prompt.""" + mock_llm_response = "Generated answer" + + with ( + patch( + "cognee.modules.retrieval.utils.completion.render_prompt", + return_value="User prompt text", + ), + patch( + "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_llm_response, + ) as mock_llm, + ): + from cognee.modules.retrieval.utils.completion import generate_completion + + result = await generate_completion( + query="What is AI?", + context="AI is artificial intelligence", + user_prompt_path="user_prompt.txt", + system_prompt_path="system_prompt.txt", + system_prompt="Custom system prompt", + ) + + assert result == mock_llm_response + mock_llm.assert_awaited_once_with( + text_input="User prompt text", + system_prompt="Custom system prompt", + response_model=str, + ) + + @pytest.mark.asyncio + async def test_generate_completion_without_system_prompt(self): + """Test generate_completion reads system_prompt from file when not provided.""" + mock_llm_response = "Generated answer" + + with ( + patch( + "cognee.modules.retrieval.utils.completion.render_prompt", + return_value="User prompt text", + ), + patch( + "cognee.modules.retrieval.utils.completion.read_query_prompt", + return_value="System prompt from file", + ), + patch( + "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_llm_response, + ) as mock_llm, + ): + from cognee.modules.retrieval.utils.completion import generate_completion + + result = await generate_completion( + query="What is AI?", + context="AI is artificial intelligence", + user_prompt_path="user_prompt.txt", + system_prompt_path="system_prompt.txt", + ) + + assert result == mock_llm_response + mock_llm.assert_awaited_once_with( + text_input="User prompt text", + system_prompt="System prompt from file", + response_model=str, + ) + + @pytest.mark.asyncio + async def test_generate_completion_with_conversation_history(self): + """Test generate_completion includes conversation_history in system_prompt.""" + mock_llm_response = "Generated answer" + + with ( + patch( + "cognee.modules.retrieval.utils.completion.render_prompt", + return_value="User prompt text", + ), + patch( + "cognee.modules.retrieval.utils.completion.read_query_prompt", + return_value="System prompt from file", + ), + patch( + "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_llm_response, + ) as mock_llm, + ): + from cognee.modules.retrieval.utils.completion import generate_completion + + result = await generate_completion( + query="What is AI?", + context="AI is artificial intelligence", + user_prompt_path="user_prompt.txt", + system_prompt_path="system_prompt.txt", + conversation_history="Previous conversation:\nQ: What is ML?\nA: ML is machine learning", + ) + + assert result == mock_llm_response + expected_system_prompt = ( + "Previous conversation:\nQ: What is ML?\nA: ML is machine learning" + + "\nTASK:" + + "System prompt from file" + ) + mock_llm.assert_awaited_once_with( + text_input="User prompt text", + system_prompt=expected_system_prompt, + response_model=str, + ) + + @pytest.mark.asyncio + async def test_generate_completion_with_conversation_history_and_custom_system_prompt(self): + """Test generate_completion includes conversation_history with custom system_prompt.""" + mock_llm_response = "Generated answer" + + with ( + patch( + "cognee.modules.retrieval.utils.completion.render_prompt", + return_value="User prompt text", + ), + patch( + "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_llm_response, + ) as mock_llm, + ): + from cognee.modules.retrieval.utils.completion import generate_completion + + result = await generate_completion( + query="What is AI?", + context="AI is artificial intelligence", + user_prompt_path="user_prompt.txt", + system_prompt_path="system_prompt.txt", + system_prompt="Custom system prompt", + conversation_history="Previous conversation:\nQ: What is ML?\nA: ML is machine learning", + ) + + assert result == mock_llm_response + expected_system_prompt = ( + "Previous conversation:\nQ: What is ML?\nA: ML is machine learning" + + "\nTASK:" + + "Custom system prompt" + ) + mock_llm.assert_awaited_once_with( + text_input="User prompt text", + system_prompt=expected_system_prompt, + response_model=str, + ) + + @pytest.mark.asyncio + async def test_generate_completion_with_response_model(self): + """Test generate_completion with custom response_model.""" + mock_response_model = MagicMock() + mock_llm_response = {"answer": "Generated answer"} + + with ( + patch( + "cognee.modules.retrieval.utils.completion.render_prompt", + return_value="User prompt text", + ), + patch( + "cognee.modules.retrieval.utils.completion.read_query_prompt", + return_value="System prompt from file", + ), + patch( + "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_llm_response, + ) as mock_llm, + ): + from cognee.modules.retrieval.utils.completion import generate_completion + + result = await generate_completion( + query="What is AI?", + context="AI is artificial intelligence", + user_prompt_path="user_prompt.txt", + system_prompt_path="system_prompt.txt", + response_model=mock_response_model, + ) + + assert result == mock_llm_response + mock_llm.assert_awaited_once_with( + text_input="User prompt text", + system_prompt="System prompt from file", + response_model=mock_response_model, + ) + + @pytest.mark.asyncio + async def test_generate_completion_render_prompt_args(self): + """Test generate_completion passes correct args to render_prompt.""" + mock_llm_response = "Generated answer" + + with ( + patch( + "cognee.modules.retrieval.utils.completion.render_prompt", + return_value="User prompt text", + ) as mock_render, + patch( + "cognee.modules.retrieval.utils.completion.read_query_prompt", + return_value="System prompt from file", + ), + patch( + "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_llm_response, + ), + ): + from cognee.modules.retrieval.utils.completion import generate_completion + + await generate_completion( + query="What is AI?", + context="AI is artificial intelligence", + user_prompt_path="user_prompt.txt", + system_prompt_path="system_prompt.txt", + ) + + mock_render.assert_called_once_with( + "user_prompt.txt", + {"question": "What is AI?", "context": "AI is artificial intelligence"}, + ) + + +class TestSummarizeText: + @pytest.mark.asyncio + async def test_summarize_text_with_system_prompt(self): + """Test summarize_text with provided system_prompt.""" + mock_llm_response = "Summary text" + + with patch( + "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_llm_response, + ) as mock_llm: + from cognee.modules.retrieval.utils.completion import summarize_text + + result = await summarize_text( + text="Long text to summarize", + system_prompt_path="summarize_search_results.txt", + system_prompt="Custom summary prompt", + ) + + assert result == mock_llm_response + mock_llm.assert_awaited_once_with( + text_input="Long text to summarize", + system_prompt="Custom summary prompt", + response_model=str, + ) + + @pytest.mark.asyncio + async def test_summarize_text_without_system_prompt(self): + """Test summarize_text reads system_prompt from file when not provided.""" + mock_llm_response = "Summary text" + + with ( + patch( + "cognee.modules.retrieval.utils.completion.read_query_prompt", + return_value="System prompt from file", + ), + patch( + "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_llm_response, + ) as mock_llm, + ): + from cognee.modules.retrieval.utils.completion import summarize_text + + result = await summarize_text( + text="Long text to summarize", + system_prompt_path="summarize_search_results.txt", + ) + + assert result == mock_llm_response + mock_llm.assert_awaited_once_with( + text_input="Long text to summarize", + system_prompt="System prompt from file", + response_model=str, + ) + + @pytest.mark.asyncio + async def test_summarize_text_default_prompt_path(self): + """Test summarize_text uses default prompt path when not provided.""" + mock_llm_response = "Summary text" + + with ( + patch( + "cognee.modules.retrieval.utils.completion.read_query_prompt", + return_value="Default system prompt", + ) as mock_read, + patch( + "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_llm_response, + ) as mock_llm, + ): + from cognee.modules.retrieval.utils.completion import summarize_text + + result = await summarize_text(text="Long text to summarize") + + assert result == mock_llm_response + mock_read.assert_called_once_with("summarize_search_results.txt") + mock_llm.assert_awaited_once_with( + text_input="Long text to summarize", + system_prompt="Default system prompt", + response_model=str, + ) + + @pytest.mark.asyncio + async def test_summarize_text_custom_prompt_path(self): + """Test summarize_text uses custom prompt path when provided.""" + mock_llm_response = "Summary text" + + with ( + patch( + "cognee.modules.retrieval.utils.completion.read_query_prompt", + return_value="Custom system prompt", + ) as mock_read, + patch( + "cognee.modules.retrieval.utils.completion.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_llm_response, + ) as mock_llm, + ): + from cognee.modules.retrieval.utils.completion import summarize_text + + result = await summarize_text( + text="Long text to summarize", + system_prompt_path="custom_prompt.txt", + ) + + assert result == mock_llm_response + mock_read.assert_called_once_with("custom_prompt.txt") + mock_llm.assert_awaited_once_with( + text_input="Long text to summarize", + system_prompt="Custom system prompt", + response_model=str, + ) diff --git a/cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py b/cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py new file mode 100644 index 000000000..2af10da5e --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py @@ -0,0 +1,157 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock + +from cognee.modules.retrieval.graph_summary_completion_retriever import ( + GraphSummaryCompletionRetriever, +) +from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge + + +@pytest.fixture +def mock_edge(): + """Create a mock edge.""" + edge = MagicMock(spec=Edge) + return edge + + +class TestGraphSummaryCompletionRetriever: + @pytest.mark.asyncio + async def test_init_defaults(self): + """Test GraphSummaryCompletionRetriever initialization with defaults.""" + retriever = GraphSummaryCompletionRetriever() + + assert retriever.summarize_prompt_path == "summarize_search_results.txt" + assert retriever.user_prompt_path == "graph_context_for_question.txt" + assert retriever.system_prompt_path == "answer_simple_question.txt" + assert retriever.top_k == 5 + assert retriever.save_interaction is False + + @pytest.mark.asyncio + async def test_init_custom_params(self): + """Test GraphSummaryCompletionRetriever initialization with custom parameters.""" + retriever = GraphSummaryCompletionRetriever( + user_prompt_path="custom_user.txt", + system_prompt_path="custom_system.txt", + summarize_prompt_path="custom_summarize.txt", + system_prompt="Custom system prompt", + top_k=10, + save_interaction=True, + wide_search_top_k=200, + triplet_distance_penalty=2.5, + ) + + assert retriever.summarize_prompt_path == "custom_summarize.txt" + assert retriever.user_prompt_path == "custom_user.txt" + assert retriever.system_prompt_path == "custom_system.txt" + assert retriever.top_k == 10 + assert retriever.save_interaction is True + + @pytest.mark.asyncio + async def test_resolve_edges_to_text_calls_super_and_summarizes(self, mock_edge): + """Test resolve_edges_to_text calls super method and then summarizes.""" + retriever = GraphSummaryCompletionRetriever( + summarize_prompt_path="custom_summarize.txt", + system_prompt="Custom system prompt", + ) + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text", + new_callable=AsyncMock, + return_value="Resolved edges text", + ) as mock_super_resolve, + patch( + "cognee.modules.retrieval.graph_summary_completion_retriever.summarize_text", + new_callable=AsyncMock, + return_value="Summarized text", + ) as mock_summarize, + ): + result = await retriever.resolve_edges_to_text([mock_edge]) + + assert result == "Summarized text" + mock_super_resolve.assert_awaited_once_with([mock_edge]) + mock_summarize.assert_awaited_once_with( + "Resolved edges text", + "custom_summarize.txt", + "Custom system prompt", + ) + + @pytest.mark.asyncio + async def test_resolve_edges_to_text_with_default_system_prompt(self, mock_edge): + """Test resolve_edges_to_text uses None for system_prompt when not provided.""" + retriever = GraphSummaryCompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text", + new_callable=AsyncMock, + return_value="Resolved edges text", + ), + patch( + "cognee.modules.retrieval.graph_summary_completion_retriever.summarize_text", + new_callable=AsyncMock, + return_value="Summarized text", + ) as mock_summarize, + ): + await retriever.resolve_edges_to_text([mock_edge]) + + mock_summarize.assert_awaited_once_with( + "Resolved edges text", + "summarize_search_results.txt", + None, + ) + + @pytest.mark.asyncio + async def test_resolve_edges_to_text_with_empty_edges(self): + """Test resolve_edges_to_text handles empty edges list.""" + retriever = GraphSummaryCompletionRetriever() + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text", + new_callable=AsyncMock, + return_value="", + ), + patch( + "cognee.modules.retrieval.graph_summary_completion_retriever.summarize_text", + new_callable=AsyncMock, + return_value="Empty summary", + ) as mock_summarize, + ): + result = await retriever.resolve_edges_to_text([]) + + assert result == "Empty summary" + mock_summarize.assert_awaited_once_with( + "", + "summarize_search_results.txt", + None, + ) + + @pytest.mark.asyncio + async def test_resolve_edges_to_text_with_multiple_edges(self, mock_edge): + """Test resolve_edges_to_text handles multiple edges.""" + retriever = GraphSummaryCompletionRetriever() + + mock_edge2 = MagicMock(spec=Edge) + mock_edge3 = MagicMock(spec=Edge) + + with ( + patch( + "cognee.modules.retrieval.graph_completion_retriever.GraphCompletionRetriever.resolve_edges_to_text", + new_callable=AsyncMock, + return_value="Multiple edges resolved text", + ), + patch( + "cognee.modules.retrieval.graph_summary_completion_retriever.summarize_text", + new_callable=AsyncMock, + return_value="Multiple edges summarized", + ) as mock_summarize, + ): + result = await retriever.resolve_edges_to_text([mock_edge, mock_edge2, mock_edge3]) + + assert result == "Multiple edges summarized" + mock_summarize.assert_awaited_once_with( + "Multiple edges resolved text", + "summarize_search_results.txt", + None, + ) diff --git a/cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py b/cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py new file mode 100644 index 000000000..a1e746bb9 --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py @@ -0,0 +1,312 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock +from uuid import UUID, NAMESPACE_OID, uuid5 + +from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback +from cognee.modules.retrieval.utils.models import UserFeedbackEvaluation, UserFeedbackSentiment +from cognee.modules.engine.models import NodeSet + + +@pytest.fixture +def mock_feedback_evaluation(): + """Create a mock feedback evaluation.""" + evaluation = MagicMock(spec=UserFeedbackEvaluation) + evaluation.evaluation = MagicMock() + evaluation.evaluation.value = "positive" + evaluation.score = 4.5 + return evaluation + + +@pytest.fixture +def mock_graph_engine(): + """Create a mock graph engine.""" + engine = AsyncMock() + engine.get_last_user_interaction_ids = AsyncMock(return_value=[]) + engine.add_edges = AsyncMock() + engine.apply_feedback_weight = AsyncMock() + return engine + + +class TestUserQAFeedback: + @pytest.mark.asyncio + async def test_init_default(self): + """Test UserQAFeedback initialization with default last_k.""" + retriever = UserQAFeedback() + assert retriever.last_k == 1 + + @pytest.mark.asyncio + async def test_init_custom_last_k(self): + """Test UserQAFeedback initialization with custom last_k.""" + retriever = UserQAFeedback(last_k=5) + assert retriever.last_k == 5 + + @pytest.mark.asyncio + async def test_add_feedback_success_with_relationships( + self, mock_feedback_evaluation, mock_graph_engine + ): + """Test add_feedback successfully creates feedback with relationships.""" + interaction_id_1 = str(UUID("550e8400-e29b-41d4-a716-446655440000")) + interaction_id_2 = str(UUID("550e8400-e29b-41d4-a716-446655440001")) + mock_graph_engine.get_last_user_interaction_ids = AsyncMock( + return_value=[interaction_id_1, interaction_id_2] + ) + + feedback_text = "This answer was helpful" + + with ( + patch( + "cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_feedback_evaluation, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.add_data_points", + new_callable=AsyncMock, + ) as mock_add_data, + patch( + "cognee.modules.retrieval.user_qa_feedback.index_graph_edges", + new_callable=AsyncMock, + ) as mock_index_edges, + ): + retriever = UserQAFeedback(last_k=2) + result = await retriever.add_feedback(feedback_text) + + assert result == [feedback_text] + mock_add_data.assert_awaited_once() + mock_graph_engine.add_edges.assert_awaited_once() + mock_index_edges.assert_awaited_once() + mock_graph_engine.apply_feedback_weight.assert_awaited_once() + + # Verify add_edges was called with correct relationships + call_args = mock_graph_engine.add_edges.call_args[0][0] + assert len(call_args) == 2 + assert call_args[0][0] == uuid5(NAMESPACE_OID, name=feedback_text) + assert call_args[0][1] == UUID(interaction_id_1) + assert call_args[0][2] == "gives_feedback_to" + assert call_args[0][3]["relationship_name"] == "gives_feedback_to" + assert call_args[0][3]["ontology_valid"] is False + + # Verify apply_feedback_weight was called with correct node IDs + weight_call_args = mock_graph_engine.apply_feedback_weight.call_args[1]["node_ids"] + assert len(weight_call_args) == 2 + assert interaction_id_1 in weight_call_args + assert interaction_id_2 in weight_call_args + + @pytest.mark.asyncio + async def test_add_feedback_success_no_relationships( + self, mock_feedback_evaluation, mock_graph_engine + ): + """Test add_feedback successfully creates feedback without relationships.""" + mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[]) + + feedback_text = "This answer was helpful" + + with ( + patch( + "cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_feedback_evaluation, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.add_data_points", + new_callable=AsyncMock, + ) as mock_add_data, + patch( + "cognee.modules.retrieval.user_qa_feedback.index_graph_edges", + new_callable=AsyncMock, + ) as mock_index_edges, + ): + retriever = UserQAFeedback(last_k=1) + result = await retriever.add_feedback(feedback_text) + + assert result == [feedback_text] + mock_add_data.assert_awaited_once() + # Should not call add_edges or index_graph_edges when no relationships + mock_graph_engine.add_edges.assert_not_awaited() + mock_index_edges.assert_not_awaited() + mock_graph_engine.apply_feedback_weight.assert_not_awaited() + + @pytest.mark.asyncio + async def test_add_feedback_creates_correct_feedback_node( + self, mock_feedback_evaluation, mock_graph_engine + ): + """Test add_feedback creates CogneeUserFeedback with correct attributes.""" + mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[]) + + feedback_text = "This was a negative experience" + mock_feedback_evaluation.evaluation.value = "negative" + mock_feedback_evaluation.score = -3.0 + + with ( + patch( + "cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_feedback_evaluation, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.add_data_points", + new_callable=AsyncMock, + ) as mock_add_data, + ): + retriever = UserQAFeedback() + await retriever.add_feedback(feedback_text) + + # Verify add_data_points was called with correct CogneeUserFeedback + call_args = mock_add_data.call_args[1]["data_points"] + assert len(call_args) == 1 + feedback_node = call_args[0] + assert feedback_node.id == uuid5(NAMESPACE_OID, name=feedback_text) + assert feedback_node.feedback == feedback_text + assert feedback_node.sentiment == "negative" + assert feedback_node.score == -3.0 + assert isinstance(feedback_node.belongs_to_set, NodeSet) + assert feedback_node.belongs_to_set.name == "UserQAFeedbacks" + + @pytest.mark.asyncio + async def test_add_feedback_calls_llm_with_correct_prompt( + self, mock_feedback_evaluation, mock_graph_engine + ): + """Test add_feedback calls LLM with correct sentiment analysis prompt.""" + mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[]) + + feedback_text = "Great answer!" + + with ( + patch( + "cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_feedback_evaluation, + ) as mock_llm, + patch( + "cognee.modules.retrieval.user_qa_feedback.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.add_data_points", + new_callable=AsyncMock, + ), + ): + retriever = UserQAFeedback() + await retriever.add_feedback(feedback_text) + + mock_llm.assert_awaited_once() + call_kwargs = mock_llm.call_args[1] + assert call_kwargs["text_input"] == feedback_text + assert "sentiment analysis assistant" in call_kwargs["system_prompt"] + assert call_kwargs["response_model"] == UserFeedbackEvaluation + + @pytest.mark.asyncio + async def test_add_feedback_uses_last_k_parameter( + self, mock_feedback_evaluation, mock_graph_engine + ): + """Test add_feedback uses last_k parameter when getting interaction IDs.""" + mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[]) + + feedback_text = "Test feedback" + + with ( + patch( + "cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_feedback_evaluation, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.add_data_points", + new_callable=AsyncMock, + ), + ): + retriever = UserQAFeedback(last_k=5) + await retriever.add_feedback(feedback_text) + + mock_graph_engine.get_last_user_interaction_ids.assert_awaited_once_with(limit=5) + + @pytest.mark.asyncio + async def test_add_feedback_with_single_interaction( + self, mock_feedback_evaluation, mock_graph_engine + ): + """Test add_feedback with single interaction ID.""" + interaction_id = str(UUID("550e8400-e29b-41d4-a716-446655440000")) + mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[interaction_id]) + + feedback_text = "Test feedback" + + with ( + patch( + "cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_feedback_evaluation, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.add_data_points", + new_callable=AsyncMock, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.index_graph_edges", + new_callable=AsyncMock, + ), + ): + retriever = UserQAFeedback() + result = await retriever.add_feedback(feedback_text) + + assert result == [feedback_text] + # Should create relationship for the interaction + call_args = mock_graph_engine.add_edges.call_args[0][0] + assert len(call_args) == 1 + assert call_args[0][1] == UUID(interaction_id) + + @pytest.mark.asyncio + async def test_add_feedback_applies_weight_correctly( + self, mock_feedback_evaluation, mock_graph_engine + ): + """Test add_feedback applies feedback weight with correct score.""" + interaction_id = str(UUID("550e8400-e29b-41d4-a716-446655440000")) + mock_graph_engine.get_last_user_interaction_ids = AsyncMock(return_value=[interaction_id]) + mock_feedback_evaluation.score = 4.5 + + feedback_text = "Positive feedback" + + with ( + patch( + "cognee.modules.retrieval.user_qa_feedback.LLMGateway.acreate_structured_output", + new_callable=AsyncMock, + return_value=mock_feedback_evaluation, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.get_graph_engine", + return_value=mock_graph_engine, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.add_data_points", + new_callable=AsyncMock, + ), + patch( + "cognee.modules.retrieval.user_qa_feedback.index_graph_edges", + new_callable=AsyncMock, + ), + ): + retriever = UserQAFeedback() + await retriever.add_feedback(feedback_text) + + mock_graph_engine.apply_feedback_weight.assert_awaited_once_with( + node_ids=[interaction_id], weight=4.5 + ) diff --git a/cognee/tests/unit/modules/retrieval/triplet_retriever_test.py b/cognee/tests/unit/modules/retrieval/triplet_retriever_test.py new file mode 100644 index 000000000..83612c7aa --- /dev/null +++ b/cognee/tests/unit/modules/retrieval/triplet_retriever_test.py @@ -0,0 +1,329 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock + +from cognee.modules.retrieval.triplet_retriever import TripletRetriever +from cognee.modules.retrieval.exceptions.exceptions import NoDataError +from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError + + +@pytest.fixture +def mock_vector_engine(): + """Create a mock vector engine.""" + engine = AsyncMock() + engine.has_collection = AsyncMock(return_value=True) + engine.search = AsyncMock() + return engine + + +@pytest.mark.asyncio +async def test_get_context_success(mock_vector_engine): + """Test successful retrieval of triplet context.""" + mock_result1 = MagicMock() + mock_result1.payload = {"text": "Alice knows Bob"} + mock_result2 = MagicMock() + mock_result2.payload = {"text": "Bob works at Tech Corp"} + + mock_vector_engine.search.return_value = [mock_result1, mock_result2] + + retriever = TripletRetriever(top_k=5) + + with patch( + "cognee.modules.retrieval.triplet_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") + + assert context == "Alice knows Bob\nBob works at Tech Corp" + mock_vector_engine.search.assert_awaited_once_with("Triplet_text", "test query", limit=5) + + +@pytest.mark.asyncio +async def test_get_context_no_collection(mock_vector_engine): + """Test that NoDataError is raised when Triplet_text collection doesn't exist.""" + mock_vector_engine.has_collection.return_value = False + + retriever = TripletRetriever() + + with patch( + "cognee.modules.retrieval.triplet_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + with pytest.raises(NoDataError, match="create_triplet_embeddings"): + await retriever.get_context("test query") + + +@pytest.mark.asyncio +async def test_get_context_empty_results(mock_vector_engine): + """Test that empty string is returned when no triplets are found.""" + mock_vector_engine.search.return_value = [] + + retriever = TripletRetriever() + + with patch( + "cognee.modules.retrieval.triplet_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") + + assert context == "" + + +@pytest.mark.asyncio +async def test_get_context_collection_not_found_error(mock_vector_engine): + """Test that CollectionNotFoundError is converted to NoDataError.""" + mock_vector_engine.has_collection.side_effect = CollectionNotFoundError("Collection not found") + + retriever = TripletRetriever() + + with patch( + "cognee.modules.retrieval.triplet_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + with pytest.raises(NoDataError, match="No data found"): + await retriever.get_context("test query") + + +@pytest.mark.asyncio +async def test_get_context_empty_payload_text(mock_vector_engine): + """Test get_context handles missing text in payload.""" + mock_result = MagicMock() + mock_result.payload = {} + + mock_vector_engine.search.return_value = [mock_result] + + retriever = TripletRetriever() + + with patch( + "cognee.modules.retrieval.triplet_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + with pytest.raises(KeyError): + await retriever.get_context("test query") + + +@pytest.mark.asyncio +async def test_get_context_single_triplet(mock_vector_engine): + """Test get_context with single triplet result.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Single triplet"} + + mock_vector_engine.search.return_value = [mock_result] + + retriever = TripletRetriever() + + with patch( + "cognee.modules.retrieval.triplet_retriever.get_vector_engine", + return_value=mock_vector_engine, + ): + context = await retriever.get_context("test query") + + assert context == "Single triplet" + + +@pytest.mark.asyncio +async def test_init_defaults(): + """Test TripletRetriever initialization with defaults.""" + retriever = TripletRetriever() + + assert retriever.user_prompt_path == "context_for_question.txt" + assert retriever.system_prompt_path == "answer_simple_question.txt" + assert retriever.top_k == 5 # Default is 5 + assert retriever.system_prompt is None + + +@pytest.mark.asyncio +async def test_init_custom_params(): + """Test TripletRetriever initialization with custom parameters.""" + retriever = TripletRetriever( + user_prompt_path="custom_user.txt", + system_prompt_path="custom_system.txt", + system_prompt="Custom prompt", + top_k=10, + ) + + assert retriever.user_prompt_path == "custom_user.txt" + assert retriever.system_prompt_path == "custom_system.txt" + assert retriever.system_prompt == "Custom prompt" + assert retriever.top_k == 10 + + +@pytest.mark.asyncio +async def test_get_completion_without_context(mock_vector_engine): + """Test get_completion retrieves context when not provided.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Test triplet"} + mock_vector_engine.has_collection.return_value = True + mock_vector_engine.search.return_value = [mock_result] + + retriever = TripletRetriever() + + with ( + patch( + "cognee.modules.retrieval.triplet_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.triplet_retriever.generate_completion", + return_value="Generated answer", + ), + patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_provided_context(mock_vector_engine): + """Test get_completion uses provided context.""" + retriever = TripletRetriever() + + with ( + patch( + "cognee.modules.retrieval.triplet_retriever.generate_completion", + return_value="Generated answer", + ), + patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", context="Provided context") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + + +@pytest.mark.asyncio +async def test_get_completion_with_session(mock_vector_engine): + """Test get_completion with session caching enabled.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Test triplet"} + mock_vector_engine.has_collection.return_value = True + mock_vector_engine.search.return_value = [mock_result] + + retriever = TripletRetriever() + + mock_user = MagicMock() + mock_user.id = "test-user-id" + + with ( + patch( + "cognee.modules.retrieval.triplet_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.triplet_retriever.get_conversation_history", + return_value="Previous conversation", + ), + patch( + "cognee.modules.retrieval.triplet_retriever.summarize_text", + return_value="Context summary", + ), + patch( + "cognee.modules.retrieval.triplet_retriever.generate_completion", + return_value="Generated answer", + ), + patch( + "cognee.modules.retrieval.triplet_retriever.save_conversation_history", + ) as mock_save, + patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config, + patch("cognee.modules.retrieval.triplet_retriever.session_user") as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = mock_user + + completion = await retriever.get_completion("test query", session_id="test_session") + + assert isinstance(completion, list) + assert len(completion) == 1 + assert completion[0] == "Generated answer" + mock_save.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_completion_with_session_no_user_id(mock_vector_engine): + """Test get_completion with session config but no user ID.""" + mock_result = MagicMock() + mock_result.payload = {"text": "Test triplet"} + mock_vector_engine.has_collection.return_value = True + mock_vector_engine.search.return_value = [mock_result] + + retriever = TripletRetriever() + + with ( + patch( + "cognee.modules.retrieval.triplet_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.triplet_retriever.generate_completion", + return_value="Generated answer", + ), + patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config, + patch("cognee.modules.retrieval.triplet_retriever.session_user") as mock_session_user, + ): + mock_config = MagicMock() + mock_config.caching = True + mock_cache_config.return_value = mock_config + mock_session_user.get.return_value = None # No user + + completion = await retriever.get_completion("test query") + + assert isinstance(completion, list) + assert len(completion) == 1 + + +@pytest.mark.asyncio +async def test_get_completion_with_response_model(mock_vector_engine): + """Test get_completion with custom response model.""" + from pydantic import BaseModel + + class TestModel(BaseModel): + answer: str + + mock_result = MagicMock() + mock_result.payload = {"text": "Test triplet"} + mock_vector_engine.has_collection.return_value = True + mock_vector_engine.search.return_value = [mock_result] + + retriever = TripletRetriever() + + with ( + patch( + "cognee.modules.retrieval.triplet_retriever.get_vector_engine", + return_value=mock_vector_engine, + ), + patch( + "cognee.modules.retrieval.triplet_retriever.generate_completion", + return_value=TestModel(answer="Test answer"), + ), + patch("cognee.modules.retrieval.triplet_retriever.CacheConfig") as mock_cache_config, + ): + mock_config = MagicMock() + mock_config.caching = False + mock_cache_config.return_value = mock_config + + completion = await retriever.get_completion("test query", response_model=TestModel) + + assert isinstance(completion, list) + assert len(completion) == 1 + assert isinstance(completion[0], TestModel) + + +@pytest.mark.asyncio +async def test_init_none_top_k(): + """Test TripletRetriever initialization with None top_k.""" + retriever = TripletRetriever(top_k=None) + + assert retriever.top_k == 5 From e92d8f57b56823e0a1a4bf5ccf6734cdda01d56f Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Wed, 17 Dec 2025 13:14:14 +0100 Subject: [PATCH 078/176] feat: add comunity test trigger --- .github/workflows/release_test.yml | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/.github/workflows/release_test.yml b/.github/workflows/release_test.yml index 76ce3b09d..be57c7fbf 100644 --- a/.github/workflows/release_test.yml +++ b/.github/workflows/release_test.yml @@ -19,14 +19,28 @@ jobs: # uses: ./.github/workflows/load_tests.yml # secrets: inherit - docs-tests: +# docs-tests: +# runs-on: ubuntu-22.04 +# steps: +# - name: Trigger docs tests +# run: | +# curl -L -X POST \ +# -H "Accept: application/vnd.github+json" \ +# -H "Authorization: Bearer ${{ secrets.REPO_DISPATCH_PAT_TOKEN }}" \ +# -H "X-GitHub-Api-Version: 2022-11-28" \ +# https://api.github.com/repos/topoteretes/cognee-docs/dispatches \ +# -d '{"event_type":"new-main-release","client_payload":{"caller_repo":"'"${GITHUB_REPOSITORY}"'"}}' + + trigger-community-test-suite: + needs: release-pypi-package + if: ${{ inputs.flavour == 'main' }} runs-on: ubuntu-22.04 steps: - - name: Trigger docs tests + - name: Trigger community tests run: | curl -L -X POST \ -H "Accept: application/vnd.github+json" \ - -H "Authorization: Bearer ${{ secrets.DOCS_REPO_PAT_TOKEN }}" \ + -H "Authorization: Bearer ${{ secrets.REPO_DISPATCH_PAT_TOKEN }}" \ -H "X-GitHub-Api-Version: 2022-11-28" \ - https://api.github.com/repos/topoteretes/cognee-docs/dispatches \ + https://api.github.com/repos/topoteretes/cognee-community/dispatches \ -d '{"event_type":"new-main-release","client_payload":{"caller_repo":"'"${GITHUB_REPOSITORY}"'"}}' From 601f74db4fda3c1bc3603d03bfbe22be7c8d6a24 Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Wed, 17 Dec 2025 13:15:43 +0100 Subject: [PATCH 079/176] test: remove dependency from community trigger --- .github/workflows/release_test.yml | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/.github/workflows/release_test.yml b/.github/workflows/release_test.yml index be57c7fbf..dcb709ead 100644 --- a/.github/workflows/release_test.yml +++ b/.github/workflows/release_test.yml @@ -14,6 +14,18 @@ on: - main jobs: + trigger-community-test-suite: + if: ${{ inputs.flavour == 'main' }} + runs-on: ubuntu-22.04 + steps: + - name: Trigger community tests + run: | + curl -L -X POST \ + -H "Accept: application/vnd.github+json" \ + -H "Authorization: Bearer ${{ secrets.REPO_DISPATCH_PAT_TOKEN }}" \ + -H "X-GitHub-Api-Version: 2022-11-28" \ + https://api.github.com/repos/topoteretes/cognee-community/dispatches \ + -d '{"event_type":"new-main-release","client_payload":{"caller_repo":"'"${GITHUB_REPOSITORY}"'"}}' # load-tests: # name: Load Tests # uses: ./.github/workflows/load_tests.yml @@ -30,17 +42,3 @@ jobs: # -H "X-GitHub-Api-Version: 2022-11-28" \ # https://api.github.com/repos/topoteretes/cognee-docs/dispatches \ # -d '{"event_type":"new-main-release","client_payload":{"caller_repo":"'"${GITHUB_REPOSITORY}"'"}}' - - trigger-community-test-suite: - needs: release-pypi-package - if: ${{ inputs.flavour == 'main' }} - runs-on: ubuntu-22.04 - steps: - - name: Trigger community tests - run: | - curl -L -X POST \ - -H "Accept: application/vnd.github+json" \ - -H "Authorization: Bearer ${{ secrets.REPO_DISPATCH_PAT_TOKEN }}" \ - -H "X-GitHub-Api-Version: 2022-11-28" \ - https://api.github.com/repos/topoteretes/cognee-community/dispatches \ - -d '{"event_type":"new-main-release","client_payload":{"caller_repo":"'"${GITHUB_REPOSITORY}"'"}}' From a5a7ae2564abd90c0bf9b51b3abfc2a24a067a8f Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Wed, 17 Dec 2025 13:16:46 +0100 Subject: [PATCH 080/176] test: remove if --- .github/workflows/release_test.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/release_test.yml b/.github/workflows/release_test.yml index dcb709ead..08595a01e 100644 --- a/.github/workflows/release_test.yml +++ b/.github/workflows/release_test.yml @@ -15,7 +15,6 @@ on: jobs: trigger-community-test-suite: - if: ${{ inputs.flavour == 'main' }} runs-on: ubuntu-22.04 steps: - name: Trigger community tests From 6958b4edd462615e2e973d7cabd369181c030eba Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Wed, 17 Dec 2025 13:50:03 +0100 Subject: [PATCH 081/176] feat: add the triggers to release, after pypi publishing --- .github/workflows/release.yml | 28 ++++++++++++++++++++++++++++ .github/workflows/release_test.yml | 30 ++++-------------------------- 2 files changed, 32 insertions(+), 26 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 84601edf7..26ccce1f0 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -136,3 +136,31 @@ jobs: flavour=${{ inputs.flavour }} cache-from: type=registry,ref=cognee/cognee:buildcache cache-to: type=registry,ref=cognee/cognee:buildcache,mode=max + + trigger-docs-test-suite: + needs: release-pypi-package + if: ${{ inputs.flavour == 'main' }} + runs-on: ubuntu-22.04 + steps: + - name: Trigger docs tests + run: | + curl -L -X POST \ + -H "Accept: application/vnd.github+json" \ + -H "Authorization: Bearer ${{ secrets.REPO_DISPATCH_PAT_TOKEN }}" \ + -H "X-GitHub-Api-Version: 2022-11-28" \ + https://api.github.com/repos/topoteretes/cognee-docs/dispatches \ + -d '{"event_type":"new-main-release","client_payload":{"caller_repo":"'"${GITHUB_REPOSITORY}"'"}}' + + trigger-community-test-suite: + needs: release-pypi-package + if: ${{ inputs.flavour == 'main' }} + runs-on: ubuntu-22.04 + steps: + - name: Trigger community tests + run: | + curl -L -X POST \ + -H "Accept: application/vnd.github+json" \ + -H "Authorization: Bearer ${{ secrets.REPO_DISPATCH_PAT_TOKEN }}" \ + -H "X-GitHub-Api-Version: 2022-11-28" \ + https://api.github.com/repos/topoteretes/cognee-community/dispatches \ + -d '{"event_type":"new-main-release","client_payload":{"caller_repo":"'"${GITHUB_REPOSITORY}"'"}}' \ No newline at end of file diff --git a/.github/workflows/release_test.yml b/.github/workflows/release_test.yml index 08595a01e..6090a1217 100644 --- a/.github/workflows/release_test.yml +++ b/.github/workflows/release_test.yml @@ -14,30 +14,8 @@ on: - main jobs: - trigger-community-test-suite: - runs-on: ubuntu-22.04 - steps: - - name: Trigger community tests - run: | - curl -L -X POST \ - -H "Accept: application/vnd.github+json" \ - -H "Authorization: Bearer ${{ secrets.REPO_DISPATCH_PAT_TOKEN }}" \ - -H "X-GitHub-Api-Version: 2022-11-28" \ - https://api.github.com/repos/topoteretes/cognee-community/dispatches \ - -d '{"event_type":"new-main-release","client_payload":{"caller_repo":"'"${GITHUB_REPOSITORY}"'"}}' -# load-tests: -# name: Load Tests -# uses: ./.github/workflows/load_tests.yml -# secrets: inherit + load-tests: + name: Load Tests + uses: ./.github/workflows/load_tests.yml + secrets: inherit -# docs-tests: -# runs-on: ubuntu-22.04 -# steps: -# - name: Trigger docs tests -# run: | -# curl -L -X POST \ -# -H "Accept: application/vnd.github+json" \ -# -H "Authorization: Bearer ${{ secrets.REPO_DISPATCH_PAT_TOKEN }}" \ -# -H "X-GitHub-Api-Version: 2022-11-28" \ -# https://api.github.com/repos/topoteretes/cognee-docs/dispatches \ -# -d '{"event_type":"new-main-release","client_payload":{"caller_repo":"'"${GITHUB_REPOSITORY}"'"}}' From 431a83247fff487357a253cbddb00e779e8bda9b Mon Sep 17 00:00:00 2001 From: Andrej Milicevic Date: Wed, 17 Dec 2025 13:50:43 +0100 Subject: [PATCH 082/176] chore: remove unnecessary 'on push' setting --- .github/workflows/release_test.yml | 3 --- 1 file changed, 3 deletions(-) diff --git a/.github/workflows/release_test.yml b/.github/workflows/release_test.yml index 6090a1217..b31b431a4 100644 --- a/.github/workflows/release_test.yml +++ b/.github/workflows/release_test.yml @@ -5,9 +5,6 @@ permissions: contents: read on: - push: - branches: - - feature/cog-3213-docs-set-up-guide-script-tests workflow_dispatch: pull_request: branches: From cc7ca45e7315300fc775509a8b3784cfae2ed99a Mon Sep 17 00:00:00 2001 From: lxobr <122801072+lxobr@users.noreply.github.com> Date: Wed, 17 Dec 2025 15:48:24 +0100 Subject: [PATCH 083/176] feat: make vector_distance list based --- .../modules/graph/cognee_graph/CogneeGraph.py | 20 ++++++++++++++++--- .../graph/cognee_graph/CogneeGraphElements.py | 4 ++-- .../graph/cognee_graph_elements_test.py | 4 ++-- .../unit/modules/graph/cognee_graph_test.py | 4 ++++ 4 files changed, 25 insertions(+), 7 deletions(-) diff --git a/cognee/modules/graph/cognee_graph/CogneeGraph.py b/cognee/modules/graph/cognee_graph/CogneeGraph.py index 6233c245f..dd05c8c4f 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraph.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraph.py @@ -26,11 +26,13 @@ class CogneeGraph(CogneeAbstractGraph): nodes: Dict[str, Node] edges: List[Edge] directed: bool + triplet_distance_penalty: float def __init__(self, directed: bool = True): self.nodes = {} self.edges = [] self.directed = directed + self.triplet_distance_penalty = 3.5 def add_node(self, node: Node) -> None: if node.id not in self.nodes: @@ -148,6 +150,8 @@ class CogneeGraph(CogneeAbstractGraph): adapter, memory_fragment_filter ) + self.triplet_distance_penalty = triplet_distance_penalty + import time start_time = time.time() @@ -230,11 +234,21 @@ class CogneeGraph(CogneeAbstractGraph): logger.error(f"Error mapping vector distances to edges: {str(ex)}") raise ex + def _as_distance(self, value: Union[float, List[float], None]) -> float: + """Normalize distance value to float, handling None, lists, and scalars.""" + if value is None: + return self.triplet_distance_penalty + if isinstance(value, list) and value: + return float(value[0]) + if isinstance(value, (int, float)): + return float(value) + return self.triplet_distance_penalty + async def calculate_top_triplet_importances(self, k: int) -> List[Edge]: def score(edge): - n1 = edge.node1.attributes.get("vector_distance", 1) - n2 = edge.node2.attributes.get("vector_distance", 1) - e = edge.attributes.get("vector_distance", 1) + n1 = self._as_distance(edge.node1.attributes.get("vector_distance")) + n2 = self._as_distance(edge.node2.attributes.get("vector_distance")) + e = self._as_distance(edge.attributes.get("vector_distance")) return n1 + n2 + e return heapq.nsmallest(k, self.edges, key=score) diff --git a/cognee/modules/graph/cognee_graph/CogneeGraphElements.py b/cognee/modules/graph/cognee_graph/CogneeGraphElements.py index 62ef8d9fd..5d8e0df34 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraphElements.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraphElements.py @@ -30,7 +30,7 @@ class Node: raise InvalidDimensionsError() self.id = node_id self.attributes = attributes if attributes is not None else {} - self.attributes["vector_distance"] = node_penalty + self.attributes["vector_distance"] = None self.skeleton_neighbours = [] self.skeleton_edges = [] self.status = np.ones(dimension, dtype=int) @@ -116,7 +116,7 @@ class Edge: self.node1 = node1 self.node2 = node2 self.attributes = attributes if attributes is not None else {} - self.attributes["vector_distance"] = edge_penalty + self.attributes["vector_distance"] = None self.directed = directed self.status = np.ones(dimension, dtype=int) diff --git a/cognee/tests/unit/modules/graph/cognee_graph_elements_test.py b/cognee/tests/unit/modules/graph/cognee_graph_elements_test.py index 1d2b79cf9..e59888525 100644 --- a/cognee/tests/unit/modules/graph/cognee_graph_elements_test.py +++ b/cognee/tests/unit/modules/graph/cognee_graph_elements_test.py @@ -9,7 +9,7 @@ def test_node_initialization(): """Test that a Node is initialized correctly.""" node = Node("node1", {"attr1": "value1"}, dimension=2) assert node.id == "node1" - assert node.attributes == {"attr1": "value1", "vector_distance": 3.5} + assert node.attributes == {"attr1": "value1", "vector_distance": None} assert len(node.status) == 2 assert np.all(node.status == 1) @@ -96,7 +96,7 @@ def test_edge_initialization(): edge = Edge(node1, node2, {"weight": 10}, directed=False, dimension=2) assert edge.node1 == node1 assert edge.node2 == node2 - assert edge.attributes == {"vector_distance": 3.5, "weight": 10} + assert edge.attributes == {"vector_distance": None, "weight": 10} assert edge.directed is False assert len(edge.status) == 2 assert np.all(edge.status == 1) diff --git a/cognee/tests/unit/modules/graph/cognee_graph_test.py b/cognee/tests/unit/modules/graph/cognee_graph_test.py index edbd8ef9d..d30167262 100644 --- a/cognee/tests/unit/modules/graph/cognee_graph_test.py +++ b/cognee/tests/unit/modules/graph/cognee_graph_test.py @@ -246,6 +246,7 @@ async def test_map_vector_distances_to_graph_nodes(setup_graph): @pytest.mark.asyncio +@pytest.mark.skip(reason="Will be updated in Phase 2 to expect list-based distances") async def test_map_vector_distances_partial_node_coverage(setup_graph): """Test mapping vector distances when only some nodes have results.""" graph = setup_graph @@ -272,6 +273,7 @@ async def test_map_vector_distances_partial_node_coverage(setup_graph): @pytest.mark.asyncio +@pytest.mark.skip(reason="Will be updated in Phase 2 to expect list-based distances") async def test_map_vector_distances_multiple_categories(setup_graph): """Test mapping vector distances from multiple collection categories.""" graph = setup_graph @@ -331,6 +333,7 @@ async def test_map_vector_distances_to_graph_edges_with_payload(setup_graph): @pytest.mark.asyncio +@pytest.mark.skip(reason="Will be updated in Phase 2 to expect list-based distances") async def test_map_vector_distances_partial_edge_coverage(setup_graph): """Test mapping edge distances when only some edges have results.""" graph = setup_graph @@ -384,6 +387,7 @@ async def test_map_vector_distances_edges_fallback_to_relationship_type(setup_gr @pytest.mark.asyncio +@pytest.mark.skip(reason="Will be updated in Phase 2 to expect list-based distances") async def test_map_vector_distances_no_edge_matches(setup_graph): """Test edge mapping when no edges match the distance results.""" graph = setup_graph From 69ab8e7edee9e422b33059a55f0d711fd5e4cde6 Mon Sep 17 00:00:00 2001 From: lxobr <122801072+lxobr@users.noreply.github.com> Date: Wed, 17 Dec 2025 18:14:57 +0100 Subject: [PATCH 084/176] feat: add multi-query support to graph distance mapping --- .../modules/graph/cognee_graph/CogneeGraph.py | 114 +++++++++++++++--- cognee/tests/test_search_db.py | 29 ++++- .../unit/modules/graph/cognee_graph_test.py | 90 +++++++++++--- 3 files changed, 190 insertions(+), 43 deletions(-) diff --git a/cognee/modules/graph/cognee_graph/CogneeGraph.py b/cognee/modules/graph/cognee_graph/CogneeGraph.py index dd05c8c4f..4a586e488 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraph.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraph.py @@ -1,6 +1,6 @@ import time from cognee.shared.logging_utils import get_logger -from typing import List, Dict, Union, Optional, Type +from typing import List, Dict, Union, Optional, Type, Iterable, Tuple, Callable, Any from cognee.modules.graph.exceptions import ( EntityNotFoundError, @@ -204,31 +204,105 @@ class CogneeGraph(CogneeAbstractGraph): logger.error(f"Error during graph projection: {str(e)}") raise - async def map_vector_distances_to_graph_nodes(self, node_distances) -> None: - mapped_nodes = 0 - for category, scored_results in node_distances.items(): - for scored_result in scored_results: - node_id = str(scored_result.id) - score = scored_result.score - node = self.get_node(node_id) - if node: - node.add_attribute("vector_distance", score) - mapped_nodes += 1 + def _initialize_vector_distance(self, graph_elements, query_list_length=None) -> None: + """Initialize vector_distance as a list of default penalties for all graph elements.""" + query_count = query_list_length or 1 + for element in graph_elements: + element.attributes["vector_distance"] = [self.triplet_distance_penalty] * query_count - async def map_vector_distances_to_graph_edges(self, edge_distances) -> None: + def _normalize_query_input(self, distance_data, query_list_length=None, name="input"): + """Normalize single-query or multi-query input to list of lists, return empty list if empty.""" + if not distance_data: + return [] + normalized = ( + distance_data if isinstance(distance_data[0], (list, tuple)) else [distance_data] + ) + if query_list_length is not None and len(normalized) != query_list_length: + raise ValueError( + f"{name} has {len(normalized)} query lists, but query_list_length is {query_list_length}" + ) + return normalized + + def _apply_vector_distance_updates( + self, + element_distances, + query_index: int, + get_element: Callable[[str], Optional[Union[Node, Edge]]], + get_id_and_score: Callable[[Any], Tuple[Optional[str], Optional[float]]], + ) -> None: + """Apply updates into element.attributes["vector_distance"][query_index].""" + for res in element_distances: + key, score = get_id_and_score(res) + if key is None or score is None: + continue + element = get_element(key) + if element is None: + continue + element.attributes["vector_distance"][query_index] = score + + def _get_node_id_and_score(self, res: Any) -> Tuple[str, float]: + """Extract node ID and score from a scored result.""" + return str(res.id), float(res.score) + + def _get_edge_id_and_score(self, res: Any) -> Tuple[Optional[str], Optional[float]]: + """Extract edge key and score from a scored result.""" + payload = getattr(res, "payload", None) + if not payload: + return None, None + text = payload.get("text") + if text is None: + return None, None + return str(text), float(res.score) + + async def map_vector_distances_to_graph_nodes( + self, + node_distances, + query_list_length: Optional[int] = None, + ) -> None: + self._initialize_vector_distance(self.nodes.values(), query_list_length) + + for collection_name, scored_results in node_distances.items(): + per_query_lists = self._normalize_query_input( + scored_results, query_list_length, f"Collection '{collection_name}'" + ) + if not per_query_lists: + continue + + for query_index, scored_list in enumerate(per_query_lists): + self._apply_vector_distance_updates( + element_distances=scored_list, + query_index=query_index, + get_element=self.nodes.get, + get_id_and_score=self._get_node_id_and_score, + ) + + async def map_vector_distances_to_graph_edges( + self, + edge_distances, + query_list_length: Optional[int] = None, + ) -> None: try: - if edge_distances is None: + self._initialize_vector_distance(self.edges, query_list_length) + + normalized_edges = self._normalize_query_input( + edge_distances, query_list_length, "edge_distances" + ) + if not normalized_edges: return - embedding_map = {result.payload["text"]: result.score for result in edge_distances} - + edges_by_key: Dict[str, Edge] = {} for edge in self.edges: - edge_key = edge.attributes.get("edge_text") or edge.attributes.get( - "relationship_type" + key = edge.attributes.get("edge_text") or edge.attributes.get("relationship_type") + if key: + edges_by_key[str(key)] = edge + + for query_index, scored_list in enumerate(normalized_edges): + self._apply_vector_distance_updates( + element_distances=scored_list, + query_index=query_index, + get_element=edges_by_key.get, + get_id_and_score=self._get_edge_id_and_score, ) - distance = embedding_map.get(edge_key, None) - if distance is not None: - edge.attributes["vector_distance"] = distance except Exception as ex: logger.error(f"Error mapping vector distances to edges: {str(ex)}") diff --git a/cognee/tests/test_search_db.py b/cognee/tests/test_search_db.py index 0916be322..d0b78dfcc 100644 --- a/cognee/tests/test_search_db.py +++ b/cognee/tests/test_search_db.py @@ -350,11 +350,32 @@ async def test_e2e_retriever_triplets_have_vector_distances(e2e_state): assert triplets, f"{name}: Triplets list should not be empty" for edge in triplets: assert isinstance(edge, Edge), f"{name}: Elements should be Edge instances" - distance = edge.attributes.get("vector_distance") - node1_distance = edge.node1.attributes.get("vector_distance") - node2_distance = edge.node2.attributes.get("vector_distance") - assert isinstance(distance, float), f"{name}: vector_distance should be float" + vector_distances = edge.attributes.get("vector_distance", []) + assert isinstance(vector_distances, list) and vector_distances, ( + f"{name}: vector_distance should be a non-empty list" + ) + distance = vector_distances[0] + assert isinstance(distance, float), ( + f"{name}: vector_distance[0] should be float, got {type(distance)}" + ) assert 0 <= distance <= 1 + + node1_distances = edge.node1.attributes.get("vector_distance", []) + node2_distances = edge.node2.attributes.get("vector_distance", []) + assert isinstance(node1_distances, list) and node1_distances, ( + f"{name}: node1 vector_distance should be a non-empty list" + ) + assert isinstance(node2_distances, list) and node2_distances, ( + f"{name}: node2 vector_distance should be a non-empty list" + ) + node1_distance = node1_distances[0] + node2_distance = node2_distances[0] + assert isinstance(node1_distance, float), ( + f"{name}: node1 vector_distance[0] should be float, got {type(node1_distance)}" + ) + assert isinstance(node2_distance, float), ( + f"{name}: node2 vector_distance[0] should be float, got {type(node2_distance)}" + ) assert 0 <= node1_distance <= 1 assert 0 <= node2_distance <= 1 diff --git a/cognee/tests/unit/modules/graph/cognee_graph_test.py b/cognee/tests/unit/modules/graph/cognee_graph_test.py index d30167262..8babdfe47 100644 --- a/cognee/tests/unit/modules/graph/cognee_graph_test.py +++ b/cognee/tests/unit/modules/graph/cognee_graph_test.py @@ -241,12 +241,11 @@ async def test_map_vector_distances_to_graph_nodes(setup_graph): await graph.map_vector_distances_to_graph_nodes(node_distances) - assert graph.get_node("1").attributes.get("vector_distance") == 0.95 - assert graph.get_node("2").attributes.get("vector_distance") == 0.87 + assert graph.get_node("1").attributes.get("vector_distance") == [0.95] + assert graph.get_node("2").attributes.get("vector_distance") == [0.87] @pytest.mark.asyncio -@pytest.mark.skip(reason="Will be updated in Phase 2 to expect list-based distances") async def test_map_vector_distances_partial_node_coverage(setup_graph): """Test mapping vector distances when only some nodes have results.""" graph = setup_graph @@ -267,13 +266,12 @@ async def test_map_vector_distances_partial_node_coverage(setup_graph): await graph.map_vector_distances_to_graph_nodes(node_distances) - assert graph.get_node("1").attributes.get("vector_distance") == 0.95 - assert graph.get_node("2").attributes.get("vector_distance") == 0.87 - assert graph.get_node("3").attributes.get("vector_distance") == 3.5 + assert graph.get_node("1").attributes.get("vector_distance") == [0.95] + assert graph.get_node("2").attributes.get("vector_distance") == [0.87] + assert graph.get_node("3").attributes.get("vector_distance") == [3.5] @pytest.mark.asyncio -@pytest.mark.skip(reason="Will be updated in Phase 2 to expect list-based distances") async def test_map_vector_distances_multiple_categories(setup_graph): """Test mapping vector distances from multiple collection categories.""" graph = setup_graph @@ -300,10 +298,36 @@ async def test_map_vector_distances_multiple_categories(setup_graph): await graph.map_vector_distances_to_graph_nodes(node_distances) - assert graph.get_node("1").attributes.get("vector_distance") == 0.95 - assert graph.get_node("2").attributes.get("vector_distance") == 0.87 - assert graph.get_node("3").attributes.get("vector_distance") == 0.92 - assert graph.get_node("4").attributes.get("vector_distance") == 3.5 + assert graph.get_node("1").attributes.get("vector_distance") == [0.95] + assert graph.get_node("2").attributes.get("vector_distance") == [0.87] + assert graph.get_node("3").attributes.get("vector_distance") == [0.92] + assert graph.get_node("4").attributes.get("vector_distance") == [3.5] + + +@pytest.mark.asyncio +async def test_map_vector_distances_to_graph_nodes_multi_query(setup_graph): + """Test mapping vector distances with multiple queries.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + node3 = Node("3") + graph.add_node(node1) + graph.add_node(node2) + graph.add_node(node3) + + node_distances = { + "Entity_name": [ + [MockScoredResult("1", 0.95)], # query 0 + [MockScoredResult("2", 0.87)], # query 1 + ] + } + + await graph.map_vector_distances_to_graph_nodes(node_distances, query_list_length=2) + + assert graph.get_node("1").attributes.get("vector_distance") == [0.95, 3.5] + assert graph.get_node("2").attributes.get("vector_distance") == [3.5, 0.87] + assert graph.get_node("3").attributes.get("vector_distance") == [3.5, 3.5] @pytest.mark.asyncio @@ -329,11 +353,10 @@ async def test_map_vector_distances_to_graph_edges_with_payload(setup_graph): await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances) - assert graph.edges[0].attributes.get("vector_distance") == 0.92 + assert graph.edges[0].attributes.get("vector_distance") == [0.92] @pytest.mark.asyncio -@pytest.mark.skip(reason="Will be updated in Phase 2 to expect list-based distances") async def test_map_vector_distances_partial_edge_coverage(setup_graph): """Test mapping edge distances when only some edges have results.""" graph = setup_graph @@ -356,8 +379,8 @@ async def test_map_vector_distances_partial_edge_coverage(setup_graph): await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances) - assert graph.edges[0].attributes.get("vector_distance") == 0.92 - assert graph.edges[1].attributes.get("vector_distance") == 3.5 + assert graph.edges[0].attributes.get("vector_distance") == [0.92] + assert graph.edges[1].attributes.get("vector_distance") == [3.5] @pytest.mark.asyncio @@ -383,11 +406,10 @@ async def test_map_vector_distances_edges_fallback_to_relationship_type(setup_gr await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances) - assert graph.edges[0].attributes.get("vector_distance") == 0.85 + assert graph.edges[0].attributes.get("vector_distance") == [0.85] @pytest.mark.asyncio -@pytest.mark.skip(reason="Will be updated in Phase 2 to expect list-based distances") async def test_map_vector_distances_no_edge_matches(setup_graph): """Test edge mapping when no edges match the distance results.""" graph = setup_graph @@ -410,7 +432,7 @@ async def test_map_vector_distances_no_edge_matches(setup_graph): await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances) - assert graph.edges[0].attributes.get("vector_distance") == 3.5 + assert graph.edges[0].attributes.get("vector_distance") == [3.5] @pytest.mark.asyncio @@ -423,7 +445,37 @@ async def test_map_vector_distances_none_returns_early(setup_graph): await graph.map_vector_distances_to_graph_edges(edge_distances=None) - assert graph.edges[0].attributes.get("vector_distance") == 3.5 + assert graph.edges[0].attributes.get("vector_distance") == [3.5] + + +@pytest.mark.asyncio +async def test_map_vector_distances_to_graph_edges_multi_query(setup_graph): + """Test mapping edge distances with multiple queries.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + node3 = Node("3") + graph.add_node(node1) + graph.add_node(node2) + graph.add_node(node3) + + edge1 = Edge(node1, node2, attributes={"edge_text": "A"}) + edge2 = Edge(node2, node3, attributes={"edge_text": "B"}) + graph.add_edge(edge1) + graph.add_edge(edge2) + + edge_distances = [ + [MockScoredResult("e1", 0.1, payload={"text": "A"})], # query 0 + [MockScoredResult("e2", 0.2, payload={"text": "B"})], # query 1 + ] + + await graph.map_vector_distances_to_graph_edges( + edge_distances=edge_distances, query_list_length=2 + ) + + assert graph.edges[0].attributes.get("vector_distance") == [0.1, 3.5] + assert graph.edges[1].attributes.get("vector_distance") == [3.5, 0.2] @pytest.mark.asyncio From 931c5f30968fbb43f614fbf339ca81160f017998 Mon Sep 17 00:00:00 2001 From: Christina_Raichel_Francis Date: Wed, 17 Dec 2025 18:02:35 +0000 Subject: [PATCH 085/176] refactor: add test and example script --- .../tasks/memify/extract_usage_frequency.py | 102 +++++++++++++++++- cognee/tests/test_extract_usage_frequency.py | 42 ++++++++ .../python/extract_usage_frequency_example.py | 49 +++++++++ 3 files changed, 189 insertions(+), 4 deletions(-) create mode 100644 cognee/tests/test_extract_usage_frequency.py create mode 100644 examples/python/extract_usage_frequency_example.py diff --git a/cognee/tasks/memify/extract_usage_frequency.py b/cognee/tasks/memify/extract_usage_frequency.py index d6ca3773f..7932a39a4 100644 --- a/cognee/tasks/memify/extract_usage_frequency.py +++ b/cognee/tasks/memify/extract_usage_frequency.py @@ -1,7 +1,101 @@ +# cognee/tasks/memify/extract_usage_frequency.py +from typing import List, Dict, Any +from datetime import datetime, timedelta from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph +from cognee.modules.pipelines.tasks.task import Task - -async def extract_subgraph(subgraphs: list[CogneeGraph]): +async def extract_usage_frequency( + subgraphs: List[CogneeGraph], + time_window: timedelta = timedelta(days=7), + min_interaction_threshold: int = 1 +) -> Dict[str, Any]: + """ + Extract usage frequency from CogneeUserInteraction nodes + + :param subgraphs: List of graph subgraphs + :param time_window: Time window to consider for interactions + :param min_interaction_threshold: Minimum interactions to track + :return: Dictionary of usage frequencies + """ + current_time = datetime.now() + node_frequencies = {} + edge_frequencies = {} + for subgraph in subgraphs: - for edge in subgraph.edges: - yield edge + # Filter CogneeUserInteraction nodes within time window + user_interactions = [ + interaction for interaction in subgraph.nodes + if (interaction.get('type') == 'CogneeUserInteraction' and + current_time - datetime.fromisoformat(interaction.get('timestamp', current_time.isoformat())) <= time_window) + ] + + # Count node and edge frequencies + for interaction in user_interactions: + target_node_id = interaction.get('target_node_id') + edge_type = interaction.get('edge_type') + + if target_node_id: + node_frequencies[target_node_id] = node_frequencies.get(target_node_id, 0) + 1 + + if edge_type: + edge_frequencies[edge_type] = edge_frequencies.get(edge_type, 0) + 1 + + # Filter frequencies above threshold + filtered_node_frequencies = { + node_id: freq for node_id, freq in node_frequencies.items() + if freq >= min_interaction_threshold + } + + filtered_edge_frequencies = { + edge_type: freq for edge_type, freq in edge_frequencies.items() + if freq >= min_interaction_threshold + } + + return { + 'node_frequencies': filtered_node_frequencies, + 'edge_frequencies': filtered_edge_frequencies, + 'last_processed_timestamp': current_time.isoformat() + } + +async def add_frequency_weights( + graph_adapter, + usage_frequencies: Dict[str, Any] +) -> None: + """ + Add frequency weights to graph nodes and edges + + :param graph_adapter: Graph database adapter + :param usage_frequencies: Calculated usage frequencies + """ + # Update node frequencies + for node_id, frequency in usage_frequencies['node_frequencies'].items(): + try: + node = graph_adapter.get_node(node_id) + if node: + node_properties = node.get_properties() or {} + node_properties['frequency_weight'] = frequency + graph_adapter.update_node(node_id, node_properties) + except Exception as e: + print(f"Error updating node {node_id}: {e}") + + # Note: Edge frequency update might require backend-specific implementation + print("Edge frequency update might need backend-specific handling") + +def usage_frequency_pipeline_entry(graph_adapter): + """ + Memify pipeline entry for usage frequency tracking + + :param graph_adapter: Graph database adapter + :return: Usage frequency results + """ + extraction_tasks = [ + Task(extract_usage_frequency, + time_window=timedelta(days=7), + min_interaction_threshold=1) + ] + + enrichment_tasks = [ + Task(add_frequency_weights, task_config={"batch_size": 1}) + ] + + return extraction_tasks, enrichment_tasks \ No newline at end of file diff --git a/cognee/tests/test_extract_usage_frequency.py b/cognee/tests/test_extract_usage_frequency.py new file mode 100644 index 000000000..b75168409 --- /dev/null +++ b/cognee/tests/test_extract_usage_frequency.py @@ -0,0 +1,42 @@ +# cognee/tests/test_usage_frequency.py +import pytest +import asyncio +from datetime import datetime, timedelta +from cognee.tasks.memify.extract_usage_frequency import extract_usage_frequency, add_frequency_weights + +@pytest.mark.asyncio +async def test_extract_usage_frequency(): + # Mock CogneeGraph with user interactions + mock_subgraphs = [{ + 'nodes': [ + { + 'type': 'CogneeUserInteraction', + 'target_node_id': 'node1', + 'edge_type': 'viewed', + 'timestamp': datetime.now().isoformat() + }, + { + 'type': 'CogneeUserInteraction', + 'target_node_id': 'node1', + 'edge_type': 'viewed', + 'timestamp': datetime.now().isoformat() + }, + { + 'type': 'CogneeUserInteraction', + 'target_node_id': 'node2', + 'edge_type': 'referenced', + 'timestamp': datetime.now().isoformat() + } + ] + }] + + # Test frequency extraction + result = await extract_usage_frequency( + mock_subgraphs, + time_window=timedelta(days=1), + min_interaction_threshold=1 + ) + + assert 'node1' in result['node_frequencies'] + assert result['node_frequencies']['node1'] == 2 + assert result['edge_frequencies']['viewed'] == 2 \ No newline at end of file diff --git a/examples/python/extract_usage_frequency_example.py b/examples/python/extract_usage_frequency_example.py new file mode 100644 index 000000000..c73fa4cc2 --- /dev/null +++ b/examples/python/extract_usage_frequency_example.py @@ -0,0 +1,49 @@ +# cognee/examples/usage_frequency_example.py +import asyncio +import cognee +from cognee.api.v1.search import SearchType +from cognee.tasks.memify.extract_usage_frequency import usage_frequency_pipeline_entry + +async def main(): + # Reset cognee state + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + + # Sample conversation + conversation = [ + "Alice discusses machine learning", + "Bob asks about neural networks", + "Alice explains deep learning concepts", + "Bob wants more details about neural networks" + ] + + # Add conversation and cognify + await cognee.add(conversation) + await cognee.cognify() + + # Perform some searches to generate interactions + for query in ["machine learning", "neural networks", "deep learning"]: + await cognee.search( + query_type=SearchType.GRAPH_COMPLETION, + query_text=query, + save_interaction=True + ) + + # Run usage frequency tracking + await cognee.memify( + *usage_frequency_pipeline_entry(cognee.graph_adapter) + ) + + # Search and display frequency weights + results = await cognee.search( + query_text="Find nodes with frequency weights", + query_type=SearchType.NODE_PROPERTIES, + properties=["frequency_weight"] + ) + + print("Nodes with Frequency Weights:") + for result in results[0]["search_result"][0]: + print(result) + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file From 46ff01021a2ca1b3a115ac6279a36ff90925bf3c Mon Sep 17 00:00:00 2001 From: lxobr <122801072+lxobr@users.noreply.github.com> Date: Wed, 17 Dec 2025 19:09:02 +0100 Subject: [PATCH 086/176] feat: add multi-query support to score calculation --- .../modules/graph/cognee_graph/CogneeGraph.py | 40 ++-- .../unit/modules/graph/cognee_graph_test.py | 182 +++++++++++++++++- 2 files changed, 200 insertions(+), 22 deletions(-) diff --git a/cognee/modules/graph/cognee_graph/CogneeGraph.py b/cognee/modules/graph/cognee_graph/CogneeGraph.py index 4a586e488..4838d5bc0 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraph.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraph.py @@ -308,21 +308,33 @@ class CogneeGraph(CogneeAbstractGraph): logger.error(f"Error mapping vector distances to edges: {str(ex)}") raise ex - def _as_distance(self, value: Union[float, List[float], None]) -> float: - """Normalize distance value to float, handling None, lists, and scalars.""" - if value is None: - return self.triplet_distance_penalty - if isinstance(value, list) and value: - return float(value[0]) - if isinstance(value, (int, float)): - return float(value) - return self.triplet_distance_penalty + def _calculate_query_top_triplet_importances( + self, + k: int, + query_index: int = 0, + ) -> List[Edge]: + """Calculate top k triplet importances for a specific query index.""" - async def calculate_top_triplet_importances(self, k: int) -> List[Edge]: def score(edge): - n1 = self._as_distance(edge.node1.attributes.get("vector_distance")) - n2 = self._as_distance(edge.node2.attributes.get("vector_distance")) - e = self._as_distance(edge.attributes.get("vector_distance")) - return n1 + n2 + e + distances = [ + edge.node1.attributes.get("vector_distance"), + edge.node2.attributes.get("vector_distance"), + edge.attributes.get("vector_distance"), + ] + return sum(float(d[query_index]) for d in distances) return heapq.nsmallest(k, self.edges, key=score) + + async def calculate_top_triplet_importances( + self, k: int, query_list_length: Optional[int] = None + ) -> Union[List[Edge], List[List[Edge]]]: + """Calculate top k triplet importances, supporting both single and multi-query modes.""" + query_count = query_list_length or 1 + results = [ + self._calculate_query_top_triplet_importances(k=k, query_index=i) + for i in range(query_count) + ] + + if query_list_length is None: + return results[0] + return results diff --git a/cognee/tests/unit/modules/graph/cognee_graph_test.py b/cognee/tests/unit/modules/graph/cognee_graph_test.py index 8babdfe47..84e6411e2 100644 --- a/cognee/tests/unit/modules/graph/cognee_graph_test.py +++ b/cognee/tests/unit/modules/graph/cognee_graph_test.py @@ -200,6 +200,37 @@ async def test_project_graph_from_db_empty_graph(setup_graph, mock_adapter): ) +@pytest.mark.asyncio +async def test_project_graph_from_db_stores_triplet_penalty_on_graph(mock_adapter): + """Test that project_graph_from_db stores triplet_distance_penalty on the graph.""" + from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph + + nodes_data = [("1", {"name": "Node1"})] + edges_data = [("1", "1", "SELF", {})] + + mock_adapter.get_graph_data = AsyncMock(return_value=(nodes_data, edges_data)) + + graph = CogneeGraph() + custom_penalty = 5.0 + await graph.project_graph_from_db( + adapter=mock_adapter, + node_properties_to_project=["name"], + edge_properties_to_project=[], + triplet_distance_penalty=custom_penalty, + ) + + assert graph.triplet_distance_penalty == custom_penalty + + graph2 = CogneeGraph() + await graph2.project_graph_from_db( + adapter=mock_adapter, + node_properties_to_project=["name"], + edge_properties_to_project=[], + ) + + assert graph2.triplet_distance_penalty == 3.5 + + @pytest.mark.asyncio async def test_project_graph_from_db_missing_nodes(setup_graph, mock_adapter): """Test that edges referencing missing nodes raise error.""" @@ -478,6 +509,36 @@ async def test_map_vector_distances_to_graph_edges_multi_query(setup_graph): assert graph.edges[1].attributes.get("vector_distance") == [3.5, 0.2] +@pytest.mark.asyncio +async def test_map_vector_distances_to_graph_edges_preserves_unmapped_indices(setup_graph): + """Test that unmapped indices in multi-query mode stay at default penalty.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + node3 = Node("3") + graph.add_node(node1) + graph.add_node(node2) + graph.add_node(node3) + + edge1 = Edge(node1, node2, attributes={"edge_text": "A"}) + edge2 = Edge(node2, node3, attributes={"edge_text": "B"}) + graph.add_edge(edge1) + graph.add_edge(edge2) + + edge_distances = [ + [MockScoredResult("e1", 0.1, payload={"text": "A"})], # query 0: only edge1 mapped + [], # query 1: no edges mapped + ] + + await graph.map_vector_distances_to_graph_edges( + edge_distances=edge_distances, query_list_length=2 + ) + + assert graph.edges[0].attributes.get("vector_distance") == [0.1, 3.5] + assert graph.edges[1].attributes.get("vector_distance") == [3.5, 3.5] + + @pytest.mark.asyncio async def test_calculate_top_triplet_importances(setup_graph): """Test calculating top triplet importances by score.""" @@ -488,10 +549,10 @@ async def test_calculate_top_triplet_importances(setup_graph): node3 = Node("3") node4 = Node("4") - node1.add_attribute("vector_distance", 0.9) - node2.add_attribute("vector_distance", 0.8) - node3.add_attribute("vector_distance", 0.7) - node4.add_attribute("vector_distance", 0.6) + node1.add_attribute("vector_distance", [0.9]) + node2.add_attribute("vector_distance", [0.8]) + node3.add_attribute("vector_distance", [0.7]) + node4.add_attribute("vector_distance", [0.6]) graph.add_node(node1) graph.add_node(node2) @@ -502,9 +563,9 @@ async def test_calculate_top_triplet_importances(setup_graph): edge2 = Edge(node2, node3) edge3 = Edge(node3, node4) - edge1.add_attribute("vector_distance", 0.85) - edge2.add_attribute("vector_distance", 0.75) - edge3.add_attribute("vector_distance", 0.65) + edge1.add_attribute("vector_distance", [0.85]) + edge2.add_attribute("vector_distance", [0.75]) + edge3.add_attribute("vector_distance", [0.65]) graph.add_edge(edge1) graph.add_edge(edge2) @@ -520,7 +581,7 @@ async def test_calculate_top_triplet_importances(setup_graph): @pytest.mark.asyncio async def test_calculate_top_triplet_importances_default_distances(setup_graph): - """Test calculating importances when nodes/edges have no vector distances.""" + """Test calculating importances when nodes/edges have default vector distances.""" graph = setup_graph node1 = Node("1") @@ -531,7 +592,112 @@ async def test_calculate_top_triplet_importances_default_distances(setup_graph): edge = Edge(node1, node2) graph.add_edge(edge) + await graph.map_vector_distances_to_graph_nodes({}) + await graph.map_vector_distances_to_graph_edges(None) + top_triplets = await graph.calculate_top_triplet_importances(k=1) assert len(top_triplets) == 1 assert top_triplets[0] == edge + + +@pytest.mark.asyncio +async def test_calculate_top_triplet_importances_single_query_via_helper(setup_graph): + """Test calculating top triplet importances for a single query index.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + node3 = Node("3") + graph.add_node(node1) + graph.add_node(node2) + graph.add_node(node3) + + node1.add_attribute("vector_distance", [0.1]) + node2.add_attribute("vector_distance", [0.2]) + node3.add_attribute("vector_distance", [0.3]) + + edge1 = Edge(node1, node2) + edge2 = Edge(node2, node3) + graph.add_edge(edge1) + graph.add_edge(edge2) + + edge1.add_attribute("vector_distance", [0.3]) + edge2.add_attribute("vector_distance", [0.4]) + + results = await graph.calculate_top_triplet_importances(k=1, query_list_length=1) + assert len(results) == 1 + assert len(results[0]) == 1 + assert results[0][0] == edge1 + + +@pytest.mark.asyncio +async def test_calculate_top_triplet_importances_multi_query(setup_graph): + """Test calculating top triplet importances with multiple queries.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + node3 = Node("3") + graph.add_node(node1) + graph.add_node(node2) + graph.add_node(node3) + + edge_a = Edge(node1, node2) + edge_b = Edge(node2, node3) + graph.add_edge(edge_a) + graph.add_edge(edge_b) + + node1.add_attribute("vector_distance", [0.1, 0.9]) + node2.add_attribute("vector_distance", [0.1, 0.9]) + node3.add_attribute("vector_distance", [0.9, 0.1]) + edge_a.add_attribute("vector_distance", [0.1, 0.9]) + edge_b.add_attribute("vector_distance", [0.9, 0.1]) + + results = await graph.calculate_top_triplet_importances(k=1, query_list_length=2) + + assert len(results) == 2 + assert results[0][0] == edge_a + assert results[1][0] == edge_b + + +@pytest.mark.asyncio +async def test_calculate_top_triplet_importances_raises_on_short_list(setup_graph): + """Test that scoring raises ValueError when list is too short for query_index.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + graph.add_node(node1) + graph.add_node(node2) + + node1.add_attribute("vector_distance", [0.1]) + node2.add_attribute("vector_distance", [0.2]) + + edge = Edge(node1, node2) + edge.add_attribute("vector_distance", [0.3]) + graph.add_edge(edge) + + with pytest.raises(IndexError): + await graph.calculate_top_triplet_importances(k=1, query_list_length=2) + + +@pytest.mark.asyncio +async def test_calculate_top_triplet_importances_raises_on_missing_attribute(setup_graph): + """Test that scoring raises error when vector_distance is missing.""" + graph = setup_graph + + node1 = Node("1") + node2 = Node("2") + graph.add_node(node1) + graph.add_node(node2) + + del node1.attributes["vector_distance"] + del node2.attributes["vector_distance"] + + edge = Edge(node1, node2) + del edge.attributes["vector_distance"] + graph.add_edge(edge) + + with pytest.raises((KeyError, TypeError)): + await graph.calculate_top_triplet_importances(k=1, query_list_length=1) From 6e5e79f434a755b0692bb6956b571128e9d9db4c Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Wed, 17 Dec 2025 21:07:23 +0100 Subject: [PATCH 087/176] fix: Resolve connection issue with postgres when special characters are present --- .../relational/create_relational_engine.py | 14 ++++++++++---- .../databases/vector/create_vector_engine.py | 11 +++++++++-- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/cognee/infrastructure/databases/relational/create_relational_engine.py b/cognee/infrastructure/databases/relational/create_relational_engine.py index 8813dfcb2..ea2b35c75 100644 --- a/cognee/infrastructure/databases/relational/create_relational_engine.py +++ b/cognee/infrastructure/databases/relational/create_relational_engine.py @@ -1,6 +1,6 @@ +from sqlalchemy import URL from .sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter from functools import lru_cache -from urllib.parse import quote_plus @lru_cache @@ -44,10 +44,16 @@ def create_relational_engine( # Test if asyncpg is available import asyncpg - encoded_username = quote_plus(db_username) - encoded_password = quote_plus(db_password) + # Handle special characters in username and password like # or @ + connection_string = URL.create( + "postgresql+asyncpg", + username=db_username, + password=db_password, + host=db_host, + port=int(db_port), + database=db_name, + ) - connection_string = f"postgresql+asyncpg://{encoded_username}:{encoded_password}@{db_host}:{db_port}/{db_name}" except ImportError: raise ImportError( "PostgreSQL dependencies are not installed. Please install with 'pip install cognee\"[postgres]\"' or 'pip install cognee\"[postgres-binary]\"' to use PostgreSQL functionality." diff --git a/cognee/infrastructure/databases/vector/create_vector_engine.py b/cognee/infrastructure/databases/vector/create_vector_engine.py index d1cf855d7..47a2e2582 100644 --- a/cognee/infrastructure/databases/vector/create_vector_engine.py +++ b/cognee/infrastructure/databases/vector/create_vector_engine.py @@ -1,3 +1,5 @@ +from sqlalchemy import URL + from .supported_databases import supported_databases from .embeddings import get_embedding_engine @@ -61,8 +63,13 @@ def create_vector_engine( if not (db_host and db_port and db_name and db_username and db_password): raise EnvironmentError("Missing requred pgvector credentials!") - connection_string: str = ( - f"postgresql+asyncpg://{db_username}:{db_password}@{db_host}:{db_port}/{db_name}" + connection_string = URL.create( + "postgresql+asyncpg", + username=db_username, + password=db_password, + host=db_host, + port=int(db_port), + database=db_name, ) try: From c1ea7a8cc235067358cf68006bd87b40ad7830b7 Mon Sep 17 00:00:00 2001 From: lxobr <122801072+lxobr@users.noreply.github.com> Date: Thu, 18 Dec 2025 14:52:35 +0100 Subject: [PATCH 088/176] fix: improve graph distance mapping --- .../modules/graph/cognee_graph/CogneeGraph.py | 169 +++++++++--------- .../graph/cognee_graph/CogneeGraphElements.py | 46 +++++ cognee/tests/test_search_db.py | 15 +- .../graph/cognee_graph_elements_test.py | 40 +++++ .../unit/modules/graph/cognee_graph_test.py | 35 +++- 5 files changed, 210 insertions(+), 95 deletions(-) diff --git a/cognee/modules/graph/cognee_graph/CogneeGraph.py b/cognee/modules/graph/cognee_graph/CogneeGraph.py index 4838d5bc0..bc29bb828 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraph.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraph.py @@ -25,12 +25,14 @@ class CogneeGraph(CogneeAbstractGraph): nodes: Dict[str, Node] edges: List[Edge] + edges_by_distance_key: Dict[str, List[Edge]] directed: bool triplet_distance_penalty: float def __init__(self, directed: bool = True): self.nodes = {} self.edges = [] + self.edges_by_distance_key = {} self.directed = directed self.triplet_distance_penalty = 3.5 @@ -44,6 +46,12 @@ class CogneeGraph(CogneeAbstractGraph): self.edges.append(edge) edge.node1.add_skeleton_edge(edge) edge.node2.add_skeleton_edge(edge) + key = edge.get_distance_key() + if not key: + return + if key not in self.edges_by_distance_key: + self.edges_by_distance_key[key] = [] + self.edges_by_distance_key[key].append(edge) def get_node(self, node_id: str) -> Node: return self.nodes.get(node_id, None) @@ -58,6 +66,29 @@ class CogneeGraph(CogneeAbstractGraph): def get_edges(self) -> List[Edge]: return self.edges + def reset_distances(self, collection: Iterable[Union[Node, Edge]], query_count: int) -> None: + """Reset vector distances for a collection of nodes or edges.""" + for item in collection: + item.reset_vector_distances(query_count, self.triplet_distance_penalty) + + def _normalize_query_distance_lists( + self, distances: List, query_list_length: Optional[int] = None, name: str = "distances" + ) -> List: + """Normalize shape: flat list -> single-query; nested list -> multi-query.""" + if not distances: + return [] + first_item = distances[0] + if isinstance(first_item, (list, tuple)): + per_query_lists = distances + else: + per_query_lists = [distances] + if query_list_length is not None and len(per_query_lists) != query_list_length: + raise ValueError( + f"{name} has {len(per_query_lists)} query lists, " + f"but query_list_length is {query_list_length}" + ) + return per_query_lists + async def _get_nodeset_subgraph( self, adapter, @@ -204,109 +235,81 @@ class CogneeGraph(CogneeAbstractGraph): logger.error(f"Error during graph projection: {str(e)}") raise - def _initialize_vector_distance(self, graph_elements, query_list_length=None) -> None: - """Initialize vector_distance as a list of default penalties for all graph elements.""" - query_count = query_list_length or 1 - for element in graph_elements: - element.attributes["vector_distance"] = [self.triplet_distance_penalty] * query_count - - def _normalize_query_input(self, distance_data, query_list_length=None, name="input"): - """Normalize single-query or multi-query input to list of lists, return empty list if empty.""" - if not distance_data: - return [] - normalized = ( - distance_data if isinstance(distance_data[0], (list, tuple)) else [distance_data] - ) - if query_list_length is not None and len(normalized) != query_list_length: - raise ValueError( - f"{name} has {len(normalized)} query lists, but query_list_length is {query_list_length}" - ) - return normalized - - def _apply_vector_distance_updates( - self, - element_distances, - query_index: int, - get_element: Callable[[str], Optional[Union[Node, Edge]]], - get_id_and_score: Callable[[Any], Tuple[Optional[str], Optional[float]]], - ) -> None: - """Apply updates into element.attributes["vector_distance"][query_index].""" - for res in element_distances: - key, score = get_id_and_score(res) - if key is None or score is None: - continue - element = get_element(key) - if element is None: - continue - element.attributes["vector_distance"][query_index] = score - - def _get_node_id_and_score(self, res: Any) -> Tuple[str, float]: - """Extract node ID and score from a scored result.""" - return str(res.id), float(res.score) - - def _get_edge_id_and_score(self, res: Any) -> Tuple[Optional[str], Optional[float]]: - """Extract edge key and score from a scored result.""" - payload = getattr(res, "payload", None) - if not payload: - return None, None - text = payload.get("text") - if text is None: - return None, None - return str(text), float(res.score) - async def map_vector_distances_to_graph_nodes( self, node_distances, query_list_length: Optional[int] = None, ) -> None: - self._initialize_vector_distance(self.nodes.values(), query_list_length) + """Map vector distances to nodes, supporting single- and multi-query input shapes.""" + if not node_distances: + return None + + query_count = query_list_length or 1 + + # Reset all node distances for this search + self.reset_distances(self.nodes.values(), query_count) for collection_name, scored_results in node_distances.items(): - per_query_lists = self._normalize_query_input( - scored_results, query_list_length, f"Collection '{collection_name}'" - ) - if not per_query_lists: + if not scored_results: continue + per_query_lists = self._normalize_query_distance_lists( + scored_results, query_list_length, f"Collection '{collection_name}'" + ) + for query_index, scored_list in enumerate(per_query_lists): - self._apply_vector_distance_updates( - element_distances=scored_list, - query_index=query_index, - get_element=self.nodes.get, - get_id_and_score=self._get_node_id_and_score, - ) + for result in scored_list: + node_id = str(getattr(result, "id", None)) + if not node_id: + continue + node = self.get_node(node_id) + if node is None: + continue + score = float(getattr(result, "score", self.triplet_distance_penalty)) + node.update_distance_for_query( + query_index=query_index, + score=score, + query_count=query_count, + default_penalty=self.triplet_distance_penalty, + ) async def map_vector_distances_to_graph_edges( self, edge_distances, query_list_length: Optional[int] = None, ) -> None: - try: - self._initialize_vector_distance(self.edges, query_list_length) + """Map vector distances to graph edges, supporting single- and multi-query input shapes.""" + if not edge_distances: + return None - normalized_edges = self._normalize_query_input( - edge_distances, query_list_length, "edge_distances" - ) - if not normalized_edges: - return + query_count = query_list_length or 1 - edges_by_key: Dict[str, Edge] = {} - for edge in self.edges: - key = edge.attributes.get("edge_text") or edge.attributes.get("relationship_type") - if key: - edges_by_key[str(key)] = edge + # Reset all edge distances for this search + self.reset_distances(self.edges, query_count) - for query_index, scored_list in enumerate(normalized_edges): - self._apply_vector_distance_updates( - element_distances=scored_list, - query_index=query_index, - get_element=edges_by_key.get, - get_id_and_score=self._get_edge_id_and_score, - ) + per_query_edge_lists = self._normalize_query_distance_lists( + edge_distances, query_list_length, "edge_distances" + ) - except Exception as ex: - logger.error(f"Error mapping vector distances to edges: {str(ex)}") - raise ex + # For each query, apply distances to all matching edges + for query_index, scored_list in enumerate(per_query_edge_lists): + for result in scored_list: + payload = getattr(result, "payload", None) + if not isinstance(payload, dict): + continue + text = payload.get("text") + if not text: + continue + matching_edges = self.edges_by_distance_key.get(str(text)) + if not matching_edges: + continue + for edge in matching_edges: + edge.update_distance_for_query( + query_index=query_index, + score=float(getattr(result, "score", self.triplet_distance_penalty)), + query_count=query_count, + default_penalty=self.triplet_distance_penalty, + ) def _calculate_query_top_triplet_importances( self, diff --git a/cognee/modules/graph/cognee_graph/CogneeGraphElements.py b/cognee/modules/graph/cognee_graph/CogneeGraphElements.py index 5d8e0df34..c9226b6a1 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraphElements.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraphElements.py @@ -35,6 +35,26 @@ class Node: self.skeleton_edges = [] self.status = np.ones(dimension, dtype=int) + def reset_vector_distances(self, query_count: int, default_penalty: float) -> None: + self.attributes["vector_distance"] = [default_penalty] * query_count + + def ensure_vector_distance_list(self, query_count: int, default_penalty: float) -> List[float]: + distances = self.attributes.get("vector_distance") + if not isinstance(distances, list) or len(distances) != query_count: + distances = [default_penalty] * query_count + self.attributes["vector_distance"] = distances + return distances + + def update_distance_for_query( + self, + query_index: int, + score: float, + query_count: int, + default_penalty: float, + ) -> None: + distances = self.ensure_vector_distance_list(query_count, default_penalty) + distances[query_index] = score + def add_skeleton_neighbor(self, neighbor: "Node") -> None: if neighbor not in self.skeleton_neighbours: self.skeleton_neighbours.append(neighbor) @@ -120,6 +140,32 @@ class Edge: self.directed = directed self.status = np.ones(dimension, dtype=int) + def get_distance_key(self) -> Optional[str]: + key = self.attributes.get("edge_text") or self.attributes.get("relationship_type") + if key is None: + return None + return str(key) + + def reset_vector_distances(self, query_count: int, default_penalty: float) -> None: + self.attributes["vector_distance"] = [default_penalty] * query_count + + def ensure_vector_distance_list(self, query_count: int, default_penalty: float) -> List[float]: + distances = self.attributes.get("vector_distance") + if not isinstance(distances, list) or len(distances) != query_count: + distances = [default_penalty] * query_count + self.attributes["vector_distance"] = distances + return distances + + def update_distance_for_query( + self, + query_index: int, + score: float, + query_count: int, + default_penalty: float, + ) -> None: + distances = self.ensure_vector_distance_list(query_count, default_penalty) + distances[query_index] = score + def is_edge_alive_in_dimension(self, dimension: int) -> bool: if dimension < 0 or dimension >= len(self.status): raise DimensionOutOfRangeError(dimension=dimension, max_index=len(self.status) - 1) diff --git a/cognee/tests/test_search_db.py b/cognee/tests/test_search_db.py index d0b78dfcc..c5cd0061e 100644 --- a/cognee/tests/test_search_db.py +++ b/cognee/tests/test_search_db.py @@ -350,7 +350,10 @@ async def test_e2e_retriever_triplets_have_vector_distances(e2e_state): assert triplets, f"{name}: Triplets list should not be empty" for edge in triplets: assert isinstance(edge, Edge), f"{name}: Elements should be Edge instances" - vector_distances = edge.attributes.get("vector_distance", []) + vector_distances = edge.attributes.get("vector_distance") + assert vector_distances is not None, ( + f"{name}: vector_distance should be set when retrievers return results" + ) assert isinstance(vector_distances, list) and vector_distances, ( f"{name}: vector_distance should be a non-empty list" ) @@ -360,8 +363,14 @@ async def test_e2e_retriever_triplets_have_vector_distances(e2e_state): ) assert 0 <= distance <= 1 - node1_distances = edge.node1.attributes.get("vector_distance", []) - node2_distances = edge.node2.attributes.get("vector_distance", []) + node1_distances = edge.node1.attributes.get("vector_distance") + node2_distances = edge.node2.attributes.get("vector_distance") + assert node1_distances is not None, ( + f"{name}: node1 vector_distance should be set when retrievers return results" + ) + assert node2_distances is not None, ( + f"{name}: node2 vector_distance should be set when retrievers return results" + ) assert isinstance(node1_distances, list) and node1_distances, ( f"{name}: node1 vector_distance should be a non-empty list" ) diff --git a/cognee/tests/unit/modules/graph/cognee_graph_elements_test.py b/cognee/tests/unit/modules/graph/cognee_graph_elements_test.py index e59888525..809cde4cd 100644 --- a/cognee/tests/unit/modules/graph/cognee_graph_elements_test.py +++ b/cognee/tests/unit/modules/graph/cognee_graph_elements_test.py @@ -86,6 +86,46 @@ def test_node_hash(): assert hash(node) == hash("node1") +def test_node_vector_distance_stays_none(): + """Test that vector_distance remains None when no distances are passed.""" + node = Node("node1") + assert node.attributes.get("vector_distance") is None + + # Verify it stays None even after other operations + node.add_attribute("other_attr", "value") + assert node.attributes.get("vector_distance") is None + + +def test_node_vector_distance_with_custom_attributes(): + """Test that vector_distance is None even when node has custom attributes.""" + node = Node("node1", {"custom": "value", "another": 42}) + assert node.attributes.get("vector_distance") is None + assert node.attributes["custom"] == "value" + assert node.attributes["another"] == 42 + + +def test_edge_vector_distance_stays_none(): + """Test that vector_distance remains None when no distances are passed.""" + node1 = Node("node1") + node2 = Node("node2") + edge = Edge(node1, node2) + assert edge.attributes.get("vector_distance") is None + + # Verify it stays None even after other operations + edge.add_attribute("other_attr", "value") + assert edge.attributes.get("vector_distance") is None + + +def test_edge_vector_distance_with_custom_attributes(): + """Test that vector_distance is None even when edge has custom attributes.""" + node1 = Node("node1") + node2 = Node("node2") + edge = Edge(node1, node2, {"weight": 5, "type": "test"}) + assert edge.attributes.get("vector_distance") is None + assert edge.attributes["weight"] == 5 + assert edge.attributes["type"] == "test" + + ### Tests for Edge ### diff --git a/cognee/tests/unit/modules/graph/cognee_graph_test.py b/cognee/tests/unit/modules/graph/cognee_graph_test.py index 84e6411e2..e4ff0251e 100644 --- a/cognee/tests/unit/modules/graph/cognee_graph_test.py +++ b/cognee/tests/unit/modules/graph/cognee_graph_test.py @@ -468,7 +468,7 @@ async def test_map_vector_distances_no_edge_matches(setup_graph): @pytest.mark.asyncio async def test_map_vector_distances_none_returns_early(setup_graph): - """Test that edge_distances=None returns early without error.""" + """Test that edge_distances=None returns early without error and vector_distance stays None.""" graph = setup_graph graph.add_node(Node("1")) graph.add_node(Node("2")) @@ -476,7 +476,22 @@ async def test_map_vector_distances_none_returns_early(setup_graph): await graph.map_vector_distances_to_graph_edges(edge_distances=None) - assert graph.edges[0].attributes.get("vector_distance") == [3.5] + assert graph.edges[0].attributes.get("vector_distance") is None + + +@pytest.mark.asyncio +async def test_map_vector_distances_empty_nodes_returns_early(setup_graph): + """Test that node_distances={} returns early without error and vector_distance stays None.""" + graph = setup_graph + node1 = Node("1") + node2 = Node("2") + graph.add_node(node1) + graph.add_node(node2) + + await graph.map_vector_distances_to_graph_nodes({}) + + assert node1.attributes.get("vector_distance") is None + assert node2.attributes.get("vector_distance") is None @pytest.mark.asyncio @@ -581,7 +596,7 @@ async def test_calculate_top_triplet_importances(setup_graph): @pytest.mark.asyncio async def test_calculate_top_triplet_importances_default_distances(setup_graph): - """Test calculating importances when nodes/edges have default vector distances.""" + """Test that vector_distance stays None when no distances are passed and calculate_top_triplet_importances handles it.""" graph = setup_graph node1 = Node("1") @@ -592,13 +607,15 @@ async def test_calculate_top_triplet_importances_default_distances(setup_graph): edge = Edge(node1, node2) graph.add_edge(edge) - await graph.map_vector_distances_to_graph_nodes({}) - await graph.map_vector_distances_to_graph_edges(None) + # Verify vector_distance is None when no distances are passed + assert node1.attributes.get("vector_distance") is None + assert node2.attributes.get("vector_distance") is None + assert edge.attributes.get("vector_distance") is None - top_triplets = await graph.calculate_top_triplet_importances(k=1) - - assert len(top_triplets) == 1 - assert top_triplets[0] == edge + # When no distances are set, calculate_top_triplet_importances should handle None + # by either raising an error or skipping edges with None distances + with pytest.raises(TypeError, match="'NoneType' object is not subscriptable"): + await graph.calculate_top_triplet_importances(k=1) @pytest.mark.asyncio From f93d414e94681d046bbd3e1320ab6d1be924e848 Mon Sep 17 00:00:00 2001 From: Boris Arzentar Date: Thu, 18 Dec 2025 15:28:45 +0100 Subject: [PATCH 089/176] feat: simplify the current tutorial and add cognee basics tutorial --- ...nable_delete_for_old_tutorial_notebooks.py | 52 + cognee-frontend/package-lock.json | 1252 ++++++++++++++++- cognee-frontend/package.json | 3 +- .../src/app/dashboard/Dashboard.tsx | 15 +- .../dashboard/InstanceDatasetsAccordion.tsx | 2 + .../src/modules/ingestion/useDatasets.ts | 1 + .../src/modules/instances/cloudFetch.ts | 59 + .../src/modules/instances/localFetch.ts | 27 + .../src/modules/instances/types.ts | 4 + .../src/modules/notebooks/createNotebook.ts | 11 + .../src/modules/notebooks/deleteNotebook.ts | 7 + .../src/modules/notebooks/getNotebooks.ts | 10 + .../src/modules/notebooks/runNotebookCell.ts | 14 + .../src/modules/notebooks/saveNotebook.ts | 11 + .../src/modules/notebooks/useNotebooks.ts | 75 +- .../ui/elements/Notebook/MarkdownPreview.tsx | 77 + .../src/ui/elements/Notebook/Notebook.tsx | 397 +++--- .../elements/Notebook/NotebookCellHeader.tsx | 9 +- cognee-frontend/src/ui/elements/TextArea.tsx | 136 +- .../src/utils/handleServerErrors.ts | 8 +- cognee/modules/notebooks/methods/__init__.py | 1 + .../notebooks/methods/create_notebook.py | 34 - .../methods/create_tutorial_notebooks.py | 191 +++ .../notebooks/methods/get_notebooks.py | 20 +- .../tutorials/cognee-basics/cell-1.md | 3 + .../tutorials/cognee-basics/cell-2.md | 10 + .../tutorials/cognee-basics/cell-3.md | 7 + .../tutorials/cognee-basics/cell-4.py | 26 + .../tutorials/cognee-basics/cell-5.py | 1 + .../tutorials/cognee-basics/cell-6.py | 22 + .../tutorials/cognee-basics/config.json | 4 + .../python-development-with-cognee/cell-1.md | 3 + .../python-development-with-cognee/cell-10.md | 3 + .../python-development-with-cognee/cell-11.md | 3 + .../python-development-with-cognee/cell-12.py | 3 + .../python-development-with-cognee/cell-13.md | 7 + .../python-development-with-cognee/cell-14.py | 6 + .../python-development-with-cognee/cell-15.md | 3 + .../python-development-with-cognee/cell-16.py | 7 + .../python-development-with-cognee/cell-2.md | 9 + .../python-development-with-cognee/cell-3.md | 7 + .../python-development-with-cognee/cell-4.md | 9 + .../python-development-with-cognee/cell-5.md | 5 + .../python-development-with-cognee/cell-6.py | 13 + .../python-development-with-cognee/cell-7.md | 3 + .../python-development-with-cognee/cell-8.md | 3 + .../python-development-with-cognee/cell-9.py | 32 + .../config.json | 4 + .../data/copilot_conversations.json | 108 ++ .../data/guido_contributions.json | 976 +++++++++++++ .../data/my_developer_rules.md | 79 ++ .../data/pep_style_guide.md | 75 + .../data/zen_principles.md | 75 + cognee/modules/users/methods/create_user.py | 9 - .../users/test_tutorial_notebook_creation.py | 921 +++++++----- 55 files changed, 4114 insertions(+), 738 deletions(-) create mode 100644 alembic/versions/1a58b986e6e1_enable_delete_for_old_tutorial_notebooks.py create mode 100644 cognee-frontend/src/modules/instances/cloudFetch.ts create mode 100644 cognee-frontend/src/modules/instances/localFetch.ts create mode 100644 cognee-frontend/src/modules/instances/types.ts create mode 100644 cognee-frontend/src/modules/notebooks/createNotebook.ts create mode 100644 cognee-frontend/src/modules/notebooks/deleteNotebook.ts create mode 100644 cognee-frontend/src/modules/notebooks/getNotebooks.ts create mode 100644 cognee-frontend/src/modules/notebooks/runNotebookCell.ts create mode 100644 cognee-frontend/src/modules/notebooks/saveNotebook.ts create mode 100644 cognee-frontend/src/ui/elements/Notebook/MarkdownPreview.tsx create mode 100644 cognee/modules/notebooks/methods/create_tutorial_notebooks.py create mode 100644 cognee/modules/notebooks/tutorials/cognee-basics/cell-1.md create mode 100644 cognee/modules/notebooks/tutorials/cognee-basics/cell-2.md create mode 100644 cognee/modules/notebooks/tutorials/cognee-basics/cell-3.md create mode 100644 cognee/modules/notebooks/tutorials/cognee-basics/cell-4.py create mode 100644 cognee/modules/notebooks/tutorials/cognee-basics/cell-5.py create mode 100644 cognee/modules/notebooks/tutorials/cognee-basics/cell-6.py create mode 100644 cognee/modules/notebooks/tutorials/cognee-basics/config.json create mode 100644 cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-1.md create mode 100644 cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-10.md create mode 100644 cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-11.md create mode 100644 cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-12.py create mode 100644 cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-13.md create mode 100644 cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-14.py create mode 100644 cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-15.md create mode 100644 cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-16.py create mode 100644 cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-2.md create mode 100644 cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-3.md create mode 100644 cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-4.md create mode 100644 cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-5.md create mode 100644 cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-6.py create mode 100644 cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-7.md create mode 100644 cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-8.md create mode 100644 cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-9.py create mode 100644 cognee/modules/notebooks/tutorials/python-development-with-cognee/config.json create mode 100644 cognee/modules/notebooks/tutorials/python-development-with-cognee/data/copilot_conversations.json create mode 100644 cognee/modules/notebooks/tutorials/python-development-with-cognee/data/guido_contributions.json create mode 100644 cognee/modules/notebooks/tutorials/python-development-with-cognee/data/my_developer_rules.md create mode 100644 cognee/modules/notebooks/tutorials/python-development-with-cognee/data/pep_style_guide.md create mode 100644 cognee/modules/notebooks/tutorials/python-development-with-cognee/data/zen_principles.md diff --git a/alembic/versions/1a58b986e6e1_enable_delete_for_old_tutorial_notebooks.py b/alembic/versions/1a58b986e6e1_enable_delete_for_old_tutorial_notebooks.py new file mode 100644 index 000000000..f6a965c35 --- /dev/null +++ b/alembic/versions/1a58b986e6e1_enable_delete_for_old_tutorial_notebooks.py @@ -0,0 +1,52 @@ +"""Enable delete for old tutorial notebooks + +Revision ID: 1a58b986e6e1 +Revises: 46a6ce2bd2b2 +Create Date: 2025-12-17 11:04:44.414259 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "1a58b986e6e1" +down_revision: Union[str, None] = "46a6ce2bd2b2" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def change_tutorial_deletable_flag(deletable: bool) -> None: + bind = op.get_bind() + inspector = sa.inspect(bind) + + if "notebooks" not in inspector.get_table_names(): + return + + columns = {col["name"] for col in inspector.get_columns("notebooks")} + required_columns = {"name", "deletable"} + if not required_columns.issubset(columns): + return + + notebooks = sa.table( + "notebooks", + sa.Column("name", sa.String()), + sa.Column("deletable", sa.Boolean()), + ) + + tutorial_name = "Python Development with Cognee Tutorial 🧠" + + bind.execute( + notebooks.update().where(notebooks.c.name == tutorial_name).values(deletable=deletable) + ) + + +def upgrade() -> None: + change_tutorial_deletable_flag(True) + + +def downgrade() -> None: + change_tutorial_deletable_flag(False) diff --git a/cognee-frontend/package-lock.json b/cognee-frontend/package-lock.json index 29826027a..abd9cd6c3 100644 --- a/cognee-frontend/package-lock.json +++ b/cognee-frontend/package-lock.json @@ -12,10 +12,11 @@ "classnames": "^2.5.1", "culori": "^4.0.1", "d3-force-3d": "^3.0.6", - "next": "16.0.4", + "next": "^16.0.10", "react": "^19.2.0", "react-dom": "^19.2.0", "react-force-graph-2d": "^1.27.1", + "react-markdown": "^10.1.0", "uuid": "^9.0.1" }, "devDependencies": { @@ -1074,9 +1075,9 @@ } }, "node_modules/@next/env": { - "version": "16.0.4", - "resolved": "https://registry.npmjs.org/@next/env/-/env-16.0.4.tgz", - "integrity": "sha512-FDPaVoB1kYhtOz6Le0Jn2QV7RZJ3Ngxzqri7YX4yu3Ini+l5lciR7nA9eNDpKTmDm7LWZtxSju+/CQnwRBn2pA==", + "version": "16.0.10", + "resolved": "https://registry.npmjs.org/@next/env/-/env-16.0.10.tgz", + "integrity": "sha512-8tuaQkyDVgeONQ1MeT9Mkk8pQmZapMKFh5B+OrFUlG3rVmYTXcXlBetBgTurKXGaIZvkoqRT9JL5K3phXcgang==", "license": "MIT" }, "node_modules/@next/eslint-plugin-next": { @@ -1090,9 +1091,9 @@ } }, "node_modules/@next/swc-darwin-arm64": { - "version": "16.0.4", - "resolved": "https://registry.npmjs.org/@next/swc-darwin-arm64/-/swc-darwin-arm64-16.0.4.tgz", - "integrity": "sha512-TN0cfB4HT2YyEio9fLwZY33J+s+vMIgC84gQCOLZOYusW7ptgjIn8RwxQt0BUpoo9XRRVVWEHLld0uhyux1ZcA==", + "version": "16.0.10", + "resolved": "https://registry.npmjs.org/@next/swc-darwin-arm64/-/swc-darwin-arm64-16.0.10.tgz", + "integrity": "sha512-4XgdKtdVsaflErz+B5XeG0T5PeXKDdruDf3CRpnhN+8UebNa5N2H58+3GDgpn/9GBurrQ1uWW768FfscwYkJRg==", "cpu": [ "arm64" ], @@ -1106,9 +1107,9 @@ } }, "node_modules/@next/swc-darwin-x64": { - "version": "16.0.4", - "resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-16.0.4.tgz", - "integrity": "sha512-XsfI23jvimCaA7e+9f3yMCoVjrny2D11G6H8NCcgv+Ina/TQhKPXB9P4q0WjTuEoyZmcNvPdrZ+XtTh3uPfH7Q==", + "version": "16.0.10", + "resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-16.0.10.tgz", + "integrity": "sha512-spbEObMvRKkQ3CkYVOME+ocPDFo5UqHb8EMTS78/0mQ+O1nqE8toHJVioZo4TvebATxgA8XMTHHrScPrn68OGw==", "cpu": [ "x64" ], @@ -1122,9 +1123,9 @@ } }, "node_modules/@next/swc-linux-arm64-gnu": { - "version": "16.0.4", - "resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-16.0.4.tgz", - "integrity": "sha512-uo8X7qHDy4YdJUhaoJDMAbL8VT5Ed3lijip2DdBHIB4tfKAvB1XBih6INH2L4qIi4jA0Qq1J0ErxcOocBmUSwg==", + "version": "16.0.10", + "resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-16.0.10.tgz", + "integrity": "sha512-uQtWE3X0iGB8apTIskOMi2w/MKONrPOUCi5yLO+v3O8Mb5c7K4Q5KD1jvTpTF5gJKa3VH/ijKjKUq9O9UhwOYw==", "cpu": [ "arm64" ], @@ -1138,9 +1139,9 @@ } }, "node_modules/@next/swc-linux-arm64-musl": { - "version": "16.0.4", - "resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-16.0.4.tgz", - "integrity": "sha512-pvR/AjNIAxsIz0PCNcZYpH+WmNIKNLcL4XYEfo+ArDi7GsxKWFO5BvVBLXbhti8Coyv3DE983NsitzUsGH5yTw==", + "version": "16.0.10", + "resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-16.0.10.tgz", + "integrity": "sha512-llA+hiDTrYvyWI21Z0L1GiXwjQaanPVQQwru5peOgtooeJ8qx3tlqRV2P7uH2pKQaUfHxI/WVarvI5oYgGxaTw==", "cpu": [ "arm64" ], @@ -1154,9 +1155,9 @@ } }, "node_modules/@next/swc-linux-x64-gnu": { - "version": "16.0.4", - "resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-16.0.4.tgz", - "integrity": "sha512-2hebpsd5MRRtgqmT7Jj/Wze+wG+ZEXUK2KFFL4IlZ0amEEFADo4ywsifJNeFTQGsamH3/aXkKWymDvgEi+pc2Q==", + "version": "16.0.10", + "resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-16.0.10.tgz", + "integrity": "sha512-AK2q5H0+a9nsXbeZ3FZdMtbtu9jxW4R/NgzZ6+lrTm3d6Zb7jYrWcgjcpM1k8uuqlSy4xIyPR2YiuUr+wXsavA==", "cpu": [ "x64" ], @@ -1170,9 +1171,9 @@ } }, "node_modules/@next/swc-linux-x64-musl": { - "version": "16.0.4", - "resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-16.0.4.tgz", - "integrity": "sha512-pzRXf0LZZ8zMljH78j8SeLncg9ifIOp3ugAFka+Bq8qMzw6hPXOc7wydY7ardIELlczzzreahyTpwsim/WL3Sg==", + "version": "16.0.10", + "resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-16.0.10.tgz", + "integrity": "sha512-1TDG9PDKivNw5550S111gsO4RGennLVl9cipPhtkXIFVwo31YZ73nEbLjNC8qG3SgTz/QZyYyaFYMeY4BKZR/g==", "cpu": [ "x64" ], @@ -1186,9 +1187,9 @@ } }, "node_modules/@next/swc-win32-arm64-msvc": { - "version": "16.0.4", - "resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-16.0.4.tgz", - "integrity": "sha512-7G/yJVzum52B5HOqqbQYX9bJHkN+c4YyZ2AIvEssMHQlbAWOn3iIJjD4sM6ihWsBxuljiTKJovEYlD1K8lCUHw==", + "version": "16.0.10", + "resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-16.0.10.tgz", + "integrity": "sha512-aEZIS4Hh32xdJQbHz121pyuVZniSNoqDVx1yIr2hy+ZwJGipeqnMZBJHyMxv2tiuAXGx6/xpTcQJ6btIiBjgmg==", "cpu": [ "arm64" ], @@ -1202,9 +1203,9 @@ } }, "node_modules/@next/swc-win32-x64-msvc": { - "version": "16.0.4", - "resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-16.0.4.tgz", - "integrity": "sha512-0Vy4g8SSeVkuU89g2OFHqGKM4rxsQtihGfenjx2tRckPrge5+gtFnRWGAAwvGXr0ty3twQvcnYjEyOrLHJ4JWA==", + "version": "16.0.10", + "resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-16.0.10.tgz", + "integrity": "sha512-E+njfCoFLb01RAFEnGZn6ERoOqhK1Gl3Lfz1Kjnj0Ulfu7oJbuMyvBKNj/bw8XZnenHDASlygTjZICQW+rYW1Q==", "cpu": [ "x64" ], @@ -1585,13 +1586,39 @@ "dev": true, "license": "MIT" }, + "node_modules/@types/debug": { + "version": "4.1.12", + "resolved": "https://registry.npmjs.org/@types/debug/-/debug-4.1.12.tgz", + "integrity": "sha512-vIChWdVG3LG1SMxEvI/AK+FWJthlrqlTu7fbrlywTkkaONwk/UAGaULXRlf8vkzFBLVm0zkMdCquhL5aOjhXPQ==", + "license": "MIT", + "dependencies": { + "@types/ms": "*" + } + }, "node_modules/@types/estree": { "version": "1.0.8", "resolved": "https://registry.npmjs.org/@types/estree/-/estree-1.0.8.tgz", "integrity": "sha512-dWHzHa2WqEXI/O1E9OjrocMTKJl2mSrEolh1Iomrv6U+JuNwaHXsXx9bLu5gG7BUWFIN0skIQJQ/L1rIex4X6w==", - "dev": true, "license": "MIT" }, + "node_modules/@types/estree-jsx": { + "version": "1.0.5", + "resolved": "https://registry.npmjs.org/@types/estree-jsx/-/estree-jsx-1.0.5.tgz", + "integrity": "sha512-52CcUVNFyfb1A2ALocQw/Dd1BQFNmSdkuC3BkZ6iqhdMfQz7JWOFRuJFloOzjk+6WijU56m9oKXFAXc7o3Towg==", + "license": "MIT", + "dependencies": { + "@types/estree": "*" + } + }, + "node_modules/@types/hast": { + "version": "3.0.4", + "resolved": "https://registry.npmjs.org/@types/hast/-/hast-3.0.4.tgz", + "integrity": "sha512-WPs+bbQw5aCj+x6laNGWLH3wviHtoCv/P3+otBhbOhJgG8qtpdAMlTCxLtsTWA7LH1Oh/bFCHsBn0TPS5m30EQ==", + "license": "MIT", + "dependencies": { + "@types/unist": "*" + } + }, "node_modules/@types/json-schema": { "version": "7.0.15", "resolved": "https://registry.npmjs.org/@types/json-schema/-/json-schema-7.0.15.tgz", @@ -1606,6 +1633,21 @@ "dev": true, "license": "MIT" }, + "node_modules/@types/mdast": { + "version": "4.0.4", + "resolved": "https://registry.npmjs.org/@types/mdast/-/mdast-4.0.4.tgz", + "integrity": "sha512-kGaNbPh1k7AFzgpud/gMdvIm5xuECykRR+JnWKQno9TAXVa6WIVCGTPvYGekIDL4uwCZQSYbUxNBSb1aUo79oA==", + "license": "MIT", + "dependencies": { + "@types/unist": "*" + } + }, + "node_modules/@types/ms": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/@types/ms/-/ms-2.1.0.tgz", + "integrity": "sha512-GsCCIZDE/p3i96vtEqx+7dBUGXrc7zeSK3wwPHIaRThS+9OhWIXRqzs4d6k1SVU8g91DrNRWxWUGhp5KXQb2VA==", + "license": "MIT" + }, "node_modules/@types/node": { "version": "20.19.25", "resolved": "https://registry.npmjs.org/@types/node/-/node-20.19.25.tgz", @@ -1620,7 +1662,6 @@ "version": "19.2.7", "resolved": "https://registry.npmjs.org/@types/react/-/react-19.2.7.tgz", "integrity": "sha512-MWtvHrGZLFttgeEj28VXHxpmwYbor/ATPYbBfSFZEIRK0ecCFLl2Qo55z52Hss+UV9CRN7trSeq1zbgx7YDWWg==", - "dev": true, "license": "MIT", "peer": true, "dependencies": { @@ -1637,6 +1678,12 @@ "@types/react": "^19.2.0" } }, + "node_modules/@types/unist": { + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/@types/unist/-/unist-3.0.3.tgz", + "integrity": "sha512-ko/gIFJRv177XgZsZcBwnqJN5x/Gien8qNOn0D5bQU/zAzVf9Zt3BlcUiLqhV9y4ARk0GbT3tnUiPNgnTXzc/Q==", + "license": "MIT" + }, "node_modules/@types/uuid": { "version": "9.0.8", "resolved": "https://registry.npmjs.org/@types/uuid/-/uuid-9.0.8.tgz", @@ -1915,6 +1962,12 @@ "url": "https://opencollective.com/typescript-eslint" } }, + "node_modules/@ungap/structured-clone": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/@ungap/structured-clone/-/structured-clone-1.3.0.tgz", + "integrity": "sha512-WmoN8qaIAo7WTYWbAZuG8PYEhn5fkz7dZrqTBZ7dtt//lL2Gwms1IcnQ5yHqjDfX8Ft5j4YzDM23f87zBfDe9g==", + "license": "ISC" + }, "node_modules/@unrs/resolver-binding-android-arm-eabi": { "version": "1.11.1", "resolved": "https://registry.npmjs.org/@unrs/resolver-binding-android-arm-eabi/-/resolver-binding-android-arm-eabi-1.11.1.tgz", @@ -2480,6 +2533,16 @@ "node": ">= 0.4" } }, + "node_modules/bail": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/bail/-/bail-2.0.2.tgz", + "integrity": "sha512-0xO6mYd7JB2YesxDKplafRpsiOzPt9V02ddPCLbY1xYGPOX24NTyN50qnUxgCPcSoYMhKpAuBTjQoRZCAkUDRw==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, "node_modules/balanced-match": { "version": "1.0.2", "resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-1.0.2.tgz", @@ -2658,6 +2721,16 @@ "node": ">=12" } }, + "node_modules/ccount": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/ccount/-/ccount-2.0.1.tgz", + "integrity": "sha512-eyrF0jiFpY+3drT6383f1qhkbGsLSifNAjA61IUjZjmLCWjItY6LB9ft9YhoDgwfmclB2zhu51Lc7+95b8NRAg==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, "node_modules/chalk": { "version": "4.1.2", "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz", @@ -2675,6 +2748,46 @@ "url": "https://github.com/chalk/chalk?sponsor=1" } }, + "node_modules/character-entities": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/character-entities/-/character-entities-2.0.2.tgz", + "integrity": "sha512-shx7oQ0Awen/BRIdkjkvz54PnEEI/EjwXDSIZp86/KKdbafHh1Df/RYGBhn4hbe2+uKC9FnT5UCEdyPz3ai9hQ==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, + "node_modules/character-entities-html4": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/character-entities-html4/-/character-entities-html4-2.1.0.tgz", + "integrity": "sha512-1v7fgQRj6hnSwFpq1Eu0ynr/CDEw0rXo2B61qXrLNdHZmPKgb7fqS1a2JwF0rISo9q77jDI8VMEHoApn8qDoZA==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, + "node_modules/character-entities-legacy": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/character-entities-legacy/-/character-entities-legacy-3.0.0.tgz", + "integrity": "sha512-RpPp0asT/6ufRm//AJVwpViZbGM/MkjQFxJccQRHmISF/22NBtsHqAWmL+/pmkPWoIUJdWyeVleTl1wydHATVQ==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, + "node_modules/character-reference-invalid": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/character-reference-invalid/-/character-reference-invalid-2.0.1.tgz", + "integrity": "sha512-iBZ4F4wRbyORVsu0jPV7gXkOsGYjGHPmAyv+HiHG8gi5PtC9KI2j1+v8/tlibRvjoWX027ypmG/n0HtO5t7unw==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, "node_modules/classnames": { "version": "2.5.1", "resolved": "https://registry.npmjs.org/classnames/-/classnames-2.5.1.tgz", @@ -2707,6 +2820,16 @@ "dev": true, "license": "MIT" }, + "node_modules/comma-separated-tokens": { + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/comma-separated-tokens/-/comma-separated-tokens-2.0.3.tgz", + "integrity": "sha512-Fu4hJdvzeylCfQPp9SGWidpzrMs7tTrlu6Vb8XGaRGck8QSNZJJp538Wrb60Lax4fPwR64ViY468OIUTbRlGZg==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, "node_modules/concat-map": { "version": "0.0.1", "resolved": "https://registry.npmjs.org/concat-map/-/concat-map-0.0.1.tgz", @@ -2740,7 +2863,6 @@ "version": "3.2.3", "resolved": "https://registry.npmjs.org/csstype/-/csstype-3.2.3.tgz", "integrity": "sha512-z1HGKcYy2xA8AGQfwrn0PAy+PB7X/GSj3UVJW9qKyn43xWa+gl5nXmU4qqLMRzWVLFC8KusUX8T/0kCiOYpAIQ==", - "dev": true, "license": "MIT" }, "node_modules/culori": { @@ -3034,7 +3156,6 @@ "version": "4.4.3", "resolved": "https://registry.npmjs.org/debug/-/debug-4.4.3.tgz", "integrity": "sha512-RGwwWnwQvkVfavKVt22FGLw+xYSdzARwm0ru6DhTVA3umU5hZc28V3kO4stgYryrTlLpuvgI9GiijltAjNbcqA==", - "dev": true, "license": "MIT", "dependencies": { "ms": "^2.1.3" @@ -3048,6 +3169,19 @@ } } }, + "node_modules/decode-named-character-reference": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/decode-named-character-reference/-/decode-named-character-reference-1.2.0.tgz", + "integrity": "sha512-c6fcElNV6ShtZXmsgNgFFV5tVX2PaV4g+MOAkb8eXHvn6sryJBrZa9r0zV6+dtTyoCKxtDy5tyQ5ZwQuidtd+Q==", + "license": "MIT", + "dependencies": { + "character-entities": "^2.0.0" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, "node_modules/deep-is": { "version": "0.1.4", "resolved": "https://registry.npmjs.org/deep-is/-/deep-is-0.1.4.tgz", @@ -3110,6 +3244,19 @@ "node": ">=8" } }, + "node_modules/devlop": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/devlop/-/devlop-1.1.0.tgz", + "integrity": "sha512-RWmIqhcFf1lRYBvNmr7qTNuyCt/7/ns2jbpp1+PalgE/rDQcBT0fioSMUpJ93irlUhC5hrg4cYqe6U+0ImW0rA==", + "license": "MIT", + "dependencies": { + "dequal": "^2.0.0" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, "node_modules/doctrine": { "version": "2.1.0", "resolved": "https://registry.npmjs.org/doctrine/-/doctrine-2.1.0.tgz", @@ -3574,6 +3721,7 @@ "integrity": "sha512-whOE1HFo/qJDyX4SnXzP4N6zOWn79WhnCUY/iDR0mPfQZO8wcYE4JClzI2oZrhBnnMUCBCHZhO6VQyoBU95mZA==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "@rtsao/scc": "^1.1.0", "array-includes": "^3.1.9", @@ -3797,6 +3945,16 @@ "node": ">=4.0" } }, + "node_modules/estree-util-is-identifier-name": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/estree-util-is-identifier-name/-/estree-util-is-identifier-name-3.0.0.tgz", + "integrity": "sha512-hFtqIDZTIUZ9BXLb8y4pYGyk6+wekIivNVTcmvk8NoOh+VeRn5y6cEHzbURrWbfp1fIqdVipilzj+lfaadNZmg==", + "license": "MIT", + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, "node_modules/esutils": { "version": "2.0.3", "resolved": "https://registry.npmjs.org/esutils/-/esutils-2.0.3.tgz", @@ -3807,6 +3965,12 @@ "node": ">=0.10.0" } }, + "node_modules/extend": { + "version": "3.0.2", + "resolved": "https://registry.npmjs.org/extend/-/extend-3.0.2.tgz", + "integrity": "sha512-fjquC59cD7CyW6urNXK0FBufkZcoiGG80wTuPujX590cB5Ttln20E2UB4S/WARVqhXffZl2LNgS+gQdPIIim/g==", + "license": "MIT" + }, "node_modules/fast-deep-equal": { "version": "3.1.3", "resolved": "https://registry.npmjs.org/fast-deep-equal/-/fast-deep-equal-3.1.3.tgz", @@ -4283,6 +4447,46 @@ "node": ">= 0.4" } }, + "node_modules/hast-util-to-jsx-runtime": { + "version": "2.3.6", + "resolved": "https://registry.npmjs.org/hast-util-to-jsx-runtime/-/hast-util-to-jsx-runtime-2.3.6.tgz", + "integrity": "sha512-zl6s8LwNyo1P9uw+XJGvZtdFF1GdAkOg8ujOw+4Pyb76874fLps4ueHXDhXWdk6YHQ6OgUtinliG7RsYvCbbBg==", + "license": "MIT", + "dependencies": { + "@types/estree": "^1.0.0", + "@types/hast": "^3.0.0", + "@types/unist": "^3.0.0", + "comma-separated-tokens": "^2.0.0", + "devlop": "^1.0.0", + "estree-util-is-identifier-name": "^3.0.0", + "hast-util-whitespace": "^3.0.0", + "mdast-util-mdx-expression": "^2.0.0", + "mdast-util-mdx-jsx": "^3.0.0", + "mdast-util-mdxjs-esm": "^2.0.0", + "property-information": "^7.0.0", + "space-separated-tokens": "^2.0.0", + "style-to-js": "^1.0.0", + "unist-util-position": "^5.0.0", + "vfile-message": "^4.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/hast-util-whitespace": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/hast-util-whitespace/-/hast-util-whitespace-3.0.0.tgz", + "integrity": "sha512-88JUN06ipLwsnv+dVn+OIYOvAuvBMy/Qoi6O7mQHxdPXpjy+Cd6xRkWwux7DKO+4sYILtLBRIKgsdpS2gQc7qw==", + "license": "MIT", + "dependencies": { + "@types/hast": "^3.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, "node_modules/hermes-estree": { "version": "0.25.1", "resolved": "https://registry.npmjs.org/hermes-estree/-/hermes-estree-0.25.1.tgz", @@ -4300,6 +4504,16 @@ "hermes-estree": "0.25.1" } }, + "node_modules/html-url-attributes": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/html-url-attributes/-/html-url-attributes-3.0.1.tgz", + "integrity": "sha512-ol6UPyBWqsrO6EJySPz2O7ZSr856WDrEzM5zMqp+FJJLGMW35cLYmmZnl0vztAZxRUoNZJFTCohfjuIJ8I4QBQ==", + "license": "MIT", + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, "node_modules/ignore": { "version": "5.3.2", "resolved": "https://registry.npmjs.org/ignore/-/ignore-5.3.2.tgz", @@ -4346,6 +4560,12 @@ "node": ">=12" } }, + "node_modules/inline-style-parser": { + "version": "0.2.7", + "resolved": "https://registry.npmjs.org/inline-style-parser/-/inline-style-parser-0.2.7.tgz", + "integrity": "sha512-Nb2ctOyNR8DqQoR0OwRG95uNWIC0C1lCgf5Naz5H6Ji72KZ8OcFZLz2P5sNgwlyoJ8Yif11oMuYs5pBQa86csA==", + "license": "MIT" + }, "node_modules/internal-slot": { "version": "1.1.0", "resolved": "https://registry.npmjs.org/internal-slot/-/internal-slot-1.1.0.tgz", @@ -4370,6 +4590,30 @@ "node": ">=12" } }, + "node_modules/is-alphabetical": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/is-alphabetical/-/is-alphabetical-2.0.1.tgz", + "integrity": "sha512-FWyyY60MeTNyeSRpkM2Iry0G9hpr7/9kD40mD/cGQEuilcZYS4okz8SN2Q6rLCJ8gbCt6fN+rC+6tMGS99LaxQ==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, + "node_modules/is-alphanumerical": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/is-alphanumerical/-/is-alphanumerical-2.0.1.tgz", + "integrity": "sha512-hmbYhX/9MUMF5uh7tOXyK/n0ZvWpad5caBA17GsC6vyuCqaWliRG5K1qS9inmUhEMaOBIW7/whAnSwveW/LtZw==", + "license": "MIT", + "dependencies": { + "is-alphabetical": "^2.0.0", + "is-decimal": "^2.0.0" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, "node_modules/is-array-buffer": { "version": "3.0.5", "resolved": "https://registry.npmjs.org/is-array-buffer/-/is-array-buffer-3.0.5.tgz", @@ -4528,6 +4772,16 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/is-decimal": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/is-decimal/-/is-decimal-2.0.1.tgz", + "integrity": "sha512-AAB9hiomQs5DXWcRB1rqsxGUstbRroFOPPVAomNk/3XHR5JyEZChOyTWe2oayKnsSsr/kcGqF+z6yuH6HHpN0A==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, "node_modules/is-extglob": { "version": "2.1.1", "resolved": "https://registry.npmjs.org/is-extglob/-/is-extglob-2.1.1.tgz", @@ -4587,6 +4841,16 @@ "node": ">=0.10.0" } }, + "node_modules/is-hexadecimal": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/is-hexadecimal/-/is-hexadecimal-2.0.1.tgz", + "integrity": "sha512-DgZQp241c8oO6cA1SbTEWiXeoxV42vlcJxgH+B3hi1AiqqKruZR3ZGF8In3fj4+/y/7rHvlOZLZtgJ/4ttYGZg==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, "node_modules/is-map": { "version": "2.0.3", "resolved": "https://registry.npmjs.org/is-map/-/is-map-2.0.3.tgz", @@ -4640,6 +4904,18 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/is-plain-obj": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/is-plain-obj/-/is-plain-obj-4.1.0.tgz", + "integrity": "sha512-+Pgi+vMuUNkJyExiMBt5IlFoMyKnr5zhJ4Uspz58WOhBF5QoIZkFyNHIbBAtHwzVAgk5RtndVNsDRN61/mmDqg==", + "license": "MIT", + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, "node_modules/is-regex": { "version": "1.2.1", "resolved": "https://registry.npmjs.org/is-regex/-/is-regex-1.2.1.tgz", @@ -5273,6 +5549,16 @@ "dev": true, "license": "MIT" }, + "node_modules/longest-streak": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/longest-streak/-/longest-streak-3.1.0.tgz", + "integrity": "sha512-9Ri+o0JYgehTaVBBDoMqIl8GXtbWg711O3srftcHhZ0dqnETqLaoIK0x17fUw9rFSlK/0NlsKe0Ahhyl5pXE2g==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, "node_modules/loose-envify": { "version": "1.4.0", "resolved": "https://registry.npmjs.org/loose-envify/-/loose-envify-1.4.0.tgz", @@ -5315,6 +5601,159 @@ "node": ">= 0.4" } }, + "node_modules/mdast-util-from-markdown": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/mdast-util-from-markdown/-/mdast-util-from-markdown-2.0.2.tgz", + "integrity": "sha512-uZhTV/8NBuw0WHkPTrCqDOl0zVe1BIng5ZtHoDk49ME1qqcjYmmLmOf0gELgcRMxN4w2iuIeVso5/6QymSrgmA==", + "license": "MIT", + "dependencies": { + "@types/mdast": "^4.0.0", + "@types/unist": "^3.0.0", + "decode-named-character-reference": "^1.0.0", + "devlop": "^1.0.0", + "mdast-util-to-string": "^4.0.0", + "micromark": "^4.0.0", + "micromark-util-decode-numeric-character-reference": "^2.0.0", + "micromark-util-decode-string": "^2.0.0", + "micromark-util-normalize-identifier": "^2.0.0", + "micromark-util-symbol": "^2.0.0", + "micromark-util-types": "^2.0.0", + "unist-util-stringify-position": "^4.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/mdast-util-mdx-expression": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/mdast-util-mdx-expression/-/mdast-util-mdx-expression-2.0.1.tgz", + "integrity": "sha512-J6f+9hUp+ldTZqKRSg7Vw5V6MqjATc+3E4gf3CFNcuZNWD8XdyI6zQ8GqH7f8169MM6P7hMBRDVGnn7oHB9kXQ==", + "license": "MIT", + "dependencies": { + "@types/estree-jsx": "^1.0.0", + "@types/hast": "^3.0.0", + "@types/mdast": "^4.0.0", + "devlop": "^1.0.0", + "mdast-util-from-markdown": "^2.0.0", + "mdast-util-to-markdown": "^2.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/mdast-util-mdx-jsx": { + "version": "3.2.0", + "resolved": "https://registry.npmjs.org/mdast-util-mdx-jsx/-/mdast-util-mdx-jsx-3.2.0.tgz", + "integrity": "sha512-lj/z8v0r6ZtsN/cGNNtemmmfoLAFZnjMbNyLzBafjzikOM+glrjNHPlf6lQDOTccj9n5b0PPihEBbhneMyGs1Q==", + "license": "MIT", + "dependencies": { + "@types/estree-jsx": "^1.0.0", + "@types/hast": "^3.0.0", + "@types/mdast": "^4.0.0", + "@types/unist": "^3.0.0", + "ccount": "^2.0.0", + "devlop": "^1.1.0", + "mdast-util-from-markdown": "^2.0.0", + "mdast-util-to-markdown": "^2.0.0", + "parse-entities": "^4.0.0", + "stringify-entities": "^4.0.0", + "unist-util-stringify-position": "^4.0.0", + "vfile-message": "^4.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/mdast-util-mdxjs-esm": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/mdast-util-mdxjs-esm/-/mdast-util-mdxjs-esm-2.0.1.tgz", + "integrity": "sha512-EcmOpxsZ96CvlP03NghtH1EsLtr0n9Tm4lPUJUBccV9RwUOneqSycg19n5HGzCf+10LozMRSObtVr3ee1WoHtg==", + "license": "MIT", + "dependencies": { + "@types/estree-jsx": "^1.0.0", + "@types/hast": "^3.0.0", + "@types/mdast": "^4.0.0", + "devlop": "^1.0.0", + "mdast-util-from-markdown": "^2.0.0", + "mdast-util-to-markdown": "^2.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/mdast-util-phrasing": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/mdast-util-phrasing/-/mdast-util-phrasing-4.1.0.tgz", + "integrity": "sha512-TqICwyvJJpBwvGAMZjj4J2n0X8QWp21b9l0o7eXyVJ25YNWYbJDVIyD1bZXE6WtV6RmKJVYmQAKWa0zWOABz2w==", + "license": "MIT", + "dependencies": { + "@types/mdast": "^4.0.0", + "unist-util-is": "^6.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/mdast-util-to-hast": { + "version": "13.2.1", + "resolved": "https://registry.npmjs.org/mdast-util-to-hast/-/mdast-util-to-hast-13.2.1.tgz", + "integrity": "sha512-cctsq2wp5vTsLIcaymblUriiTcZd0CwWtCbLvrOzYCDZoWyMNV8sZ7krj09FSnsiJi3WVsHLM4k6Dq/yaPyCXA==", + "license": "MIT", + "dependencies": { + "@types/hast": "^3.0.0", + "@types/mdast": "^4.0.0", + "@ungap/structured-clone": "^1.0.0", + "devlop": "^1.0.0", + "micromark-util-sanitize-uri": "^2.0.0", + "trim-lines": "^3.0.0", + "unist-util-position": "^5.0.0", + "unist-util-visit": "^5.0.0", + "vfile": "^6.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/mdast-util-to-markdown": { + "version": "2.1.2", + "resolved": "https://registry.npmjs.org/mdast-util-to-markdown/-/mdast-util-to-markdown-2.1.2.tgz", + "integrity": "sha512-xj68wMTvGXVOKonmog6LwyJKrYXZPvlwabaryTjLh9LuvovB/KAH+kvi8Gjj+7rJjsFi23nkUxRQv1KqSroMqA==", + "license": "MIT", + "dependencies": { + "@types/mdast": "^4.0.0", + "@types/unist": "^3.0.0", + "longest-streak": "^3.0.0", + "mdast-util-phrasing": "^4.0.0", + "mdast-util-to-string": "^4.0.0", + "micromark-util-classify-character": "^2.0.0", + "micromark-util-decode-string": "^2.0.0", + "unist-util-visit": "^5.0.0", + "zwitch": "^2.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/mdast-util-to-string": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/mdast-util-to-string/-/mdast-util-to-string-4.0.0.tgz", + "integrity": "sha512-0H44vDimn51F0YwvxSJSm0eCDOJTRlmN0R1yBh4HLj9wiV1Dn0QoXGbvFAWj2hSItVTlCmBF1hqKlIyUBVFLPg==", + "license": "MIT", + "dependencies": { + "@types/mdast": "^4.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, "node_modules/merge2": { "version": "1.4.1", "resolved": "https://registry.npmjs.org/merge2/-/merge2-1.4.1.tgz", @@ -5325,6 +5764,448 @@ "node": ">= 8" } }, + "node_modules/micromark": { + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/micromark/-/micromark-4.0.2.tgz", + "integrity": "sha512-zpe98Q6kvavpCr1NPVSCMebCKfD7CA2NqZ+rykeNhONIJBpc1tFKt9hucLGwha3jNTNI8lHpctWJWoimVF4PfA==", + "funding": [ + { + "type": "GitHub Sponsors", + "url": "https://github.com/sponsors/unifiedjs" + }, + { + "type": "OpenCollective", + "url": "https://opencollective.com/unified" + } + ], + "license": "MIT", + "dependencies": { + "@types/debug": "^4.0.0", + "debug": "^4.0.0", + "decode-named-character-reference": "^1.0.0", + "devlop": "^1.0.0", + "micromark-core-commonmark": "^2.0.0", + "micromark-factory-space": "^2.0.0", + "micromark-util-character": "^2.0.0", + "micromark-util-chunked": "^2.0.0", + "micromark-util-combine-extensions": "^2.0.0", + "micromark-util-decode-numeric-character-reference": "^2.0.0", + "micromark-util-encode": "^2.0.0", + "micromark-util-normalize-identifier": "^2.0.0", + "micromark-util-resolve-all": "^2.0.0", + "micromark-util-sanitize-uri": "^2.0.0", + "micromark-util-subtokenize": "^2.0.0", + "micromark-util-symbol": "^2.0.0", + "micromark-util-types": "^2.0.0" + } + }, + "node_modules/micromark-core-commonmark": { + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/micromark-core-commonmark/-/micromark-core-commonmark-2.0.3.tgz", + "integrity": "sha512-RDBrHEMSxVFLg6xvnXmb1Ayr2WzLAWjeSATAoxwKYJV94TeNavgoIdA0a9ytzDSVzBy2YKFK+emCPOEibLeCrg==", + "funding": [ + { + "type": "GitHub Sponsors", + "url": "https://github.com/sponsors/unifiedjs" + }, + { + "type": "OpenCollective", + "url": "https://opencollective.com/unified" + } + ], + "license": "MIT", + "dependencies": { + "decode-named-character-reference": "^1.0.0", + "devlop": "^1.0.0", + "micromark-factory-destination": "^2.0.0", + "micromark-factory-label": "^2.0.0", + "micromark-factory-space": "^2.0.0", + "micromark-factory-title": "^2.0.0", + "micromark-factory-whitespace": "^2.0.0", + "micromark-util-character": "^2.0.0", + "micromark-util-chunked": "^2.0.0", + "micromark-util-classify-character": "^2.0.0", + "micromark-util-html-tag-name": "^2.0.0", + "micromark-util-normalize-identifier": "^2.0.0", + "micromark-util-resolve-all": "^2.0.0", + "micromark-util-subtokenize": "^2.0.0", + "micromark-util-symbol": "^2.0.0", + "micromark-util-types": "^2.0.0" + } + }, + "node_modules/micromark-factory-destination": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/micromark-factory-destination/-/micromark-factory-destination-2.0.1.tgz", + "integrity": "sha512-Xe6rDdJlkmbFRExpTOmRj9N3MaWmbAgdpSrBQvCFqhezUn4AHqJHbaEnfbVYYiexVSs//tqOdY/DxhjdCiJnIA==", + "funding": [ + { + "type": "GitHub Sponsors", + "url": "https://github.com/sponsors/unifiedjs" + }, + { + "type": "OpenCollective", + "url": "https://opencollective.com/unified" + } + ], + "license": "MIT", + "dependencies": { + "micromark-util-character": "^2.0.0", + "micromark-util-symbol": "^2.0.0", + "micromark-util-types": "^2.0.0" + } + }, + "node_modules/micromark-factory-label": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/micromark-factory-label/-/micromark-factory-label-2.0.1.tgz", + "integrity": "sha512-VFMekyQExqIW7xIChcXn4ok29YE3rnuyveW3wZQWWqF4Nv9Wk5rgJ99KzPvHjkmPXF93FXIbBp6YdW3t71/7Vg==", + "funding": [ + { + "type": "GitHub Sponsors", + "url": "https://github.com/sponsors/unifiedjs" + }, + { + "type": "OpenCollective", + "url": "https://opencollective.com/unified" + } + ], + "license": "MIT", + "dependencies": { + "devlop": "^1.0.0", + "micromark-util-character": "^2.0.0", + "micromark-util-symbol": "^2.0.0", + "micromark-util-types": "^2.0.0" + } + }, + "node_modules/micromark-factory-space": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/micromark-factory-space/-/micromark-factory-space-2.0.1.tgz", + "integrity": "sha512-zRkxjtBxxLd2Sc0d+fbnEunsTj46SWXgXciZmHq0kDYGnck/ZSGj9/wULTV95uoeYiK5hRXP2mJ98Uo4cq/LQg==", + "funding": [ + { + "type": "GitHub Sponsors", + "url": "https://github.com/sponsors/unifiedjs" + }, + { + "type": "OpenCollective", + "url": "https://opencollective.com/unified" + } + ], + "license": "MIT", + "dependencies": { + "micromark-util-character": "^2.0.0", + "micromark-util-types": "^2.0.0" + } + }, + "node_modules/micromark-factory-title": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/micromark-factory-title/-/micromark-factory-title-2.0.1.tgz", + "integrity": "sha512-5bZ+3CjhAd9eChYTHsjy6TGxpOFSKgKKJPJxr293jTbfry2KDoWkhBb6TcPVB4NmzaPhMs1Frm9AZH7OD4Cjzw==", + "funding": [ + { + "type": "GitHub Sponsors", + "url": "https://github.com/sponsors/unifiedjs" + }, + { + "type": "OpenCollective", + "url": "https://opencollective.com/unified" + } + ], + "license": "MIT", + "dependencies": { + "micromark-factory-space": "^2.0.0", + "micromark-util-character": "^2.0.0", + "micromark-util-symbol": "^2.0.0", + "micromark-util-types": "^2.0.0" + } + }, + "node_modules/micromark-factory-whitespace": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/micromark-factory-whitespace/-/micromark-factory-whitespace-2.0.1.tgz", + "integrity": "sha512-Ob0nuZ3PKt/n0hORHyvoD9uZhr+Za8sFoP+OnMcnWK5lngSzALgQYKMr9RJVOWLqQYuyn6ulqGWSXdwf6F80lQ==", + "funding": [ + { + "type": "GitHub Sponsors", + "url": "https://github.com/sponsors/unifiedjs" + }, + { + "type": "OpenCollective", + "url": "https://opencollective.com/unified" + } + ], + "license": "MIT", + "dependencies": { + "micromark-factory-space": "^2.0.0", + "micromark-util-character": "^2.0.0", + "micromark-util-symbol": "^2.0.0", + "micromark-util-types": "^2.0.0" + } + }, + "node_modules/micromark-util-character": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/micromark-util-character/-/micromark-util-character-2.1.1.tgz", + "integrity": "sha512-wv8tdUTJ3thSFFFJKtpYKOYiGP2+v96Hvk4Tu8KpCAsTMs6yi+nVmGh1syvSCsaxz45J6Jbw+9DD6g97+NV67Q==", + "funding": [ + { + "type": "GitHub Sponsors", + "url": "https://github.com/sponsors/unifiedjs" + }, + { + "type": "OpenCollective", + "url": "https://opencollective.com/unified" + } + ], + "license": "MIT", + "dependencies": { + "micromark-util-symbol": "^2.0.0", + "micromark-util-types": "^2.0.0" + } + }, + "node_modules/micromark-util-chunked": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/micromark-util-chunked/-/micromark-util-chunked-2.0.1.tgz", + "integrity": "sha512-QUNFEOPELfmvv+4xiNg2sRYeS/P84pTW0TCgP5zc9FpXetHY0ab7SxKyAQCNCc1eK0459uoLI1y5oO5Vc1dbhA==", + "funding": [ + { + "type": "GitHub Sponsors", + "url": "https://github.com/sponsors/unifiedjs" + }, + { + "type": "OpenCollective", + "url": "https://opencollective.com/unified" + } + ], + "license": "MIT", + "dependencies": { + "micromark-util-symbol": "^2.0.0" + } + }, + "node_modules/micromark-util-classify-character": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/micromark-util-classify-character/-/micromark-util-classify-character-2.0.1.tgz", + "integrity": "sha512-K0kHzM6afW/MbeWYWLjoHQv1sgg2Q9EccHEDzSkxiP/EaagNzCm7T/WMKZ3rjMbvIpvBiZgwR3dKMygtA4mG1Q==", + "funding": [ + { + "type": "GitHub Sponsors", + "url": "https://github.com/sponsors/unifiedjs" + }, + { + "type": "OpenCollective", + "url": "https://opencollective.com/unified" + } + ], + "license": "MIT", + "dependencies": { + "micromark-util-character": "^2.0.0", + "micromark-util-symbol": "^2.0.0", + "micromark-util-types": "^2.0.0" + } + }, + "node_modules/micromark-util-combine-extensions": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/micromark-util-combine-extensions/-/micromark-util-combine-extensions-2.0.1.tgz", + "integrity": "sha512-OnAnH8Ujmy59JcyZw8JSbK9cGpdVY44NKgSM7E9Eh7DiLS2E9RNQf0dONaGDzEG9yjEl5hcqeIsj4hfRkLH/Bg==", + "funding": [ + { + "type": "GitHub Sponsors", + "url": "https://github.com/sponsors/unifiedjs" + }, + { + "type": "OpenCollective", + "url": "https://opencollective.com/unified" + } + ], + "license": "MIT", + "dependencies": { + "micromark-util-chunked": "^2.0.0", + "micromark-util-types": "^2.0.0" + } + }, + "node_modules/micromark-util-decode-numeric-character-reference": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/micromark-util-decode-numeric-character-reference/-/micromark-util-decode-numeric-character-reference-2.0.2.tgz", + "integrity": "sha512-ccUbYk6CwVdkmCQMyr64dXz42EfHGkPQlBj5p7YVGzq8I7CtjXZJrubAYezf7Rp+bjPseiROqe7G6foFd+lEuw==", + "funding": [ + { + "type": "GitHub Sponsors", + "url": "https://github.com/sponsors/unifiedjs" + }, + { + "type": "OpenCollective", + "url": "https://opencollective.com/unified" + } + ], + "license": "MIT", + "dependencies": { + "micromark-util-symbol": "^2.0.0" + } + }, + "node_modules/micromark-util-decode-string": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/micromark-util-decode-string/-/micromark-util-decode-string-2.0.1.tgz", + "integrity": "sha512-nDV/77Fj6eH1ynwscYTOsbK7rR//Uj0bZXBwJZRfaLEJ1iGBR6kIfNmlNqaqJf649EP0F3NWNdeJi03elllNUQ==", + "funding": [ + { + "type": "GitHub Sponsors", + "url": "https://github.com/sponsors/unifiedjs" + }, + { + "type": "OpenCollective", + "url": "https://opencollective.com/unified" + } + ], + "license": "MIT", + "dependencies": { + "decode-named-character-reference": "^1.0.0", + "micromark-util-character": "^2.0.0", + "micromark-util-decode-numeric-character-reference": "^2.0.0", + "micromark-util-symbol": "^2.0.0" + } + }, + "node_modules/micromark-util-encode": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/micromark-util-encode/-/micromark-util-encode-2.0.1.tgz", + "integrity": "sha512-c3cVx2y4KqUnwopcO9b/SCdo2O67LwJJ/UyqGfbigahfegL9myoEFoDYZgkT7f36T0bLrM9hZTAaAyH+PCAXjw==", + "funding": [ + { + "type": "GitHub Sponsors", + "url": "https://github.com/sponsors/unifiedjs" + }, + { + "type": "OpenCollective", + "url": "https://opencollective.com/unified" + } + ], + "license": "MIT" + }, + "node_modules/micromark-util-html-tag-name": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/micromark-util-html-tag-name/-/micromark-util-html-tag-name-2.0.1.tgz", + "integrity": "sha512-2cNEiYDhCWKI+Gs9T0Tiysk136SnR13hhO8yW6BGNyhOC4qYFnwF1nKfD3HFAIXA5c45RrIG1ub11GiXeYd1xA==", + "funding": [ + { + "type": "GitHub Sponsors", + "url": "https://github.com/sponsors/unifiedjs" + }, + { + "type": "OpenCollective", + "url": "https://opencollective.com/unified" + } + ], + "license": "MIT" + }, + "node_modules/micromark-util-normalize-identifier": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/micromark-util-normalize-identifier/-/micromark-util-normalize-identifier-2.0.1.tgz", + "integrity": "sha512-sxPqmo70LyARJs0w2UclACPUUEqltCkJ6PhKdMIDuJ3gSf/Q+/GIe3WKl0Ijb/GyH9lOpUkRAO2wp0GVkLvS9Q==", + "funding": [ + { + "type": "GitHub Sponsors", + "url": "https://github.com/sponsors/unifiedjs" + }, + { + "type": "OpenCollective", + "url": "https://opencollective.com/unified" + } + ], + "license": "MIT", + "dependencies": { + "micromark-util-symbol": "^2.0.0" + } + }, + "node_modules/micromark-util-resolve-all": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/micromark-util-resolve-all/-/micromark-util-resolve-all-2.0.1.tgz", + "integrity": "sha512-VdQyxFWFT2/FGJgwQnJYbe1jjQoNTS4RjglmSjTUlpUMa95Htx9NHeYW4rGDJzbjvCsl9eLjMQwGeElsqmzcHg==", + "funding": [ + { + "type": "GitHub Sponsors", + "url": "https://github.com/sponsors/unifiedjs" + }, + { + "type": "OpenCollective", + "url": "https://opencollective.com/unified" + } + ], + "license": "MIT", + "dependencies": { + "micromark-util-types": "^2.0.0" + } + }, + "node_modules/micromark-util-sanitize-uri": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/micromark-util-sanitize-uri/-/micromark-util-sanitize-uri-2.0.1.tgz", + "integrity": "sha512-9N9IomZ/YuGGZZmQec1MbgxtlgougxTodVwDzzEouPKo3qFWvymFHWcnDi2vzV1ff6kas9ucW+o3yzJK9YB1AQ==", + "funding": [ + { + "type": "GitHub Sponsors", + "url": "https://github.com/sponsors/unifiedjs" + }, + { + "type": "OpenCollective", + "url": "https://opencollective.com/unified" + } + ], + "license": "MIT", + "dependencies": { + "micromark-util-character": "^2.0.0", + "micromark-util-encode": "^2.0.0", + "micromark-util-symbol": "^2.0.0" + } + }, + "node_modules/micromark-util-subtokenize": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/micromark-util-subtokenize/-/micromark-util-subtokenize-2.1.0.tgz", + "integrity": "sha512-XQLu552iSctvnEcgXw6+Sx75GflAPNED1qx7eBJ+wydBb2KCbRZe+NwvIEEMM83uml1+2WSXpBAcp9IUCgCYWA==", + "funding": [ + { + "type": "GitHub Sponsors", + "url": "https://github.com/sponsors/unifiedjs" + }, + { + "type": "OpenCollective", + "url": "https://opencollective.com/unified" + } + ], + "license": "MIT", + "dependencies": { + "devlop": "^1.0.0", + "micromark-util-chunked": "^2.0.0", + "micromark-util-symbol": "^2.0.0", + "micromark-util-types": "^2.0.0" + } + }, + "node_modules/micromark-util-symbol": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/micromark-util-symbol/-/micromark-util-symbol-2.0.1.tgz", + "integrity": "sha512-vs5t8Apaud9N28kgCrRUdEed4UJ+wWNvicHLPxCa9ENlYuAY31M0ETy5y1vA33YoNPDFTghEbnh6efaE8h4x0Q==", + "funding": [ + { + "type": "GitHub Sponsors", + "url": "https://github.com/sponsors/unifiedjs" + }, + { + "type": "OpenCollective", + "url": "https://opencollective.com/unified" + } + ], + "license": "MIT" + }, + "node_modules/micromark-util-types": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/micromark-util-types/-/micromark-util-types-2.0.2.tgz", + "integrity": "sha512-Yw0ECSpJoViF1qTU4DC6NwtC4aWGt1EkzaQB8KPPyCRR8z9TWeV0HbEFGTO+ZY1wB22zmxnJqhPyTpOVCpeHTA==", + "funding": [ + { + "type": "GitHub Sponsors", + "url": "https://github.com/sponsors/unifiedjs" + }, + { + "type": "OpenCollective", + "url": "https://opencollective.com/unified" + } + ], + "license": "MIT" + }, "node_modules/micromatch": { "version": "4.0.8", "resolved": "https://registry.npmjs.org/micromatch/-/micromatch-4.0.8.tgz", @@ -5366,7 +6247,6 @@ "version": "2.1.3", "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz", "integrity": "sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==", - "dev": true, "license": "MIT" }, "node_modules/nanoid": { @@ -5411,13 +6291,13 @@ "license": "MIT" }, "node_modules/next": { - "version": "16.0.4", - "resolved": "https://registry.npmjs.org/next/-/next-16.0.4.tgz", - "integrity": "sha512-vICcxKusY8qW7QFOzTvnRL1ejz2ClTqDKtm1AcUjm2mPv/lVAdgpGNsftsPRIDJOXOjRQO68i1dM8Lp8GZnqoA==", + "version": "16.0.10", + "resolved": "https://registry.npmjs.org/next/-/next-16.0.10.tgz", + "integrity": "sha512-RtWh5PUgI+vxlV3HdR+IfWA1UUHu0+Ram/JBO4vWB54cVPentCD0e+lxyAYEsDTqGGMg7qpjhKh6dc6aW7W/sA==", "license": "MIT", "peer": true, "dependencies": { - "@next/env": "16.0.4", + "@next/env": "16.0.10", "@swc/helpers": "0.5.15", "caniuse-lite": "^1.0.30001579", "postcss": "8.4.31", @@ -5430,14 +6310,14 @@ "node": ">=20.9.0" }, "optionalDependencies": { - "@next/swc-darwin-arm64": "16.0.4", - "@next/swc-darwin-x64": "16.0.4", - "@next/swc-linux-arm64-gnu": "16.0.4", - "@next/swc-linux-arm64-musl": "16.0.4", - "@next/swc-linux-x64-gnu": "16.0.4", - "@next/swc-linux-x64-musl": "16.0.4", - "@next/swc-win32-arm64-msvc": "16.0.4", - "@next/swc-win32-x64-msvc": "16.0.4", + "@next/swc-darwin-arm64": "16.0.10", + "@next/swc-darwin-x64": "16.0.10", + "@next/swc-linux-arm64-gnu": "16.0.10", + "@next/swc-linux-arm64-musl": "16.0.10", + "@next/swc-linux-x64-gnu": "16.0.10", + "@next/swc-linux-x64-musl": "16.0.10", + "@next/swc-win32-arm64-msvc": "16.0.10", + "@next/swc-win32-x64-msvc": "16.0.10", "sharp": "^0.34.4" }, "peerDependencies": { @@ -5723,6 +6603,31 @@ "node": ">=6" } }, + "node_modules/parse-entities": { + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/parse-entities/-/parse-entities-4.0.2.tgz", + "integrity": "sha512-GG2AQYWoLgL877gQIKeRPGO1xF9+eG1ujIb5soS5gPvLQ1y2o8FL90w2QWNdf9I361Mpp7726c+lj3U0qK1uGw==", + "license": "MIT", + "dependencies": { + "@types/unist": "^2.0.0", + "character-entities-legacy": "^3.0.0", + "character-reference-invalid": "^2.0.0", + "decode-named-character-reference": "^1.0.0", + "is-alphanumerical": "^2.0.0", + "is-decimal": "^2.0.0", + "is-hexadecimal": "^2.0.0" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, + "node_modules/parse-entities/node_modules/@types/unist": { + "version": "2.0.11", + "resolved": "https://registry.npmjs.org/@types/unist/-/unist-2.0.11.tgz", + "integrity": "sha512-CmBKiL6NNo/OqgmMn95Fk9Whlp2mtvIv+KNpQKN2F4SjvrEesubTRWGYSg+BnWZOnlCaSTU1sMpsBOzgbYhnsA==", + "license": "MIT" + }, "node_modules/path-exists": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/path-exists/-/path-exists-4.0.0.tgz", @@ -5839,6 +6744,16 @@ "react-is": "^16.13.1" } }, + "node_modules/property-information": { + "version": "7.1.0", + "resolved": "https://registry.npmjs.org/property-information/-/property-information-7.1.0.tgz", + "integrity": "sha512-TwEZ+X+yCJmYfL7TPUOcvBZ4QfoT5YenQiJuX//0th53DE6w0xxLEtfK3iyryQFddXuvkIk51EEgrJQ0WJkOmQ==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, "node_modules/punycode": { "version": "2.3.1", "resolved": "https://registry.npmjs.org/punycode/-/punycode-2.3.1.tgz", @@ -5931,6 +6846,33 @@ "react": ">=16.13.1" } }, + "node_modules/react-markdown": { + "version": "10.1.0", + "resolved": "https://registry.npmjs.org/react-markdown/-/react-markdown-10.1.0.tgz", + "integrity": "sha512-qKxVopLT/TyA6BX3Ue5NwabOsAzm0Q7kAPwq6L+wWDwisYs7R8vZ0nRXqq6rkueboxpkjvLGU9fWifiX/ZZFxQ==", + "license": "MIT", + "dependencies": { + "@types/hast": "^3.0.0", + "@types/mdast": "^4.0.0", + "devlop": "^1.0.0", + "hast-util-to-jsx-runtime": "^2.0.0", + "html-url-attributes": "^3.0.0", + "mdast-util-to-hast": "^13.0.0", + "remark-parse": "^11.0.0", + "remark-rehype": "^11.0.0", + "unified": "^11.0.0", + "unist-util-visit": "^5.0.0", + "vfile": "^6.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + }, + "peerDependencies": { + "@types/react": ">=18", + "react": ">=18" + } + }, "node_modules/reflect.getprototypeof": { "version": "1.0.10", "resolved": "https://registry.npmjs.org/reflect.getprototypeof/-/reflect.getprototypeof-1.0.10.tgz", @@ -5975,6 +6917,39 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/remark-parse": { + "version": "11.0.0", + "resolved": "https://registry.npmjs.org/remark-parse/-/remark-parse-11.0.0.tgz", + "integrity": "sha512-FCxlKLNGknS5ba/1lmpYijMUzX2esxW5xQqjWxw2eHFfS2MSdaHVINFmhjo+qN1WhZhNimq0dZATN9pH0IDrpA==", + "license": "MIT", + "dependencies": { + "@types/mdast": "^4.0.0", + "mdast-util-from-markdown": "^2.0.0", + "micromark-util-types": "^2.0.0", + "unified": "^11.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/remark-rehype": { + "version": "11.1.2", + "resolved": "https://registry.npmjs.org/remark-rehype/-/remark-rehype-11.1.2.tgz", + "integrity": "sha512-Dh7l57ianaEoIpzbp0PC9UKAdCSVklD8E5Rpw7ETfbTl3FqcOOgq5q2LVDhgGCkaBv7p24JXikPdvhhmHvKMsw==", + "license": "MIT", + "dependencies": { + "@types/hast": "^3.0.0", + "@types/mdast": "^4.0.0", + "mdast-util-to-hast": "^13.0.0", + "unified": "^11.0.0", + "vfile": "^6.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, "node_modules/resolve": { "version": "1.22.11", "resolved": "https://registry.npmjs.org/resolve/-/resolve-1.22.11.tgz", @@ -6337,6 +7312,16 @@ "node": ">=0.10.0" } }, + "node_modules/space-separated-tokens": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/space-separated-tokens/-/space-separated-tokens-2.0.2.tgz", + "integrity": "sha512-PEGlAwrG8yXGXRjW32fGbg66JAlOAwbObuqVoJpv/mRgoWDQfgH1wDPvtzWyUSNAXBGSk8h755YDbbcEy3SH2Q==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, "node_modules/stable-hash": { "version": "0.0.5", "resolved": "https://registry.npmjs.org/stable-hash/-/stable-hash-0.0.5.tgz", @@ -6471,6 +7456,20 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/stringify-entities": { + "version": "4.0.4", + "resolved": "https://registry.npmjs.org/stringify-entities/-/stringify-entities-4.0.4.tgz", + "integrity": "sha512-IwfBptatlO+QCJUo19AqvrPNqlVMpW9YEL2LIVY+Rpv2qsjCGxaDLNRgeGsQWJhfItebuJhsGSLjaBbNSQ+ieg==", + "license": "MIT", + "dependencies": { + "character-entities-html4": "^2.0.0", + "character-entities-legacy": "^3.0.0" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, "node_modules/strip-bom": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/strip-bom/-/strip-bom-3.0.0.tgz", @@ -6494,6 +7493,24 @@ "url": "https://github.com/sponsors/sindresorhus" } }, + "node_modules/style-to-js": { + "version": "1.1.21", + "resolved": "https://registry.npmjs.org/style-to-js/-/style-to-js-1.1.21.tgz", + "integrity": "sha512-RjQetxJrrUJLQPHbLku6U/ocGtzyjbJMP9lCNK7Ag0CNh690nSH8woqWH9u16nMjYBAok+i7JO1NP2pOy8IsPQ==", + "license": "MIT", + "dependencies": { + "style-to-object": "1.0.14" + } + }, + "node_modules/style-to-object": { + "version": "1.0.14", + "resolved": "https://registry.npmjs.org/style-to-object/-/style-to-object-1.0.14.tgz", + "integrity": "sha512-LIN7rULI0jBscWQYaSswptyderlarFkjQ+t79nzty8tcIAceVomEVlLzH5VP4Cmsv6MtKhs7qaAiwlcp+Mgaxw==", + "license": "MIT", + "dependencies": { + "inline-style-parser": "0.2.7" + } + }, "node_modules/styled-jsx": { "version": "5.1.6", "resolved": "https://registry.npmjs.org/styled-jsx/-/styled-jsx-5.1.6.tgz", @@ -6645,6 +7662,26 @@ "node": ">=8.0" } }, + "node_modules/trim-lines": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/trim-lines/-/trim-lines-3.0.1.tgz", + "integrity": "sha512-kRj8B+YHZCc9kQYdWfJB2/oUl9rA99qbowYYBtr4ui4mZyAQ2JpvVBd/6U2YloATfqBhBTSMhTpgBHtU0Mf3Rg==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, + "node_modules/trough": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/trough/-/trough-2.2.0.tgz", + "integrity": "sha512-tmMpK00BjZiUyVyvrBK7knerNgmgvcV/KLVyuma/SC+TQN167GrMRciANTz09+k3zW8L8t60jWO1GpfkZdjTaw==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } + }, "node_modules/ts-api-utils": { "version": "2.1.0", "resolved": "https://registry.npmjs.org/ts-api-utils/-/ts-api-utils-2.1.0.tgz", @@ -6846,6 +7883,93 @@ "dev": true, "license": "MIT" }, + "node_modules/unified": { + "version": "11.0.5", + "resolved": "https://registry.npmjs.org/unified/-/unified-11.0.5.tgz", + "integrity": "sha512-xKvGhPWw3k84Qjh8bI3ZeJjqnyadK+GEFtazSfZv/rKeTkTjOJho6mFqh2SM96iIcZokxiOpg78GazTSg8+KHA==", + "license": "MIT", + "dependencies": { + "@types/unist": "^3.0.0", + "bail": "^2.0.0", + "devlop": "^1.0.0", + "extend": "^3.0.0", + "is-plain-obj": "^4.0.0", + "trough": "^2.0.0", + "vfile": "^6.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/unist-util-is": { + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/unist-util-is/-/unist-util-is-6.0.1.tgz", + "integrity": "sha512-LsiILbtBETkDz8I9p1dQ0uyRUWuaQzd/cuEeS1hoRSyW5E5XGmTzlwY1OrNzzakGowI9Dr/I8HVaw4hTtnxy8g==", + "license": "MIT", + "dependencies": { + "@types/unist": "^3.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/unist-util-position": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/unist-util-position/-/unist-util-position-5.0.0.tgz", + "integrity": "sha512-fucsC7HjXvkB5R3kTCO7kUjRdrS0BJt3M/FPxmHMBOm8JQi2BsHAHFsy27E0EolP8rp0NzXsJ+jNPyDWvOJZPA==", + "license": "MIT", + "dependencies": { + "@types/unist": "^3.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/unist-util-stringify-position": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/unist-util-stringify-position/-/unist-util-stringify-position-4.0.0.tgz", + "integrity": "sha512-0ASV06AAoKCDkS2+xw5RXJywruurpbC4JZSm7nr7MOt1ojAzvyyaO+UxZf18j8FCF6kmzCZKcAgN/yu2gm2XgQ==", + "license": "MIT", + "dependencies": { + "@types/unist": "^3.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/unist-util-visit": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/unist-util-visit/-/unist-util-visit-5.0.0.tgz", + "integrity": "sha512-MR04uvD+07cwl/yhVuVWAtw+3GOR/knlL55Nd/wAdblk27GCVt3lqpTivy/tkJcZoNPzTwS1Y+KMojlLDhoTzg==", + "license": "MIT", + "dependencies": { + "@types/unist": "^3.0.0", + "unist-util-is": "^6.0.0", + "unist-util-visit-parents": "^6.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/unist-util-visit-parents": { + "version": "6.0.2", + "resolved": "https://registry.npmjs.org/unist-util-visit-parents/-/unist-util-visit-parents-6.0.2.tgz", + "integrity": "sha512-goh1s1TBrqSqukSc8wrjwWhL0hiJxgA8m4kFxGlQ+8FYQ3C/m11FcTs4YYem7V664AhHVvgoQLk890Ssdsr2IQ==", + "license": "MIT", + "dependencies": { + "@types/unist": "^3.0.0", + "unist-util-is": "^6.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, "node_modules/unrs-resolver": { "version": "1.11.1", "resolved": "https://registry.npmjs.org/unrs-resolver/-/unrs-resolver-1.11.1.tgz", @@ -6944,6 +8068,34 @@ "uuid": "dist/bin/uuid" } }, + "node_modules/vfile": { + "version": "6.0.3", + "resolved": "https://registry.npmjs.org/vfile/-/vfile-6.0.3.tgz", + "integrity": "sha512-KzIbH/9tXat2u30jf+smMwFCsno4wHVdNmzFyL+T/L3UGqqk6JKfVqOFOZEpZSHADH1k40ab6NUIXZq422ov3Q==", + "license": "MIT", + "dependencies": { + "@types/unist": "^3.0.0", + "vfile-message": "^4.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, + "node_modules/vfile-message": { + "version": "4.0.3", + "resolved": "https://registry.npmjs.org/vfile-message/-/vfile-message-4.0.3.tgz", + "integrity": "sha512-QTHzsGd1EhbZs4AsQ20JX1rC3cOlt/IWJruk893DfLRr57lcnOeMaWG4K0JrRta4mIJZKth2Au3mM3u03/JWKw==", + "license": "MIT", + "dependencies": { + "@types/unist": "^3.0.0", + "unist-util-stringify-position": "^4.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, "node_modules/which": { "version": "2.0.2", "resolved": "https://registry.npmjs.org/which/-/which-2.0.2.tgz", @@ -7102,6 +8254,16 @@ "peerDependencies": { "zod": "^3.25.0 || ^4.0.0" } + }, + "node_modules/zwitch": { + "version": "2.0.4", + "resolved": "https://registry.npmjs.org/zwitch/-/zwitch-2.0.4.tgz", + "integrity": "sha512-bXE4cR/kVZhKZX/RjPEflHaKVhUVl85noU3v6b8apfQEc1x4A+zBxjZ4lN8LqGd6WZ3dl98pY4o717VFmoPp+A==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/wooorm" + } } } } diff --git a/cognee-frontend/package.json b/cognee-frontend/package.json index 4195945fd..5ac6d6787 100644 --- a/cognee-frontend/package.json +++ b/cognee-frontend/package.json @@ -13,10 +13,11 @@ "classnames": "^2.5.1", "culori": "^4.0.1", "d3-force-3d": "^3.0.6", - "next": "16.0.4", + "next": "^16.0.10", "react": "^19.2.0", "react-dom": "^19.2.0", "react-force-graph-2d": "^1.27.1", + "react-markdown": "^10.1.0", "uuid": "^9.0.1" }, "devDependencies": { diff --git a/cognee-frontend/src/app/dashboard/Dashboard.tsx b/cognee-frontend/src/app/dashboard/Dashboard.tsx index 75e3d7518..d333c0bf1 100644 --- a/cognee-frontend/src/app/dashboard/Dashboard.tsx +++ b/cognee-frontend/src/app/dashboard/Dashboard.tsx @@ -15,6 +15,8 @@ import AddDataToCognee from "./AddDataToCognee"; import NotebooksAccordion from "./NotebooksAccordion"; import CogneeInstancesAccordion from "./CogneeInstancesAccordion"; import InstanceDatasetsAccordion from "./InstanceDatasetsAccordion"; +import cloudFetch from "@/modules/instances/cloudFetch"; +import localFetch from "@/modules/instances/localFetch"; interface DashboardProps { user?: { @@ -26,6 +28,17 @@ interface DashboardProps { accessToken: string; } +const cogneeInstances = { + cloudCognee: { + name: "CloudCognee", + fetch: cloudFetch, + }, + localCognee: { + name: "LocalCognee", + fetch: localFetch, + } +}; + export default function Dashboard({ accessToken }: DashboardProps) { fetch.setAccessToken(accessToken); const { user } = useAuthenticatedUser(); @@ -38,7 +51,7 @@ export default function Dashboard({ accessToken }: DashboardProps) { updateNotebook, saveNotebook, removeNotebook, - } = useNotebooks(); + } = useNotebooks(cogneeInstances.localCognee); useEffect(() => { if (!notebooks.length) { diff --git a/cognee-frontend/src/app/dashboard/InstanceDatasetsAccordion.tsx b/cognee-frontend/src/app/dashboard/InstanceDatasetsAccordion.tsx index f094f7caf..3bfa3c057 100644 --- a/cognee-frontend/src/app/dashboard/InstanceDatasetsAccordion.tsx +++ b/cognee-frontend/src/app/dashboard/InstanceDatasetsAccordion.tsx @@ -3,6 +3,7 @@ import { useCallback, useEffect } from "react"; import { fetch, isCloudEnvironment, useBoolean } from "@/utils"; import { checkCloudConnection } from "@/modules/cloud"; +import { setApiKey } from "@/modules/instances/cloudFetch"; import { CaretIcon, CloseIcon, CloudIcon, LocalCogneeIcon } from "@/ui/Icons"; import { CTAButton, GhostButton, IconButton, Input, Modal } from "@/ui/elements"; @@ -24,6 +25,7 @@ export default function InstanceDatasetsAccordion({ onDatasetsChange }: Instance const checkConnectionToCloudCognee = useCallback((apiKey?: string) => { if (apiKey) { fetch.setApiKey(apiKey); + setApiKey(apiKey); } return checkCloudConnection() .then(setCloudCogneeConnected) diff --git a/cognee-frontend/src/modules/ingestion/useDatasets.ts b/cognee-frontend/src/modules/ingestion/useDatasets.ts index e25d9f932..ab8006a9f 100644 --- a/cognee-frontend/src/modules/ingestion/useDatasets.ts +++ b/cognee-frontend/src/modules/ingestion/useDatasets.ts @@ -95,6 +95,7 @@ function useDatasets(useCloud = false) { }) .catch((error) => { console.error('Error fetching datasets:', error); + throw error; }); }, [useCloud]); diff --git a/cognee-frontend/src/modules/instances/cloudFetch.ts b/cognee-frontend/src/modules/instances/cloudFetch.ts new file mode 100644 index 000000000..6806feac8 --- /dev/null +++ b/cognee-frontend/src/modules/instances/cloudFetch.ts @@ -0,0 +1,59 @@ +import handleServerErrors from "@/utils/handleServerErrors"; + +// let numberOfRetries = 0; + +const cloudApiUrl = process.env.NEXT_PUBLIC_CLOUD_API_URL || "http://localhost:8001"; + +let apiKey: string | null = process.env.NEXT_PUBLIC_COGWIT_API_KEY || null; + +export function setApiKey(newApiKey: string) { + apiKey = newApiKey; +}; + +export default async function cloudFetch(url: URL | RequestInfo, options: RequestInit = {}): Promise { + // function retry(lastError: Response) { + // if (numberOfRetries >= 1) { + // return Promise.reject(lastError); + // } + + // numberOfRetries += 1; + + // return global.fetch("/auth/token") + // .then(() => { + // return fetch(url, options); + // }); + // } + + const authHeaders = { + "Authorization": `X-Api-Key ${apiKey}`, + }; + + return global.fetch( + cloudApiUrl + "/api" + (typeof url === "string" ? url : url.toString()).replace("/v1", ""), + { + ...options, + headers: { + ...options.headers, + ...authHeaders, + } as HeadersInit, + credentials: "include", + }, + ) + .then((response) => handleServerErrors(response, null, true)) + .catch((error) => { + if (error.message === "NEXT_REDIRECT") { + throw error; + } + + if (error.detail === undefined) { + return Promise.reject( + new Error("No connection to the server.") + ); + } + + return Promise.reject(error); + }); + // .finally(() => { + // numberOfRetries = 0; + // }); +} diff --git a/cognee-frontend/src/modules/instances/localFetch.ts b/cognee-frontend/src/modules/instances/localFetch.ts new file mode 100644 index 000000000..3cc16fda3 --- /dev/null +++ b/cognee-frontend/src/modules/instances/localFetch.ts @@ -0,0 +1,27 @@ +import handleServerErrors from "@/utils/handleServerErrors"; + +const localApiUrl = process.env.NEXT_PUBLIC_LOCAL_API_URL || "http://localhost:8000"; + +export default async function localFetch(url: URL | RequestInfo, options: RequestInit = {}): Promise { + return global.fetch( + localApiUrl + "/api" + url, + { + ...options, + credentials: "include", + }, + ) + .then((response) => handleServerErrors(response, null, false)) + .catch((error) => { + if (error.message === "NEXT_REDIRECT") { + throw error; + } + + if (error.detail === undefined) { + return Promise.reject( + new Error("No connection to the server.") + ); + } + + return Promise.reject(error); + }); +} diff --git a/cognee-frontend/src/modules/instances/types.ts b/cognee-frontend/src/modules/instances/types.ts new file mode 100644 index 000000000..5becdf53f --- /dev/null +++ b/cognee-frontend/src/modules/instances/types.ts @@ -0,0 +1,4 @@ +export interface CogneeInstance { + name: string; + fetch: typeof global.fetch; +} diff --git a/cognee-frontend/src/modules/notebooks/createNotebook.ts b/cognee-frontend/src/modules/notebooks/createNotebook.ts new file mode 100644 index 000000000..f45a57b5e --- /dev/null +++ b/cognee-frontend/src/modules/notebooks/createNotebook.ts @@ -0,0 +1,11 @@ +import { CogneeInstance } from "@/modules/instances/types"; + +export default function createNotebook(notebookName: string, instance: CogneeInstance) { + return instance.fetch("/v1/notebooks/", { + body: JSON.stringify({ name: notebookName }), + method: "POST", + headers: { + "Content-Type": "application/json", + }, + }).then((response: Response) => response.json()); +} diff --git a/cognee-frontend/src/modules/notebooks/deleteNotebook.ts b/cognee-frontend/src/modules/notebooks/deleteNotebook.ts new file mode 100644 index 000000000..2718526b5 --- /dev/null +++ b/cognee-frontend/src/modules/notebooks/deleteNotebook.ts @@ -0,0 +1,7 @@ +import { CogneeInstance } from "@/modules/instances/types"; + +export default function deleteNotebook(notebookId: string, instance: CogneeInstance) { + return instance.fetch(`/v1/notebooks/${notebookId}`, { + method: "DELETE", + }); +} diff --git a/cognee-frontend/src/modules/notebooks/getNotebooks.ts b/cognee-frontend/src/modules/notebooks/getNotebooks.ts new file mode 100644 index 000000000..b4f329ede --- /dev/null +++ b/cognee-frontend/src/modules/notebooks/getNotebooks.ts @@ -0,0 +1,10 @@ +import { CogneeInstance } from "@/modules/instances/types"; + +export default function getNotebooks(instance: CogneeInstance) { + return instance.fetch("/v1/notebooks/", { + method: "GET", + headers: { + "Content-Type": "application/json", + }, + }).then((response) => response.json()); +} diff --git a/cognee-frontend/src/modules/notebooks/runNotebookCell.ts b/cognee-frontend/src/modules/notebooks/runNotebookCell.ts new file mode 100644 index 000000000..26df16d46 --- /dev/null +++ b/cognee-frontend/src/modules/notebooks/runNotebookCell.ts @@ -0,0 +1,14 @@ +import { Cell } from "@/ui/elements/Notebook/types"; +import { CogneeInstance } from "@/modules/instances/types"; + +export default function runNotebookCell(notebookId: string, cell: Cell, instance: CogneeInstance) { + return instance.fetch(`/v1/notebooks/${notebookId}/${cell.id}/run`, { + body: JSON.stringify({ + content: cell.content, + }), + method: "POST", + headers: { + "Content-Type": "application/json", + }, + }).then((response: Response) => response.json()); +} diff --git a/cognee-frontend/src/modules/notebooks/saveNotebook.ts b/cognee-frontend/src/modules/notebooks/saveNotebook.ts new file mode 100644 index 000000000..8ad9188ba --- /dev/null +++ b/cognee-frontend/src/modules/notebooks/saveNotebook.ts @@ -0,0 +1,11 @@ +import { CogneeInstance } from "@/modules/instances/types"; + +export default function saveNotebook(notebookId: string, notebookData: object, instance: CogneeInstance) { + return instance.fetch(`/v1/notebooks/${notebookId}`, { + body: JSON.stringify(notebookData), + method: "PUT", + headers: { + "Content-Type": "application/json", + }, + }).then((response) => response.json()); +} diff --git a/cognee-frontend/src/modules/notebooks/useNotebooks.ts b/cognee-frontend/src/modules/notebooks/useNotebooks.ts index e427f85ee..57da95b81 100644 --- a/cognee-frontend/src/modules/notebooks/useNotebooks.ts +++ b/cognee-frontend/src/modules/notebooks/useNotebooks.ts @@ -1,20 +1,18 @@ import { useCallback, useState } from "react"; -import { fetch, isCloudEnvironment } from "@/utils"; import { Cell, Notebook } from "@/ui/elements/Notebook/types"; +import { CogneeInstance } from "@/modules/instances/types"; +import createNotebook from "./createNotebook"; +import deleteNotebook from "./deleteNotebook"; +import getNotebooks from "./getNotebooks"; +import runNotebookCell from "./runNotebookCell"; +import { default as persistNotebook } from "./saveNotebook"; -function useNotebooks() { +function useNotebooks(instance: CogneeInstance) { const [notebooks, setNotebooks] = useState([]); const addNotebook = useCallback((notebookName: string) => { - return fetch("/v1/notebooks", { - body: JSON.stringify({ name: notebookName }), - method: "POST", - headers: { - "Content-Type": "application/json", - }, - }, isCloudEnvironment()) - .then((response) => response.json()) - .then((notebook) => { + return createNotebook(notebookName, instance) + .then((notebook: Notebook) => { setNotebooks((notebooks) => [ ...notebooks, notebook, @@ -22,36 +20,29 @@ function useNotebooks() { return notebook; }); - }, []); + }, [instance]); const removeNotebook = useCallback((notebookId: string) => { - return fetch(`/v1/notebooks/${notebookId}`, { - method: "DELETE", - }, isCloudEnvironment()) + return deleteNotebook(notebookId, instance) .then(() => { setNotebooks((notebooks) => notebooks.filter((notebook) => notebook.id !== notebookId) ); }); - }, []); + }, [instance]); const fetchNotebooks = useCallback(() => { - return fetch("/v1/notebooks", { - headers: { - "Content-Type": "application/json", - }, - }, isCloudEnvironment()) - .then((response) => response.json()) + return getNotebooks(instance) .then((notebooks) => { setNotebooks(notebooks); return notebooks; }) .catch((error) => { - console.error("Error fetching notebooks:", error); + console.error("Error fetching notebooks:", error.detail); throw error }); - }, []); + }, [instance]); const updateNotebook = useCallback((updatedNotebook: Notebook) => { setNotebooks((existingNotebooks) => @@ -64,20 +55,13 @@ function useNotebooks() { }, []); const saveNotebook = useCallback((notebook: Notebook) => { - return fetch(`/v1/notebooks/${notebook.id}`, { - body: JSON.stringify({ - name: notebook.name, - cells: notebook.cells, - }), - method: "PUT", - headers: { - "Content-Type": "application/json", - }, - }, isCloudEnvironment()) - .then((response) => response.json()) - }, []); + return persistNotebook(notebook.id, { + name: notebook.name, + cells: notebook.cells, + }, instance); + }, [instance]); - const runCell = useCallback((notebook: Notebook, cell: Cell, cogneeInstance: string) => { + const runCell = useCallback((notebook: Notebook, cell: Cell) => { setNotebooks((existingNotebooks) => existingNotebooks.map((existingNotebook) => existingNotebook.id === notebook.id ? { @@ -89,20 +73,11 @@ function useNotebooks() { error: undefined, } : existingCell ), - } : notebook + } : existingNotebook ) ); - return fetch(`/v1/notebooks/${notebook.id}/${cell.id}/run`, { - body: JSON.stringify({ - content: cell.content, - }), - method: "POST", - headers: { - "Content-Type": "application/json", - }, - }, cogneeInstance === "cloud") - .then((response) => response.json()) + return runNotebookCell(notebook.id, cell, instance) .then((response) => { setNotebooks((existingNotebooks) => existingNotebooks.map((existingNotebook) => @@ -115,11 +90,11 @@ function useNotebooks() { error: response.error, } : existingCell ), - } : notebook + } : existingNotebook ) ); }); - }, []); + }, [instance]); return { notebooks, diff --git a/cognee-frontend/src/ui/elements/Notebook/MarkdownPreview.tsx b/cognee-frontend/src/ui/elements/Notebook/MarkdownPreview.tsx new file mode 100644 index 000000000..6ea69bfc6 --- /dev/null +++ b/cognee-frontend/src/ui/elements/Notebook/MarkdownPreview.tsx @@ -0,0 +1,77 @@ +import { memo } from "react"; +import ReactMarkdown from "react-markdown"; + +interface MarkdownPreviewProps { + content: string; + className?: string; +} + +function MarkdownPreview({ content, className = "" }: MarkdownPreviewProps) { + return ( +
+

{children}

, + h2: ({ children }) =>

{children}

, + h3: ({ children }) =>

{children}

, + h4: ({ children }) =>

{children}

, + h5: ({ children }) =>
{children}
, + h6: ({ children }) =>
{children}
, + p: ({ children }) =>

{children}

, + ul: ({ children }) =>
    {children}
, + ol: ({ children }) =>
    {children}
, + li: ({ children }) =>
  • {children}
  • , + blockquote: ({ children }) => ( +
    {children}
    + ), + code: ({ className, children, ...props }) => { + const isInline = !className; + return isInline ? ( + + {children} + + ) : ( + + {children} + + ); + }, + pre: ({ children }) => ( +
    +              {children}
    +            
    + ), + a: ({ href, children }) => ( + + {children} + + ), + strong: ({ children }) => {children}, + em: ({ children }) => {children}, + hr: () =>
    , + table: ({ children }) => ( +
    + {children}
    +
    + ), + thead: ({ children }) => {children}, + tbody: ({ children }) => {children}, + tr: ({ children }) => {children}, + th: ({ children }) => ( + + {children} + + ), + td: ({ children }) => ( + {children} + ), + }} + > + {content} +
    +
    + ); +} + +export default memo(MarkdownPreview); + diff --git a/cognee-frontend/src/ui/elements/Notebook/Notebook.tsx b/cognee-frontend/src/ui/elements/Notebook/Notebook.tsx index 69556552b..b6b935229 100644 --- a/cognee-frontend/src/ui/elements/Notebook/Notebook.tsx +++ b/cognee-frontend/src/ui/elements/Notebook/Notebook.tsx @@ -2,15 +2,17 @@ import { v4 as uuid4 } from "uuid"; import classNames from "classnames"; -import { Fragment, MouseEvent, RefObject, useCallback, useEffect, useRef, useState } from "react"; +import { Fragment, MouseEvent, MutableRefObject, useCallback, useEffect, useRef, useState, memo } from "react"; import { useModal } from "@/ui/elements/Modal"; import { CaretIcon, CloseIcon, PlusIcon } from "@/ui/Icons"; -import { IconButton, PopupMenu, TextArea, Modal, GhostButton, CTAButton } from "@/ui/elements"; +import PopupMenu from "@/ui/elements/PopupMenu"; +import { IconButton, TextArea, Modal, GhostButton, CTAButton } from "@/ui/elements"; import { GraphControlsAPI } from "@/app/(graph)/GraphControls"; import GraphVisualization, { GraphVisualizationAPI } from "@/app/(graph)/GraphVisualization"; import NotebookCellHeader from "./NotebookCellHeader"; +import MarkdownPreview from "./MarkdownPreview"; import { Cell, Notebook as NotebookType } from "./types"; interface NotebookProps { @@ -19,7 +21,186 @@ interface NotebookProps { updateNotebook: (updatedNotebook: NotebookType) => void; } +interface NotebookCellProps { + cell: Cell; + index: number; + isOpen: boolean; + isMarkdownEditMode: boolean; + onToggleOpen: () => void; + onToggleMarkdownEdit: () => void; + onContentChange: (value: string) => void; + onCellRun: (cell: Cell, cogneeInstance: string) => Promise; + onCellRename: (cell: Cell) => void; + onCellRemove: (cell: Cell) => void; + onCellUp: (cell: Cell) => void; + onCellDown: (cell: Cell) => void; + onCellAdd: (afterCellIndex: number, cellType: "markdown" | "code") => void; +} + +const NotebookCell = memo(function NotebookCell({ + cell, + index, + isOpen, + isMarkdownEditMode, + onToggleOpen, + onToggleMarkdownEdit, + onContentChange, + onCellRun, + onCellRename, + onCellRemove, + onCellUp, + onCellDown, + onCellAdd, +}: NotebookCellProps) { + return ( + +
    +
    + {cell.type === "code" ? ( + <> +
    + + + +
    + + + + {isOpen && ( + <> +