chore: Cog 2354 add logging (#1115)
<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. --------- Co-authored-by: hajdul88 <52442977+hajdul88@users.noreply.github.com>
This commit is contained in:
parent
d6727a1b4a
commit
1885ab9e88
12 changed files with 271 additions and 108 deletions
|
|
@ -104,6 +104,7 @@ class KuzuAdapter(GraphDBInterface):
|
|||
max_db_size=4096 * 1024 * 1024,
|
||||
)
|
||||
|
||||
|
||||
self.db.init_database()
|
||||
self.connection = Connection(self.db)
|
||||
# Create node table with essential fields and timestamp
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ from .neo4j_metrics_utils import (
|
|||
from .deadlock_retry import deadlock_retry
|
||||
|
||||
|
||||
logger = get_logger("Neo4jAdapter", level=ERROR)
|
||||
logger = get_logger("Neo4jAdapter")
|
||||
|
||||
BASE_LABEL = "__Node__"
|
||||
|
||||
|
|
@ -870,34 +870,52 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
|
||||
A tuple containing two lists: nodes and edges with their properties.
|
||||
"""
|
||||
query = "MATCH (n) RETURN ID(n) AS id, labels(n) AS labels, properties(n) AS properties"
|
||||
import time
|
||||
|
||||
result = await self.query(query)
|
||||
start_time = time.time()
|
||||
|
||||
nodes = [
|
||||
(
|
||||
record["properties"]["id"],
|
||||
record["properties"],
|
||||
try:
|
||||
# Retrieve nodes
|
||||
query = "MATCH (n) RETURN ID(n) AS id, labels(n) AS labels, properties(n) AS properties"
|
||||
result = await self.query(query)
|
||||
|
||||
nodes = []
|
||||
for record in result:
|
||||
nodes.append(
|
||||
(
|
||||
record["properties"]["id"],
|
||||
record["properties"],
|
||||
)
|
||||
)
|
||||
|
||||
# Retrieve edges
|
||||
query = """
|
||||
MATCH (n)-[r]->(m)
|
||||
RETURN ID(n) AS source, ID(m) AS target, TYPE(r) AS type, properties(r) AS properties
|
||||
"""
|
||||
result = await self.query(query)
|
||||
|
||||
edges = []
|
||||
for record in result:
|
||||
edges.append(
|
||||
(
|
||||
record["properties"]["source_node_id"],
|
||||
record["properties"]["target_node_id"],
|
||||
record["type"],
|
||||
record["properties"],
|
||||
)
|
||||
)
|
||||
|
||||
retrieval_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"Retrieved {len(nodes)} nodes and {len(edges)} edges in {retrieval_time:.2f} seconds"
|
||||
)
|
||||
for record in result
|
||||
]
|
||||
|
||||
query = """
|
||||
MATCH (n)-[r]->(m)
|
||||
RETURN ID(n) AS source, ID(m) AS target, TYPE(r) AS type, properties(r) AS properties
|
||||
"""
|
||||
result = await self.query(query)
|
||||
edges = [
|
||||
(
|
||||
record["properties"]["source_node_id"],
|
||||
record["properties"]["target_node_id"],
|
||||
record["type"],
|
||||
record["properties"],
|
||||
)
|
||||
for record in result
|
||||
]
|
||||
return (nodes, edges)
|
||||
|
||||
return (nodes, edges)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during graph data retrieval: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_nodeset_subgraph(
|
||||
self, node_type: Type[Any], node_name: List[str]
|
||||
|
|
@ -918,50 +936,71 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
- Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]}: A tuple
|
||||
containing nodes and edges in the requested subgraph.
|
||||
"""
|
||||
label = node_type.__name__
|
||||
import time
|
||||
|
||||
query = f"""
|
||||
UNWIND $names AS wantedName
|
||||
MATCH (n:`{label}`)
|
||||
WHERE n.name = wantedName
|
||||
WITH collect(DISTINCT n) AS primary
|
||||
UNWIND primary AS p
|
||||
OPTIONAL MATCH (p)--(nbr)
|
||||
WITH primary, collect(DISTINCT nbr) AS nbrs
|
||||
WITH primary + nbrs AS nodelist
|
||||
UNWIND nodelist AS node
|
||||
WITH collect(DISTINCT node) AS nodes
|
||||
MATCH (a)-[r]-(b)
|
||||
WHERE a IN nodes AND b IN nodes
|
||||
WITH nodes, collect(DISTINCT r) AS rels
|
||||
RETURN
|
||||
[n IN nodes |
|
||||
{{ id: n.id,
|
||||
properties: properties(n) }}] AS rawNodes,
|
||||
[r IN rels |
|
||||
{{ type: type(r),
|
||||
properties: properties(r) }}] AS rawRels
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
result = await self.query(query, {"names": node_name})
|
||||
if not result:
|
||||
return [], []
|
||||
try:
|
||||
label = node_type.__name__
|
||||
|
||||
raw_nodes = result[0]["rawNodes"]
|
||||
raw_rels = result[0]["rawRels"]
|
||||
query = f"""
|
||||
UNWIND $names AS wantedName
|
||||
MATCH (n:`{label}`)
|
||||
WHERE n.name = wantedName
|
||||
WITH collect(DISTINCT n) AS primary
|
||||
UNWIND primary AS p
|
||||
OPTIONAL MATCH (p)--(nbr)
|
||||
WITH primary, collect(DISTINCT nbr) AS nbrs
|
||||
WITH primary + nbrs AS nodelist
|
||||
UNWIND nodelist AS node
|
||||
WITH collect(DISTINCT node) AS nodes
|
||||
MATCH (a)-[r]-(b)
|
||||
WHERE a IN nodes AND b IN nodes
|
||||
WITH nodes, collect(DISTINCT r) AS rels
|
||||
RETURN
|
||||
[n IN nodes |
|
||||
{{ id: n.id,
|
||||
properties: properties(n) }}] AS rawNodes,
|
||||
[r IN rels |
|
||||
{{ type: type(r),
|
||||
properties: properties(r) }}] AS rawRels
|
||||
"""
|
||||
|
||||
nodes = [(n["properties"]["id"], n["properties"]) for n in raw_nodes]
|
||||
edges = [
|
||||
(
|
||||
r["properties"]["source_node_id"],
|
||||
r["properties"]["target_node_id"],
|
||||
r["type"],
|
||||
r["properties"],
|
||||
result = await self.query(query, {"names": node_name})
|
||||
|
||||
if not result:
|
||||
return [], []
|
||||
|
||||
raw_nodes = result[0]["rawNodes"]
|
||||
raw_rels = result[0]["rawRels"]
|
||||
|
||||
# Process nodes
|
||||
nodes = []
|
||||
for n in raw_nodes:
|
||||
nodes.append((n["properties"]["id"], n["properties"]))
|
||||
|
||||
# Process edges
|
||||
edges = []
|
||||
for r in raw_rels:
|
||||
edges.append(
|
||||
(
|
||||
r["properties"]["source_node_id"],
|
||||
r["properties"]["target_node_id"],
|
||||
r["type"],
|
||||
r["properties"],
|
||||
)
|
||||
)
|
||||
|
||||
retrieval_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"Retrieved {len(nodes)} nodes and {len(edges)} edges for {node_type.__name__} in {retrieval_time:.2f} seconds"
|
||||
)
|
||||
for r in raw_rels
|
||||
]
|
||||
|
||||
return nodes, edges
|
||||
return nodes, edges
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during nodeset subgraph retrieval: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_filtered_graph_data(self, attribute_filters):
|
||||
"""
|
||||
|
|
@ -1011,8 +1050,8 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
|
||||
edges = [
|
||||
(
|
||||
record["source"],
|
||||
record["target"],
|
||||
record["properties"]["source_node_id"],
|
||||
record["properties"]["target_node_id"],
|
||||
record["type"],
|
||||
record["properties"],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from cognee.modules.graph.cognee_graph.CogneeGraphElements import Node, Edge
|
|||
from cognee.modules.graph.cognee_graph.CogneeAbstractGraph import CogneeAbstractGraph
|
||||
import heapq
|
||||
|
||||
logger = get_logger()
|
||||
logger = get_logger("CogneeGraph")
|
||||
|
||||
|
||||
class CogneeGraph(CogneeAbstractGraph):
|
||||
|
|
@ -66,7 +66,13 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
) -> None:
|
||||
if node_dimension < 1 or edge_dimension < 1:
|
||||
raise InvalidValueError(message="Dimensions must be positive integers")
|
||||
|
||||
try:
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Determine projection strategy
|
||||
if node_type is not None and node_name is not None:
|
||||
nodes_data, edges_data = await adapter.get_nodeset_subgraph(
|
||||
node_type=node_type, node_name=node_name
|
||||
|
|
@ -83,16 +89,17 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
nodes_data, edges_data = await adapter.get_filtered_graph_data(
|
||||
attribute_filters=memory_fragment_filter
|
||||
)
|
||||
|
||||
if not nodes_data or not edges_data:
|
||||
raise EntityNotFoundError(
|
||||
message="Empty filtered graph projected from the database."
|
||||
)
|
||||
|
||||
# Process nodes
|
||||
for node_id, properties in nodes_data:
|
||||
node_attributes = {key: properties.get(key) for key in node_properties_to_project}
|
||||
self.add_node(Node(str(node_id), node_attributes, dimension=node_dimension))
|
||||
|
||||
# Process edges
|
||||
for source_id, target_id, relationship_type, properties in edges_data:
|
||||
source_node = self.get_node(str(source_id))
|
||||
target_node = self.get_node(str(target_id))
|
||||
|
|
@ -113,17 +120,23 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
|
||||
source_node.add_skeleton_edge(edge)
|
||||
target_node.add_skeleton_edge(edge)
|
||||
|
||||
else:
|
||||
raise EntityNotFoundError(
|
||||
message=f"Edge references nonexistent nodes: {source_id} -> {target_id}"
|
||||
)
|
||||
|
||||
except (ValueError, TypeError) as e:
|
||||
print(f"Error projecting graph: {e}")
|
||||
raise e
|
||||
# Final statistics
|
||||
projection_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"Graph projection completed: {len(self.nodes)} nodes, {len(self.edges)} edges in {projection_time:.2f}s"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during graph projection: {str(e)}")
|
||||
raise
|
||||
|
||||
async def map_vector_distances_to_graph_nodes(self, node_distances) -> None:
|
||||
mapped_nodes = 0
|
||||
for category, scored_results in node_distances.items():
|
||||
for scored_result in scored_results:
|
||||
node_id = str(scored_result.id)
|
||||
|
|
@ -131,6 +144,7 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
node = self.get_node(node_id)
|
||||
if node:
|
||||
node.add_attribute("vector_distance", score)
|
||||
mapped_nodes += 1
|
||||
|
||||
async def map_vector_distances_to_graph_edges(
|
||||
self, vector_engine, query_vector, edge_distances
|
||||
|
|
@ -150,18 +164,16 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
|
||||
for edge in self.edges:
|
||||
relationship_type = edge.attributes.get("relationship_type")
|
||||
if not relationship_type or relationship_type not in embedding_map:
|
||||
print(f"Edge {edge} has an unknown or missing relationship type.")
|
||||
continue
|
||||
|
||||
edge.attributes["vector_distance"] = embedding_map[relationship_type]
|
||||
if relationship_type and relationship_type in embedding_map:
|
||||
edge.attributes["vector_distance"] = embedding_map[relationship_type]
|
||||
|
||||
except Exception as ex:
|
||||
print(f"Error mapping vector distances to edges: {ex}")
|
||||
logger.error(f"Error mapping vector distances to edges: {str(ex)}")
|
||||
raise ex
|
||||
|
||||
async def calculate_top_triplet_importances(self, k: int) -> List:
|
||||
min_heap = []
|
||||
|
||||
for i, edge in enumerate(self.edges):
|
||||
source_node = self.get_node(edge.node1.id)
|
||||
target_node = self.get_node(edge.node2.id)
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ async def get_formatted_graph_data(dataset_id: UUID, user_id: UUID):
|
|||
lambda edge: {
|
||||
"source": str(edge[0]),
|
||||
"target": str(edge[1]),
|
||||
"label": edge[2],
|
||||
"label": str(edge[2]),
|
||||
},
|
||||
edges,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,13 @@
|
|||
from typing import Any, Optional
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
|
||||
|
||||
logger = get_logger("ChunksRetriever")
|
||||
|
||||
|
||||
class ChunksRetriever(BaseRetriever):
|
||||
"""
|
||||
|
|
@ -41,14 +44,22 @@ class ChunksRetriever(BaseRetriever):
|
|||
|
||||
- Any: A list of document chunk payloads retrieved from the search.
|
||||
"""
|
||||
logger.info(
|
||||
f"Starting chunk retrieval for query: '{query[:100]}{'...' if len(query) > 100 else ''}'"
|
||||
)
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
try:
|
||||
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k)
|
||||
logger.info(f"Found {len(found_chunks)} chunks from vector search")
|
||||
except CollectionNotFoundError as error:
|
||||
logger.error("DocumentChunk_text collection not found in vector database")
|
||||
raise NoDataError("No data found in the system, please add data first.") from error
|
||||
|
||||
return [result.payload for result in found_chunks]
|
||||
chunk_payloads = [result.payload for result in found_chunks]
|
||||
logger.info(f"Returning {len(chunk_payloads)} chunk payloads")
|
||||
return chunk_payloads
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
"""
|
||||
|
|
@ -70,6 +81,17 @@ class ChunksRetriever(BaseRetriever):
|
|||
- Any: The context used for the completion or the retrieved context if none was
|
||||
provided.
|
||||
"""
|
||||
logger.info(
|
||||
f"Starting completion generation for query: '{query[:100]}{'...' if len(query) > 100 else ''}'"
|
||||
)
|
||||
|
||||
if context is None:
|
||||
logger.debug("No context provided, retrieving context from vector database")
|
||||
context = await self.get_context(query)
|
||||
else:
|
||||
logger.debug("Using provided context")
|
||||
|
||||
logger.info(
|
||||
f"Returning context with {len(context) if isinstance(context, list) else 1} item(s)"
|
||||
)
|
||||
return context
|
||||
|
|
|
|||
|
|
@ -3,12 +3,15 @@ import asyncio
|
|||
import aiofiles
|
||||
from pydantic import BaseModel
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||
|
||||
logger = get_logger("CodeRetriever")
|
||||
|
||||
|
||||
class CodeRetriever(BaseRetriever):
|
||||
"""Retriever for handling code-based searches."""
|
||||
|
|
@ -35,26 +38,43 @@ class CodeRetriever(BaseRetriever):
|
|||
|
||||
async def _process_query(self, query: str) -> "CodeRetriever.CodeQueryInfo":
|
||||
"""Process the query using LLM to extract file names and source code parts."""
|
||||
logger.debug(
|
||||
f"Processing query with LLM: '{query[:100]}{'...' if len(query) > 100 else ''}'"
|
||||
)
|
||||
|
||||
system_prompt = read_query_prompt("codegraph_retriever_system.txt")
|
||||
llm_client = get_llm_client()
|
||||
|
||||
try:
|
||||
return await llm_client.acreate_structured_output(
|
||||
result = await llm_client.acreate_structured_output(
|
||||
text_input=query,
|
||||
system_prompt=system_prompt,
|
||||
response_model=self.CodeQueryInfo,
|
||||
)
|
||||
logger.info(
|
||||
f"LLM extracted {len(result.filenames)} filenames and {len(result.sourcecode)} chars of source code"
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to retrieve structured output from LLM: {str(e)}")
|
||||
raise RuntimeError("Failed to retrieve structured output from LLM") from e
|
||||
|
||||
async def get_context(self, query: str) -> Any:
|
||||
"""Find relevant code files based on the query."""
|
||||
logger.info(
|
||||
f"Starting code retrieval for query: '{query[:100]}{'...' if len(query) > 100 else ''}'"
|
||||
)
|
||||
|
||||
if not query or not isinstance(query, str):
|
||||
logger.error("Invalid query: must be a non-empty string")
|
||||
raise ValueError("The query must be a non-empty string.")
|
||||
|
||||
try:
|
||||
vector_engine = get_vector_engine()
|
||||
graph_engine = await get_graph_engine()
|
||||
logger.debug("Successfully initialized vector and graph engines")
|
||||
except Exception as e:
|
||||
logger.error(f"Database initialization error: {str(e)}")
|
||||
raise RuntimeError("Database initialization error in code_graph_retriever, ") from e
|
||||
|
||||
files_and_codeparts = await self._process_query(query)
|
||||
|
|
@ -63,52 +83,80 @@ class CodeRetriever(BaseRetriever):
|
|||
similar_codepieces = []
|
||||
|
||||
if not files_and_codeparts.filenames or not files_and_codeparts.sourcecode:
|
||||
logger.info("No specific files/code extracted from query, performing general search")
|
||||
|
||||
for collection in self.file_name_collections:
|
||||
logger.debug(f"Searching {collection} collection with general query")
|
||||
search_results_file = await vector_engine.search(
|
||||
collection, query, limit=self.top_k
|
||||
)
|
||||
logger.debug(f"Found {len(search_results_file)} results in {collection}")
|
||||
for res in search_results_file:
|
||||
similar_filenames.append(
|
||||
{"id": res.id, "score": res.score, "payload": res.payload}
|
||||
)
|
||||
|
||||
for collection in self.classes_and_functions_collections:
|
||||
logger.debug(f"Searching {collection} collection with general query")
|
||||
search_results_code = await vector_engine.search(
|
||||
collection, query, limit=self.top_k
|
||||
)
|
||||
logger.debug(f"Found {len(search_results_code)} results in {collection}")
|
||||
for res in search_results_code:
|
||||
similar_codepieces.append(
|
||||
{"id": res.id, "score": res.score, "payload": res.payload}
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Using extracted filenames ({len(files_and_codeparts.filenames)}) and source code for targeted search"
|
||||
)
|
||||
|
||||
for collection in self.file_name_collections:
|
||||
for file_from_query in files_and_codeparts.filenames:
|
||||
logger.debug(f"Searching {collection} for specific file: {file_from_query}")
|
||||
search_results_file = await vector_engine.search(
|
||||
collection, file_from_query, limit=self.top_k
|
||||
)
|
||||
logger.debug(
|
||||
f"Found {len(search_results_file)} results for file {file_from_query}"
|
||||
)
|
||||
for res in search_results_file:
|
||||
similar_filenames.append(
|
||||
{"id": res.id, "score": res.score, "payload": res.payload}
|
||||
)
|
||||
|
||||
for collection in self.classes_and_functions_collections:
|
||||
logger.debug(f"Searching {collection} with extracted source code")
|
||||
search_results_code = await vector_engine.search(
|
||||
collection, files_and_codeparts.sourcecode, limit=self.top_k
|
||||
)
|
||||
logger.debug(f"Found {len(search_results_code)} results for source code search")
|
||||
for res in search_results_code:
|
||||
similar_codepieces.append(
|
||||
{"id": res.id, "score": res.score, "payload": res.payload}
|
||||
)
|
||||
|
||||
total_items = len(similar_filenames) + len(similar_codepieces)
|
||||
logger.info(
|
||||
f"Total search results: {total_items} items ({len(similar_filenames)} filenames, {len(similar_codepieces)} code pieces)"
|
||||
)
|
||||
|
||||
if total_items == 0:
|
||||
logger.warning("No search results found, returning empty list")
|
||||
return []
|
||||
|
||||
logger.debug("Getting graph connections for all search results")
|
||||
relevant_triplets = await asyncio.gather(
|
||||
*[
|
||||
graph_engine.get_connections(similar_piece["id"])
|
||||
for similar_piece in similar_filenames + similar_codepieces
|
||||
]
|
||||
)
|
||||
logger.info(f"Retrieved graph connections for {len(relevant_triplets)} items")
|
||||
|
||||
paths = set()
|
||||
for sublist in relevant_triplets:
|
||||
for i, sublist in enumerate(relevant_triplets):
|
||||
logger.debug(f"Processing connections for item {i}: {len(sublist)} connections")
|
||||
for tpl in sublist:
|
||||
if isinstance(tpl, tuple) and len(tpl) >= 3:
|
||||
if "file_path" in tpl[0]:
|
||||
|
|
@ -116,23 +164,31 @@ class CodeRetriever(BaseRetriever):
|
|||
if "file_path" in tpl[2]:
|
||||
paths.add(tpl[2]["file_path"])
|
||||
|
||||
logger.info(f"Found {len(paths)} unique file paths to read")
|
||||
|
||||
retrieved_files = {}
|
||||
read_tasks = []
|
||||
for file_path in paths:
|
||||
|
||||
async def read_file(fp):
|
||||
try:
|
||||
logger.debug(f"Reading file: {fp}")
|
||||
async with aiofiles.open(fp, "r", encoding="utf-8") as f:
|
||||
retrieved_files[fp] = await f.read()
|
||||
content = await f.read()
|
||||
retrieved_files[fp] = content
|
||||
logger.debug(f"Successfully read {len(content)} characters from {fp}")
|
||||
except Exception as e:
|
||||
print(f"Error reading {fp}: {e}")
|
||||
logger.error(f"Error reading {fp}: {e}")
|
||||
retrieved_files[fp] = ""
|
||||
|
||||
read_tasks.append(read_file(file_path))
|
||||
|
||||
await asyncio.gather(*read_tasks)
|
||||
logger.info(
|
||||
f"Successfully read {len([f for f in retrieved_files.values() if f])} files (out of {len(paths)} total)"
|
||||
)
|
||||
|
||||
return [
|
||||
result = [
|
||||
{
|
||||
"name": file_path,
|
||||
"description": file_path,
|
||||
|
|
@ -141,6 +197,9 @@ class CodeRetriever(BaseRetriever):
|
|||
for file_path in paths
|
||||
]
|
||||
|
||||
logger.info(f"Returning {len(result)} code file contexts")
|
||||
return result
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
"""Returns the code files context."""
|
||||
if context is None:
|
||||
|
|
|
|||
|
|
@ -1,11 +1,14 @@
|
|||
from typing import Any, Optional
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.retrieval.utils.completion import generate_completion
|
||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
||||
|
||||
logger = get_logger("CompletionRetriever")
|
||||
|
||||
|
||||
class CompletionRetriever(BaseRetriever):
|
||||
"""
|
||||
|
|
@ -56,8 +59,10 @@ class CompletionRetriever(BaseRetriever):
|
|||
|
||||
# Combine all chunks text returned from vector search (number of chunks is determined by top_k
|
||||
chunks_payload = [found_chunk.payload["text"] for found_chunk in found_chunks]
|
||||
return "\n".join(chunks_payload)
|
||||
combined_context = "\n".join(chunks_payload)
|
||||
return combined_context
|
||||
except CollectionNotFoundError as error:
|
||||
logger.error("DocumentChunk_text collection not found")
|
||||
raise NoDataError("No data found in the system, please add data first.") from error
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
|
|
@ -70,22 +75,19 @@ class CompletionRetriever(BaseRetriever):
|
|||
Parameters:
|
||||
-----------
|
||||
|
||||
- query (str): The input query for which the completion is generated.
|
||||
- context (Optional[Any]): Optional context to use for generating the completion; if
|
||||
not provided, it will be retrieved using get_context. (default None)
|
||||
- query (str): The query string to be used for generating a completion.
|
||||
- context (Optional[Any]): Optional pre-fetched context to use for generating the
|
||||
completion; if None, it retrieves the context for the query. (default None)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- Any: A list containing the generated completion from the LLM.
|
||||
- Any: The generated completion based on the provided query and context.
|
||||
"""
|
||||
if context is None:
|
||||
context = await self.get_context(query)
|
||||
|
||||
completion = await generate_completion(
|
||||
query=query,
|
||||
context=context,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
query, context, self.user_prompt_path, self.system_prompt_path
|
||||
)
|
||||
return [completion]
|
||||
return completion
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from cognee.modules.retrieval.utils.completion import generate_completion
|
|||
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
logger = get_logger("GraphCompletionRetriever")
|
||||
|
||||
|
||||
class GraphCompletionRetriever(BaseRetriever):
|
||||
|
|
|
|||
|
|
@ -1,12 +1,15 @@
|
|||
import asyncio
|
||||
from typing import Any, Optional
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
|
||||
|
||||
logger = get_logger("InsightsRetriever")
|
||||
|
||||
|
||||
class InsightsRetriever(BaseRetriever):
|
||||
"""
|
||||
|
|
@ -63,6 +66,7 @@ class InsightsRetriever(BaseRetriever):
|
|||
vector_engine.search("EntityType_name", query_text=query, limit=self.top_k),
|
||||
)
|
||||
except CollectionNotFoundError as error:
|
||||
logger.error("Entity collections not found")
|
||||
raise NoDataError("No data found in the system, please add data first.") from error
|
||||
|
||||
results = [*results[0], *results[1]]
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from typing import Any, Optional
|
||||
import logging
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.graph.networkx.adapter import NetworkXAdapter
|
||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
|
|
@ -8,7 +8,7 @@ from cognee.modules.retrieval.base_retriever import BaseRetriever
|
|||
from cognee.modules.retrieval.exceptions import SearchTypeNotSupported
|
||||
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
|
||||
|
||||
logger = logging.getLogger("NaturalLanguageRetriever")
|
||||
logger = get_logger("NaturalLanguageRetriever")
|
||||
|
||||
|
||||
class NaturalLanguageRetriever(BaseRetriever):
|
||||
|
|
@ -123,16 +123,12 @@ class NaturalLanguageRetriever(BaseRetriever):
|
|||
- Optional[Any]: Returns the context retrieved from the graph database based on the
|
||||
query.
|
||||
"""
|
||||
try:
|
||||
graph_engine = await get_graph_engine()
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
if isinstance(graph_engine, (NetworkXAdapter)):
|
||||
raise SearchTypeNotSupported("Natural language search type not supported.")
|
||||
if isinstance(graph_engine, (NetworkXAdapter)):
|
||||
raise SearchTypeNotSupported("Natural language search type not supported.")
|
||||
|
||||
return await self._execute_cypher_query(query, graph_engine)
|
||||
except Exception as e:
|
||||
logger.error("Failed to execute natural language search retrieval: %s", str(e))
|
||||
raise e
|
||||
return await self._execute_cypher_query(query, graph_engine)
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,10 +1,13 @@
|
|||
from typing import Any, Optional
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
|
||||
|
||||
logger = get_logger("SummariesRetriever")
|
||||
|
||||
|
||||
class SummariesRetriever(BaseRetriever):
|
||||
"""
|
||||
|
|
@ -40,16 +43,24 @@ class SummariesRetriever(BaseRetriever):
|
|||
|
||||
- Any: A list of payloads from the retrieved summaries.
|
||||
"""
|
||||
logger.info(
|
||||
f"Starting summary retrieval for query: '{query[:100]}{'...' if len(query) > 100 else ''}'"
|
||||
)
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
try:
|
||||
summaries_results = await vector_engine.search(
|
||||
"TextSummary_text", query, limit=self.top_k
|
||||
)
|
||||
logger.info(f"Found {len(summaries_results)} summaries from vector search")
|
||||
except CollectionNotFoundError as error:
|
||||
logger.error("TextSummary_text collection not found in vector database")
|
||||
raise NoDataError("No data found in the system, please add data first.") from error
|
||||
|
||||
return [summary.payload for summary in summaries_results]
|
||||
summary_payloads = [summary.payload for summary in summaries_results]
|
||||
logger.info(f"Returning {len(summary_payloads)} summary payloads")
|
||||
return summary_payloads
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
"""
|
||||
|
|
@ -70,6 +81,17 @@ class SummariesRetriever(BaseRetriever):
|
|||
|
||||
- Any: The generated completion context, which is either provided or retrieved.
|
||||
"""
|
||||
logger.info(
|
||||
f"Starting completion generation for query: '{query[:100]}{'...' if len(query) > 100 else ''}'"
|
||||
)
|
||||
|
||||
if context is None:
|
||||
logger.debug("No context provided, retrieving context from vector database")
|
||||
context = await self.get_context(query)
|
||||
else:
|
||||
logger.debug("Using provided context")
|
||||
|
||||
logger.info(
|
||||
f"Returning context with {len(context) if isinstance(context, list) else 1} item(s)"
|
||||
)
|
||||
return context
|
||||
|
|
|
|||
|
|
@ -59,13 +59,13 @@ async def get_memory_fragment(
|
|||
node_name: Optional[List[str]] = None,
|
||||
) -> CogneeGraph:
|
||||
"""Creates and initializes a CogneeGraph memory fragment with optional property projections."""
|
||||
graph_engine = await get_graph_engine()
|
||||
memory_fragment = CogneeGraph()
|
||||
|
||||
if properties_to_project is None:
|
||||
properties_to_project = ["id", "description", "name", "type", "text"]
|
||||
|
||||
try:
|
||||
graph_engine = await get_graph_engine()
|
||||
memory_fragment = CogneeGraph()
|
||||
|
||||
await memory_fragment.project_graph_from_db(
|
||||
graph_engine,
|
||||
node_properties_to_project=properties_to_project,
|
||||
|
|
@ -73,7 +73,13 @@ async def get_memory_fragment(
|
|||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
)
|
||||
|
||||
except EntityNotFoundError:
|
||||
# This is expected behavior - continue with empty fragment
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Error during memory fragment creation: {str(e)}")
|
||||
# Still return the fragment even if projection failed
|
||||
pass
|
||||
|
||||
return memory_fragment
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue