chore: Cog 2354 add logging (#1115)

<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.

---------

Co-authored-by: hajdul88 <52442977+hajdul88@users.noreply.github.com>
This commit is contained in:
Vasilije 2025-07-24 13:27:27 +02:00 committed by GitHub
parent d6727a1b4a
commit 1885ab9e88
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 271 additions and 108 deletions

View file

@ -104,6 +104,7 @@ class KuzuAdapter(GraphDBInterface):
max_db_size=4096 * 1024 * 1024, max_db_size=4096 * 1024 * 1024,
) )
self.db.init_database() self.db.init_database()
self.connection = Connection(self.db) self.connection = Connection(self.db)
# Create node table with essential fields and timestamp # Create node table with essential fields and timestamp

View file

@ -33,7 +33,7 @@ from .neo4j_metrics_utils import (
from .deadlock_retry import deadlock_retry from .deadlock_retry import deadlock_retry
logger = get_logger("Neo4jAdapter", level=ERROR) logger = get_logger("Neo4jAdapter")
BASE_LABEL = "__Node__" BASE_LABEL = "__Node__"
@ -870,34 +870,52 @@ class Neo4jAdapter(GraphDBInterface):
A tuple containing two lists: nodes and edges with their properties. A tuple containing two lists: nodes and edges with their properties.
""" """
query = "MATCH (n) RETURN ID(n) AS id, labels(n) AS labels, properties(n) AS properties" import time
result = await self.query(query) start_time = time.time()
nodes = [ try:
( # Retrieve nodes
record["properties"]["id"], query = "MATCH (n) RETURN ID(n) AS id, labels(n) AS labels, properties(n) AS properties"
record["properties"], result = await self.query(query)
nodes = []
for record in result:
nodes.append(
(
record["properties"]["id"],
record["properties"],
)
)
# Retrieve edges
query = """
MATCH (n)-[r]->(m)
RETURN ID(n) AS source, ID(m) AS target, TYPE(r) AS type, properties(r) AS properties
"""
result = await self.query(query)
edges = []
for record in result:
edges.append(
(
record["properties"]["source_node_id"],
record["properties"]["target_node_id"],
record["type"],
record["properties"],
)
)
retrieval_time = time.time() - start_time
logger.info(
f"Retrieved {len(nodes)} nodes and {len(edges)} edges in {retrieval_time:.2f} seconds"
) )
for record in result
]
query = """ return (nodes, edges)
MATCH (n)-[r]->(m)
RETURN ID(n) AS source, ID(m) AS target, TYPE(r) AS type, properties(r) AS properties
"""
result = await self.query(query)
edges = [
(
record["properties"]["source_node_id"],
record["properties"]["target_node_id"],
record["type"],
record["properties"],
)
for record in result
]
return (nodes, edges) except Exception as e:
logger.error(f"Error during graph data retrieval: {str(e)}")
raise
async def get_nodeset_subgraph( async def get_nodeset_subgraph(
self, node_type: Type[Any], node_name: List[str] self, node_type: Type[Any], node_name: List[str]
@ -918,50 +936,71 @@ class Neo4jAdapter(GraphDBInterface):
- Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]}: A tuple - Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]}: A tuple
containing nodes and edges in the requested subgraph. containing nodes and edges in the requested subgraph.
""" """
label = node_type.__name__ import time
query = f""" start_time = time.time()
UNWIND $names AS wantedName
MATCH (n:`{label}`)
WHERE n.name = wantedName
WITH collect(DISTINCT n) AS primary
UNWIND primary AS p
OPTIONAL MATCH (p)--(nbr)
WITH primary, collect(DISTINCT nbr) AS nbrs
WITH primary + nbrs AS nodelist
UNWIND nodelist AS node
WITH collect(DISTINCT node) AS nodes
MATCH (a)-[r]-(b)
WHERE a IN nodes AND b IN nodes
WITH nodes, collect(DISTINCT r) AS rels
RETURN
[n IN nodes |
{{ id: n.id,
properties: properties(n) }}] AS rawNodes,
[r IN rels |
{{ type: type(r),
properties: properties(r) }}] AS rawRels
"""
result = await self.query(query, {"names": node_name}) try:
if not result: label = node_type.__name__
return [], []
raw_nodes = result[0]["rawNodes"] query = f"""
raw_rels = result[0]["rawRels"] UNWIND $names AS wantedName
MATCH (n:`{label}`)
WHERE n.name = wantedName
WITH collect(DISTINCT n) AS primary
UNWIND primary AS p
OPTIONAL MATCH (p)--(nbr)
WITH primary, collect(DISTINCT nbr) AS nbrs
WITH primary + nbrs AS nodelist
UNWIND nodelist AS node
WITH collect(DISTINCT node) AS nodes
MATCH (a)-[r]-(b)
WHERE a IN nodes AND b IN nodes
WITH nodes, collect(DISTINCT r) AS rels
RETURN
[n IN nodes |
{{ id: n.id,
properties: properties(n) }}] AS rawNodes,
[r IN rels |
{{ type: type(r),
properties: properties(r) }}] AS rawRels
"""
nodes = [(n["properties"]["id"], n["properties"]) for n in raw_nodes] result = await self.query(query, {"names": node_name})
edges = [
( if not result:
r["properties"]["source_node_id"], return [], []
r["properties"]["target_node_id"],
r["type"], raw_nodes = result[0]["rawNodes"]
r["properties"], raw_rels = result[0]["rawRels"]
# Process nodes
nodes = []
for n in raw_nodes:
nodes.append((n["properties"]["id"], n["properties"]))
# Process edges
edges = []
for r in raw_rels:
edges.append(
(
r["properties"]["source_node_id"],
r["properties"]["target_node_id"],
r["type"],
r["properties"],
)
)
retrieval_time = time.time() - start_time
logger.info(
f"Retrieved {len(nodes)} nodes and {len(edges)} edges for {node_type.__name__} in {retrieval_time:.2f} seconds"
) )
for r in raw_rels
]
return nodes, edges return nodes, edges
except Exception as e:
logger.error(f"Error during nodeset subgraph retrieval: {str(e)}")
raise
async def get_filtered_graph_data(self, attribute_filters): async def get_filtered_graph_data(self, attribute_filters):
""" """
@ -1011,8 +1050,8 @@ class Neo4jAdapter(GraphDBInterface):
edges = [ edges = [
( (
record["source"], record["properties"]["source_node_id"],
record["target"], record["properties"]["target_node_id"],
record["type"], record["type"],
record["properties"], record["properties"],
) )

View file

@ -8,7 +8,7 @@ from cognee.modules.graph.cognee_graph.CogneeGraphElements import Node, Edge
from cognee.modules.graph.cognee_graph.CogneeAbstractGraph import CogneeAbstractGraph from cognee.modules.graph.cognee_graph.CogneeAbstractGraph import CogneeAbstractGraph
import heapq import heapq
logger = get_logger() logger = get_logger("CogneeGraph")
class CogneeGraph(CogneeAbstractGraph): class CogneeGraph(CogneeAbstractGraph):
@ -66,7 +66,13 @@ class CogneeGraph(CogneeAbstractGraph):
) -> None: ) -> None:
if node_dimension < 1 or edge_dimension < 1: if node_dimension < 1 or edge_dimension < 1:
raise InvalidValueError(message="Dimensions must be positive integers") raise InvalidValueError(message="Dimensions must be positive integers")
try: try:
import time
start_time = time.time()
# Determine projection strategy
if node_type is not None and node_name is not None: if node_type is not None and node_name is not None:
nodes_data, edges_data = await adapter.get_nodeset_subgraph( nodes_data, edges_data = await adapter.get_nodeset_subgraph(
node_type=node_type, node_name=node_name node_type=node_type, node_name=node_name
@ -83,16 +89,17 @@ class CogneeGraph(CogneeAbstractGraph):
nodes_data, edges_data = await adapter.get_filtered_graph_data( nodes_data, edges_data = await adapter.get_filtered_graph_data(
attribute_filters=memory_fragment_filter attribute_filters=memory_fragment_filter
) )
if not nodes_data or not edges_data: if not nodes_data or not edges_data:
raise EntityNotFoundError( raise EntityNotFoundError(
message="Empty filtered graph projected from the database." message="Empty filtered graph projected from the database."
) )
# Process nodes
for node_id, properties in nodes_data: for node_id, properties in nodes_data:
node_attributes = {key: properties.get(key) for key in node_properties_to_project} node_attributes = {key: properties.get(key) for key in node_properties_to_project}
self.add_node(Node(str(node_id), node_attributes, dimension=node_dimension)) self.add_node(Node(str(node_id), node_attributes, dimension=node_dimension))
# Process edges
for source_id, target_id, relationship_type, properties in edges_data: for source_id, target_id, relationship_type, properties in edges_data:
source_node = self.get_node(str(source_id)) source_node = self.get_node(str(source_id))
target_node = self.get_node(str(target_id)) target_node = self.get_node(str(target_id))
@ -113,17 +120,23 @@ class CogneeGraph(CogneeAbstractGraph):
source_node.add_skeleton_edge(edge) source_node.add_skeleton_edge(edge)
target_node.add_skeleton_edge(edge) target_node.add_skeleton_edge(edge)
else: else:
raise EntityNotFoundError( raise EntityNotFoundError(
message=f"Edge references nonexistent nodes: {source_id} -> {target_id}" message=f"Edge references nonexistent nodes: {source_id} -> {target_id}"
) )
except (ValueError, TypeError) as e: # Final statistics
print(f"Error projecting graph: {e}") projection_time = time.time() - start_time
raise e logger.info(
f"Graph projection completed: {len(self.nodes)} nodes, {len(self.edges)} edges in {projection_time:.2f}s"
)
except Exception as e:
logger.error(f"Error during graph projection: {str(e)}")
raise
async def map_vector_distances_to_graph_nodes(self, node_distances) -> None: async def map_vector_distances_to_graph_nodes(self, node_distances) -> None:
mapped_nodes = 0
for category, scored_results in node_distances.items(): for category, scored_results in node_distances.items():
for scored_result in scored_results: for scored_result in scored_results:
node_id = str(scored_result.id) node_id = str(scored_result.id)
@ -131,6 +144,7 @@ class CogneeGraph(CogneeAbstractGraph):
node = self.get_node(node_id) node = self.get_node(node_id)
if node: if node:
node.add_attribute("vector_distance", score) node.add_attribute("vector_distance", score)
mapped_nodes += 1
async def map_vector_distances_to_graph_edges( async def map_vector_distances_to_graph_edges(
self, vector_engine, query_vector, edge_distances self, vector_engine, query_vector, edge_distances
@ -150,18 +164,16 @@ class CogneeGraph(CogneeAbstractGraph):
for edge in self.edges: for edge in self.edges:
relationship_type = edge.attributes.get("relationship_type") relationship_type = edge.attributes.get("relationship_type")
if not relationship_type or relationship_type not in embedding_map: if relationship_type and relationship_type in embedding_map:
print(f"Edge {edge} has an unknown or missing relationship type.") edge.attributes["vector_distance"] = embedding_map[relationship_type]
continue
edge.attributes["vector_distance"] = embedding_map[relationship_type]
except Exception as ex: except Exception as ex:
print(f"Error mapping vector distances to edges: {ex}") logger.error(f"Error mapping vector distances to edges: {str(ex)}")
raise ex raise ex
async def calculate_top_triplet_importances(self, k: int) -> List: async def calculate_top_triplet_importances(self, k: int) -> List:
min_heap = [] min_heap = []
for i, edge in enumerate(self.edges): for i, edge in enumerate(self.edges):
source_node = self.get_node(edge.node1.id) source_node = self.get_node(edge.node1.id)
target_node = self.get_node(edge.node2.id) target_node = self.get_node(edge.node2.id)

View file

@ -33,7 +33,7 @@ async def get_formatted_graph_data(dataset_id: UUID, user_id: UUID):
lambda edge: { lambda edge: {
"source": str(edge[0]), "source": str(edge[0]),
"target": str(edge[1]), "target": str(edge[1]),
"label": edge[2], "label": str(edge[2]),
}, },
edges, edges,
) )

View file

@ -1,10 +1,13 @@
from typing import Any, Optional from typing import Any, Optional
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.retrieval.base_retriever import BaseRetriever from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.exceptions.exceptions import NoDataError from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
logger = get_logger("ChunksRetriever")
class ChunksRetriever(BaseRetriever): class ChunksRetriever(BaseRetriever):
""" """
@ -41,14 +44,22 @@ class ChunksRetriever(BaseRetriever):
- Any: A list of document chunk payloads retrieved from the search. - Any: A list of document chunk payloads retrieved from the search.
""" """
logger.info(
f"Starting chunk retrieval for query: '{query[:100]}{'...' if len(query) > 100 else ''}'"
)
vector_engine = get_vector_engine() vector_engine = get_vector_engine()
try: try:
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k) found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k)
logger.info(f"Found {len(found_chunks)} chunks from vector search")
except CollectionNotFoundError as error: except CollectionNotFoundError as error:
logger.error("DocumentChunk_text collection not found in vector database")
raise NoDataError("No data found in the system, please add data first.") from error raise NoDataError("No data found in the system, please add data first.") from error
return [result.payload for result in found_chunks] chunk_payloads = [result.payload for result in found_chunks]
logger.info(f"Returning {len(chunk_payloads)} chunk payloads")
return chunk_payloads
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
""" """
@ -70,6 +81,17 @@ class ChunksRetriever(BaseRetriever):
- Any: The context used for the completion or the retrieved context if none was - Any: The context used for the completion or the retrieved context if none was
provided. provided.
""" """
logger.info(
f"Starting completion generation for query: '{query[:100]}{'...' if len(query) > 100 else ''}'"
)
if context is None: if context is None:
logger.debug("No context provided, retrieving context from vector database")
context = await self.get_context(query) context = await self.get_context(query)
else:
logger.debug("Using provided context")
logger.info(
f"Returning context with {len(context) if isinstance(context, list) else 1} item(s)"
)
return context return context

View file

@ -3,12 +3,15 @@ import asyncio
import aiofiles import aiofiles
from pydantic import BaseModel from pydantic import BaseModel
from cognee.shared.logging_utils import get_logger
from cognee.modules.retrieval.base_retriever import BaseRetriever from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.llm.get_llm_client import get_llm_client from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.infrastructure.llm.prompts import read_query_prompt from cognee.infrastructure.llm.prompts import read_query_prompt
logger = get_logger("CodeRetriever")
class CodeRetriever(BaseRetriever): class CodeRetriever(BaseRetriever):
"""Retriever for handling code-based searches.""" """Retriever for handling code-based searches."""
@ -35,26 +38,43 @@ class CodeRetriever(BaseRetriever):
async def _process_query(self, query: str) -> "CodeRetriever.CodeQueryInfo": async def _process_query(self, query: str) -> "CodeRetriever.CodeQueryInfo":
"""Process the query using LLM to extract file names and source code parts.""" """Process the query using LLM to extract file names and source code parts."""
logger.debug(
f"Processing query with LLM: '{query[:100]}{'...' if len(query) > 100 else ''}'"
)
system_prompt = read_query_prompt("codegraph_retriever_system.txt") system_prompt = read_query_prompt("codegraph_retriever_system.txt")
llm_client = get_llm_client() llm_client = get_llm_client()
try: try:
return await llm_client.acreate_structured_output( result = await llm_client.acreate_structured_output(
text_input=query, text_input=query,
system_prompt=system_prompt, system_prompt=system_prompt,
response_model=self.CodeQueryInfo, response_model=self.CodeQueryInfo,
) )
logger.info(
f"LLM extracted {len(result.filenames)} filenames and {len(result.sourcecode)} chars of source code"
)
return result
except Exception as e: except Exception as e:
logger.error(f"Failed to retrieve structured output from LLM: {str(e)}")
raise RuntimeError("Failed to retrieve structured output from LLM") from e raise RuntimeError("Failed to retrieve structured output from LLM") from e
async def get_context(self, query: str) -> Any: async def get_context(self, query: str) -> Any:
"""Find relevant code files based on the query.""" """Find relevant code files based on the query."""
logger.info(
f"Starting code retrieval for query: '{query[:100]}{'...' if len(query) > 100 else ''}'"
)
if not query or not isinstance(query, str): if not query or not isinstance(query, str):
logger.error("Invalid query: must be a non-empty string")
raise ValueError("The query must be a non-empty string.") raise ValueError("The query must be a non-empty string.")
try: try:
vector_engine = get_vector_engine() vector_engine = get_vector_engine()
graph_engine = await get_graph_engine() graph_engine = await get_graph_engine()
logger.debug("Successfully initialized vector and graph engines")
except Exception as e: except Exception as e:
logger.error(f"Database initialization error: {str(e)}")
raise RuntimeError("Database initialization error in code_graph_retriever, ") from e raise RuntimeError("Database initialization error in code_graph_retriever, ") from e
files_and_codeparts = await self._process_query(query) files_and_codeparts = await self._process_query(query)
@ -63,52 +83,80 @@ class CodeRetriever(BaseRetriever):
similar_codepieces = [] similar_codepieces = []
if not files_and_codeparts.filenames or not files_and_codeparts.sourcecode: if not files_and_codeparts.filenames or not files_and_codeparts.sourcecode:
logger.info("No specific files/code extracted from query, performing general search")
for collection in self.file_name_collections: for collection in self.file_name_collections:
logger.debug(f"Searching {collection} collection with general query")
search_results_file = await vector_engine.search( search_results_file = await vector_engine.search(
collection, query, limit=self.top_k collection, query, limit=self.top_k
) )
logger.debug(f"Found {len(search_results_file)} results in {collection}")
for res in search_results_file: for res in search_results_file:
similar_filenames.append( similar_filenames.append(
{"id": res.id, "score": res.score, "payload": res.payload} {"id": res.id, "score": res.score, "payload": res.payload}
) )
for collection in self.classes_and_functions_collections: for collection in self.classes_and_functions_collections:
logger.debug(f"Searching {collection} collection with general query")
search_results_code = await vector_engine.search( search_results_code = await vector_engine.search(
collection, query, limit=self.top_k collection, query, limit=self.top_k
) )
logger.debug(f"Found {len(search_results_code)} results in {collection}")
for res in search_results_code: for res in search_results_code:
similar_codepieces.append( similar_codepieces.append(
{"id": res.id, "score": res.score, "payload": res.payload} {"id": res.id, "score": res.score, "payload": res.payload}
) )
else: else:
logger.info(
f"Using extracted filenames ({len(files_and_codeparts.filenames)}) and source code for targeted search"
)
for collection in self.file_name_collections: for collection in self.file_name_collections:
for file_from_query in files_and_codeparts.filenames: for file_from_query in files_and_codeparts.filenames:
logger.debug(f"Searching {collection} for specific file: {file_from_query}")
search_results_file = await vector_engine.search( search_results_file = await vector_engine.search(
collection, file_from_query, limit=self.top_k collection, file_from_query, limit=self.top_k
) )
logger.debug(
f"Found {len(search_results_file)} results for file {file_from_query}"
)
for res in search_results_file: for res in search_results_file:
similar_filenames.append( similar_filenames.append(
{"id": res.id, "score": res.score, "payload": res.payload} {"id": res.id, "score": res.score, "payload": res.payload}
) )
for collection in self.classes_and_functions_collections: for collection in self.classes_and_functions_collections:
logger.debug(f"Searching {collection} with extracted source code")
search_results_code = await vector_engine.search( search_results_code = await vector_engine.search(
collection, files_and_codeparts.sourcecode, limit=self.top_k collection, files_and_codeparts.sourcecode, limit=self.top_k
) )
logger.debug(f"Found {len(search_results_code)} results for source code search")
for res in search_results_code: for res in search_results_code:
similar_codepieces.append( similar_codepieces.append(
{"id": res.id, "score": res.score, "payload": res.payload} {"id": res.id, "score": res.score, "payload": res.payload}
) )
total_items = len(similar_filenames) + len(similar_codepieces)
logger.info(
f"Total search results: {total_items} items ({len(similar_filenames)} filenames, {len(similar_codepieces)} code pieces)"
)
if total_items == 0:
logger.warning("No search results found, returning empty list")
return []
logger.debug("Getting graph connections for all search results")
relevant_triplets = await asyncio.gather( relevant_triplets = await asyncio.gather(
*[ *[
graph_engine.get_connections(similar_piece["id"]) graph_engine.get_connections(similar_piece["id"])
for similar_piece in similar_filenames + similar_codepieces for similar_piece in similar_filenames + similar_codepieces
] ]
) )
logger.info(f"Retrieved graph connections for {len(relevant_triplets)} items")
paths = set() paths = set()
for sublist in relevant_triplets: for i, sublist in enumerate(relevant_triplets):
logger.debug(f"Processing connections for item {i}: {len(sublist)} connections")
for tpl in sublist: for tpl in sublist:
if isinstance(tpl, tuple) and len(tpl) >= 3: if isinstance(tpl, tuple) and len(tpl) >= 3:
if "file_path" in tpl[0]: if "file_path" in tpl[0]:
@ -116,23 +164,31 @@ class CodeRetriever(BaseRetriever):
if "file_path" in tpl[2]: if "file_path" in tpl[2]:
paths.add(tpl[2]["file_path"]) paths.add(tpl[2]["file_path"])
logger.info(f"Found {len(paths)} unique file paths to read")
retrieved_files = {} retrieved_files = {}
read_tasks = [] read_tasks = []
for file_path in paths: for file_path in paths:
async def read_file(fp): async def read_file(fp):
try: try:
logger.debug(f"Reading file: {fp}")
async with aiofiles.open(fp, "r", encoding="utf-8") as f: async with aiofiles.open(fp, "r", encoding="utf-8") as f:
retrieved_files[fp] = await f.read() content = await f.read()
retrieved_files[fp] = content
logger.debug(f"Successfully read {len(content)} characters from {fp}")
except Exception as e: except Exception as e:
print(f"Error reading {fp}: {e}") logger.error(f"Error reading {fp}: {e}")
retrieved_files[fp] = "" retrieved_files[fp] = ""
read_tasks.append(read_file(file_path)) read_tasks.append(read_file(file_path))
await asyncio.gather(*read_tasks) await asyncio.gather(*read_tasks)
logger.info(
f"Successfully read {len([f for f in retrieved_files.values() if f])} files (out of {len(paths)} total)"
)
return [ result = [
{ {
"name": file_path, "name": file_path,
"description": file_path, "description": file_path,
@ -141,6 +197,9 @@ class CodeRetriever(BaseRetriever):
for file_path in paths for file_path in paths
] ]
logger.info(f"Returning {len(result)} code file contexts")
return result
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
"""Returns the code files context.""" """Returns the code files context."""
if context is None: if context is None:

View file

@ -1,11 +1,14 @@
from typing import Any, Optional from typing import Any, Optional
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.retrieval.utils.completion import generate_completion from cognee.modules.retrieval.utils.completion import generate_completion
from cognee.modules.retrieval.base_retriever import BaseRetriever from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.exceptions.exceptions import NoDataError from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
logger = get_logger("CompletionRetriever")
class CompletionRetriever(BaseRetriever): class CompletionRetriever(BaseRetriever):
""" """
@ -56,8 +59,10 @@ class CompletionRetriever(BaseRetriever):
# Combine all chunks text returned from vector search (number of chunks is determined by top_k # Combine all chunks text returned from vector search (number of chunks is determined by top_k
chunks_payload = [found_chunk.payload["text"] for found_chunk in found_chunks] chunks_payload = [found_chunk.payload["text"] for found_chunk in found_chunks]
return "\n".join(chunks_payload) combined_context = "\n".join(chunks_payload)
return combined_context
except CollectionNotFoundError as error: except CollectionNotFoundError as error:
logger.error("DocumentChunk_text collection not found")
raise NoDataError("No data found in the system, please add data first.") from error raise NoDataError("No data found in the system, please add data first.") from error
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
@ -70,22 +75,19 @@ class CompletionRetriever(BaseRetriever):
Parameters: Parameters:
----------- -----------
- query (str): The input query for which the completion is generated. - query (str): The query string to be used for generating a completion.
- context (Optional[Any]): Optional context to use for generating the completion; if - context (Optional[Any]): Optional pre-fetched context to use for generating the
not provided, it will be retrieved using get_context. (default None) completion; if None, it retrieves the context for the query. (default None)
Returns: Returns:
-------- --------
- Any: A list containing the generated completion from the LLM. - Any: The generated completion based on the provided query and context.
""" """
if context is None: if context is None:
context = await self.get_context(query) context = await self.get_context(query)
completion = await generate_completion( completion = await generate_completion(
query=query, query, context, self.user_prompt_path, self.system_prompt_path
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
) )
return [completion] return completion

View file

@ -10,7 +10,7 @@ from cognee.modules.retrieval.utils.completion import generate_completion
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
logger = get_logger() logger = get_logger("GraphCompletionRetriever")
class GraphCompletionRetriever(BaseRetriever): class GraphCompletionRetriever(BaseRetriever):

View file

@ -1,12 +1,15 @@
import asyncio import asyncio
from typing import Any, Optional from typing import Any, Optional
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.retrieval.base_retriever import BaseRetriever from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.exceptions.exceptions import NoDataError from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
logger = get_logger("InsightsRetriever")
class InsightsRetriever(BaseRetriever): class InsightsRetriever(BaseRetriever):
""" """
@ -63,6 +66,7 @@ class InsightsRetriever(BaseRetriever):
vector_engine.search("EntityType_name", query_text=query, limit=self.top_k), vector_engine.search("EntityType_name", query_text=query, limit=self.top_k),
) )
except CollectionNotFoundError as error: except CollectionNotFoundError as error:
logger.error("Entity collections not found")
raise NoDataError("No data found in the system, please add data first.") from error raise NoDataError("No data found in the system, please add data first.") from error
results = [*results[0], *results[1]] results = [*results[0], *results[1]]

View file

@ -1,5 +1,5 @@
from typing import Any, Optional from typing import Any, Optional
import logging from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.graph.networkx.adapter import NetworkXAdapter from cognee.infrastructure.databases.graph.networkx.adapter import NetworkXAdapter
from cognee.infrastructure.llm.get_llm_client import get_llm_client from cognee.infrastructure.llm.get_llm_client import get_llm_client
@ -8,7 +8,7 @@ from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.exceptions import SearchTypeNotSupported from cognee.modules.retrieval.exceptions import SearchTypeNotSupported
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
logger = logging.getLogger("NaturalLanguageRetriever") logger = get_logger("NaturalLanguageRetriever")
class NaturalLanguageRetriever(BaseRetriever): class NaturalLanguageRetriever(BaseRetriever):
@ -123,16 +123,12 @@ class NaturalLanguageRetriever(BaseRetriever):
- Optional[Any]: Returns the context retrieved from the graph database based on the - Optional[Any]: Returns the context retrieved from the graph database based on the
query. query.
""" """
try: graph_engine = await get_graph_engine()
graph_engine = await get_graph_engine()
if isinstance(graph_engine, (NetworkXAdapter)): if isinstance(graph_engine, (NetworkXAdapter)):
raise SearchTypeNotSupported("Natural language search type not supported.") raise SearchTypeNotSupported("Natural language search type not supported.")
return await self._execute_cypher_query(query, graph_engine) return await self._execute_cypher_query(query, graph_engine)
except Exception as e:
logger.error("Failed to execute natural language search retrieval: %s", str(e))
raise e
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
""" """

View file

@ -1,10 +1,13 @@
from typing import Any, Optional from typing import Any, Optional
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.retrieval.base_retriever import BaseRetriever from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.exceptions.exceptions import NoDataError from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
logger = get_logger("SummariesRetriever")
class SummariesRetriever(BaseRetriever): class SummariesRetriever(BaseRetriever):
""" """
@ -40,16 +43,24 @@ class SummariesRetriever(BaseRetriever):
- Any: A list of payloads from the retrieved summaries. - Any: A list of payloads from the retrieved summaries.
""" """
logger.info(
f"Starting summary retrieval for query: '{query[:100]}{'...' if len(query) > 100 else ''}'"
)
vector_engine = get_vector_engine() vector_engine = get_vector_engine()
try: try:
summaries_results = await vector_engine.search( summaries_results = await vector_engine.search(
"TextSummary_text", query, limit=self.top_k "TextSummary_text", query, limit=self.top_k
) )
logger.info(f"Found {len(summaries_results)} summaries from vector search")
except CollectionNotFoundError as error: except CollectionNotFoundError as error:
logger.error("TextSummary_text collection not found in vector database")
raise NoDataError("No data found in the system, please add data first.") from error raise NoDataError("No data found in the system, please add data first.") from error
return [summary.payload for summary in summaries_results] summary_payloads = [summary.payload for summary in summaries_results]
logger.info(f"Returning {len(summary_payloads)} summary payloads")
return summary_payloads
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
""" """
@ -70,6 +81,17 @@ class SummariesRetriever(BaseRetriever):
- Any: The generated completion context, which is either provided or retrieved. - Any: The generated completion context, which is either provided or retrieved.
""" """
logger.info(
f"Starting completion generation for query: '{query[:100]}{'...' if len(query) > 100 else ''}'"
)
if context is None: if context is None:
logger.debug("No context provided, retrieving context from vector database")
context = await self.get_context(query) context = await self.get_context(query)
else:
logger.debug("Using provided context")
logger.info(
f"Returning context with {len(context) if isinstance(context, list) else 1} item(s)"
)
return context return context

View file

@ -59,13 +59,13 @@ async def get_memory_fragment(
node_name: Optional[List[str]] = None, node_name: Optional[List[str]] = None,
) -> CogneeGraph: ) -> CogneeGraph:
"""Creates and initializes a CogneeGraph memory fragment with optional property projections.""" """Creates and initializes a CogneeGraph memory fragment with optional property projections."""
graph_engine = await get_graph_engine()
memory_fragment = CogneeGraph()
if properties_to_project is None: if properties_to_project is None:
properties_to_project = ["id", "description", "name", "type", "text"] properties_to_project = ["id", "description", "name", "type", "text"]
try: try:
graph_engine = await get_graph_engine()
memory_fragment = CogneeGraph()
await memory_fragment.project_graph_from_db( await memory_fragment.project_graph_from_db(
graph_engine, graph_engine,
node_properties_to_project=properties_to_project, node_properties_to_project=properties_to_project,
@ -73,7 +73,13 @@ async def get_memory_fragment(
node_type=node_type, node_type=node_type,
node_name=node_name, node_name=node_name,
) )
except EntityNotFoundError: except EntityNotFoundError:
# This is expected behavior - continue with empty fragment
pass
except Exception as e:
logger.error(f"Error during memory fragment creation: {str(e)}")
# Still return the fragment even if projection failed
pass pass
return memory_fragment return memory_fragment