Merge branch 'dev' into COG-2082

This commit is contained in:
Vasilije 2025-07-24 14:09:32 +02:00 committed by GitHub
commit e03b2ea709
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 581 additions and 119 deletions

View file

@ -9,7 +9,7 @@ COPY package.json package-lock.json ./
# Install any needed packages specified in package.json
RUN npm ci
# RUN npm rebuild lightningcss
RUN npm rebuild lightningcss
# Copy the rest of the application code to the working directory
COPY src ./src

View file

@ -72,11 +72,38 @@ class KuzuAdapter(GraphDBInterface):
run_sync(file_storage.ensure_directory_exists())
self.db = Database(
self.db_path,
buffer_pool_size=256 * 1024 * 1024, # 256MB buffer pool
max_db_size=1024 * 1024 * 1024,
)
try:
self.db = Database(
self.db_path,
buffer_pool_size=2048 * 1024 * 1024, # 2048MB buffer pool
max_db_size=4096 * 1024 * 1024,
)
except RuntimeError:
from .kuzu_migrate import read_kuzu_storage_version
import kuzu
kuzu_db_version = read_kuzu_storage_version(self.db_path)
if (
kuzu_db_version == "0.9.0" or kuzu_db_version == "0.8.2"
) and kuzu_db_version != kuzu.__version__:
# TODO: Write migration script that will handle all user graph databases in multi-user mode
# Try to migrate kuzu database to latest version
from .kuzu_migrate import kuzu_migration
kuzu_migration(
new_db=self.db_path + "new",
old_db=self.db_path,
new_version=kuzu.__version__,
old_version=kuzu_db_version,
overwrite=True,
)
self.db = Database(
self.db_path,
buffer_pool_size=2048 * 1024 * 1024, # 2048MB buffer pool
max_db_size=4096 * 1024 * 1024,
)
self.db.init_database()
self.connection = Connection(self.db)

View file

@ -0,0 +1,276 @@
#!/usr/bin/env python3
"""
Kuzu Database Migration Script
This script migrates Kuzu databases between different versions by:
1. Setting up isolated Python environments for each Kuzu version
2. Exporting data from the source database using the old version
3. Importing data into the target database using the new version
4. If overwrite is enabled target database will replace source database and source database will have the prefix _old
5. If delete-old is enabled target database will be renamed to source database and source database will be deleted
The script automatically handles:
- Environment setup (creates virtual environments as needed)
- Export/import validation
- Error handling and reporting
Usage Examples:
# Basic migration from 0.9.0 to 0.11.0
python kuzu_migrate.py --old-version 0.9.0 --new-version 0.11.0 --old-db /path/to/old/database --new-db /path/to/new/database
Requirements:
- Python 3.7+
- Internet connection (to download Kuzu packages)
- Sufficient disk space for virtual environments and temporary exports
Notes:
- Can only be used to migrate to newer Kuzu versions, from 0.11.0 onwards
"""
import tempfile
import sys
import struct
import shutil
import subprocess
import argparse
import os
kuzu_version_mapping = {
34: "0.7.0",
35: "0.7.1",
36: "0.8.2",
37: "0.9.0",
38: "0.10.1",
39: "0.11.0",
}
def read_kuzu_storage_version(kuzu_db_path: str) -> int:
"""
Reads the Kùzu storage version code from the first catalog.bin file bytes.
:param kuzu_db_path: Path to the Kuzu database file/directory.
:return: Storage version code as an integer.
"""
if os.path.isdir(kuzu_db_path):
kuzu_version_file_path = os.path.join(kuzu_db_path, "catalog.kz")
if not os.path.isfile(kuzu_version_file_path):
raise FileExistsError("Kuzu catalog.kz file does not exist")
else:
kuzu_version_file_path = kuzu_db_path
with open(kuzu_version_file_path, "rb") as f:
# Skip the 3-byte magic "KUZ" and one byte of padding
f.seek(4)
# Read the next 8 bytes as a little-endian unsigned 64-bit integer
data = f.read(8)
if len(data) < 8:
raise ValueError(
f"File '{kuzu_version_file_path}' does not contain a storage version code."
)
version_code = struct.unpack("<Q", data)[0]
if kuzu_version_mapping.get(version_code):
return kuzu_version_mapping[version_code]
else:
ValueError("Could not map version_code to proper Kuzu version.")
def ensure_env(version: str, export_dir) -> str:
"""
Create (if needed) a venv at .kuzu_envs/{version} and install kuzu=={version}.
Returns the path to the venv's python executable.
"""
# Use temp directory to create venv
kuzu_envs_dir = os.path.join(export_dir, ".kuzu_envs")
# venv base under the script directory
base = os.path.join(kuzu_envs_dir, version)
py_bin = os.path.join(base, "bin", "python")
# If environment already exists clean it
if os.path.isfile(py_bin):
shutil.rmtree(base)
print(f"→ Setting up venv for Kùzu {version}...", file=sys.stderr)
# Create venv
subprocess.run([sys.executable, "-m", "venv", base], check=True)
# Install the specific Kùzu version
subprocess.run([py_bin, "-m", "pip", "install", "--upgrade", "pip"], check=True)
subprocess.run([py_bin, "-m", "pip", "install", f"kuzu=={version}"], check=True)
return py_bin
def run_migration_step(python_exe: str, db_path: str, cypher: str):
"""
Uses the given python_exe to execute a short snippet that
connects to the Kùzu database and runs a Cypher command.
"""
snippet = f"""
import kuzu
db = kuzu.Database(r"{db_path}")
conn = kuzu.Connection(db)
conn.execute(r\"\"\"{cypher}\"\"\")
"""
proc = subprocess.run([python_exe, "-c", snippet], capture_output=True, text=True)
if proc.returncode != 0:
print(f"[ERROR] {cypher} failed:\n{proc.stderr}", file=sys.stderr)
sys.exit(proc.returncode)
def kuzu_migration(new_db, old_db, new_version, old_version=None, overwrite=None, delete_old=None):
"""
Main migration function that handles the complete migration process.
"""
print(f"🔄 Migrating Kuzu database from {old_version} to {new_version}", file=sys.stderr)
print(f"📂 Source: {old_db}", file=sys.stderr)
print("", file=sys.stderr)
# If version of old kuzu db is not provided try to determine it based on file info
if not old_version:
old_version = read_kuzu_storage_version(old_db)
# Check if old database exists
if not os.path.exists(old_db):
print(f"Source database '{old_db}' does not exist.", file=sys.stderr)
sys.exit(1)
# Prepare target - ensure parent directory exists but remove target if it exists
parent_dir = os.path.dirname(new_db)
if parent_dir:
os.makedirs(parent_dir, exist_ok=True)
if os.path.exists(new_db):
raise FileExistsError(
"File already exists at new database location, remove file or change new database file path to continue"
)
# Use temp directory for all processing, it will be cleaned up after with statement
with tempfile.TemporaryDirectory() as export_dir:
# Set up environments
print(f"Setting up Kuzu {old_version} environment...", file=sys.stderr)
old_py = ensure_env(old_version, export_dir)
print(f"Setting up Kuzu {new_version} environment...", file=sys.stderr)
new_py = ensure_env(new_version, export_dir)
export_file = os.path.join(export_dir, "kuzu_export")
print(f"Exporting old DB → {export_dir}", file=sys.stderr)
run_migration_step(old_py, old_db, f"EXPORT DATABASE '{export_file}'")
print("Export complete.", file=sys.stderr)
# Check if export files were created and have content
schema_file = os.path.join(export_file, "schema.cypher")
if not os.path.exists(schema_file) or os.path.getsize(schema_file) == 0:
raise ValueError(f"Schema file not found: {schema_file}")
print(f"Importing into new DB at {new_db}", file=sys.stderr)
run_migration_step(new_py, new_db, f"IMPORT DATABASE '{export_file}'")
print("Import complete.", file=sys.stderr)
# Rename new kuzu database to old kuzu database name if enabled
if overwrite or delete_old:
rename_databases(old_db, old_version, new_db, delete_old)
print("✅ Kuzu graph database migration finished successfully!")
def rename_databases(old_db: str, old_version: str, new_db: str, delete_old: bool):
"""
When overwrite is enabled, back up the original old_db (file with .lock and .wal or directory)
by renaming it to *_old, and replace it with the newly imported new_db files.
When delete_old is enabled replace the old database with the new one and delete old database
"""
base_dir = os.path.dirname(old_db)
name = os.path.basename(old_db.rstrip(os.sep))
# Add _old_ and version info to backup graph database
backup_database_name = f"{name}_old_" + old_version.replace(".", "_")
backup_base = os.path.join(base_dir, backup_database_name)
if os.path.isfile(old_db):
# File-based database: handle main file and accompanying lock/WAL
for ext in ["", ".lock", ".wal"]:
src = old_db + ext
dst = backup_base + ext
if os.path.exists(src):
if delete_old:
os.remove(src)
else:
os.rename(src, dst)
print(f"Renamed '{src}' to '{dst}'", file=sys.stderr)
elif os.path.isdir(old_db):
# Directory-based Kuzu database
backup_dir = backup_base
if delete_old:
shutil.rmtree(old_db)
else:
os.rename(old_db, backup_dir)
print(f"Renamed directory '{old_db}' to '{backup_dir}'", file=sys.stderr)
else:
print(f"Original database path '{old_db}' not found for renaming.", file=sys.stderr)
sys.exit(1)
# Now move new files into place
for ext in ["", ".lock", ".wal"]:
src_new = new_db + ext
dst_new = os.path.join(base_dir, name + ext)
if os.path.exists(src_new):
os.rename(src_new, dst_new)
print(f"Renamed '{src_new}' to '{dst_new}'", file=sys.stderr)
def main():
p = argparse.ArgumentParser(
description="Migrate Kùzu DB via PyPI versions",
epilog="""
Examples:
%(prog)s --old-version 0.9.0 --new-version 0.11.0 \\
--old-db /path/to/old/db --new-db /path/to/new/db --overwrite
Note: This script will create virtual environments in .kuzu_envs/ directory
to isolate different Kuzu versions.
""",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
p.add_argument(
"--old-version",
required=False,
default=None,
help="Source Kuzu version (e.g., 0.9.0). If not provided automatic kuzu version detection will be attempted.",
)
p.add_argument("--new-version", required=True, help="Target Kuzu version (e.g., 0.11.0)")
p.add_argument("--old-db", required=True, help="Path to source database directory")
p.add_argument(
"--new-db",
required=True,
help="Path to target database directory, it can't be the same path as the old database. Use the overwrite flag if you want to replace the old database with the new one.",
)
p.add_argument(
"--overwrite",
required=False,
action="store_true",
default=False,
help="Rename new-db to the old-db name and location, keeps old-db as backup if delete-old is not True",
)
p.add_argument(
"--delete-old",
required=False,
action="store_true",
default=False,
help="When overwrite and delete-old is True old-db will not be stored as backup",
)
args = p.parse_args()
kuzu_migration(
new_db=args.new_db,
old_db=args.old_db,
new_version=args.new_version,
old_version=args.old_version,
overwrite=args.overwrite,
delete_old=args.delete_old,
)
if __name__ == "__main__":
main()

View file

@ -33,7 +33,7 @@ from .neo4j_metrics_utils import (
from .deadlock_retry import deadlock_retry
logger = get_logger("Neo4jAdapter", level=ERROR)
logger = get_logger("Neo4jAdapter")
BASE_LABEL = "__Node__"
@ -870,34 +870,52 @@ class Neo4jAdapter(GraphDBInterface):
A tuple containing two lists: nodes and edges with their properties.
"""
query = "MATCH (n) RETURN ID(n) AS id, labels(n) AS labels, properties(n) AS properties"
import time
result = await self.query(query)
start_time = time.time()
nodes = [
(
record["properties"]["id"],
record["properties"],
try:
# Retrieve nodes
query = "MATCH (n) RETURN ID(n) AS id, labels(n) AS labels, properties(n) AS properties"
result = await self.query(query)
nodes = []
for record in result:
nodes.append(
(
record["properties"]["id"],
record["properties"],
)
)
# Retrieve edges
query = """
MATCH (n)-[r]->(m)
RETURN ID(n) AS source, ID(m) AS target, TYPE(r) AS type, properties(r) AS properties
"""
result = await self.query(query)
edges = []
for record in result:
edges.append(
(
record["properties"]["source_node_id"],
record["properties"]["target_node_id"],
record["type"],
record["properties"],
)
)
retrieval_time = time.time() - start_time
logger.info(
f"Retrieved {len(nodes)} nodes and {len(edges)} edges in {retrieval_time:.2f} seconds"
)
for record in result
]
query = """
MATCH (n)-[r]->(m)
RETURN ID(n) AS source, ID(m) AS target, TYPE(r) AS type, properties(r) AS properties
"""
result = await self.query(query)
edges = [
(
record["properties"]["source_node_id"],
record["properties"]["target_node_id"],
record["type"],
record["properties"],
)
for record in result
]
return (nodes, edges)
return (nodes, edges)
except Exception as e:
logger.error(f"Error during graph data retrieval: {str(e)}")
raise
async def get_nodeset_subgraph(
self, node_type: Type[Any], node_name: List[str]
@ -918,50 +936,71 @@ class Neo4jAdapter(GraphDBInterface):
- Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]}: A tuple
containing nodes and edges in the requested subgraph.
"""
label = node_type.__name__
import time
query = f"""
UNWIND $names AS wantedName
MATCH (n:`{label}`)
WHERE n.name = wantedName
WITH collect(DISTINCT n) AS primary
UNWIND primary AS p
OPTIONAL MATCH (p)--(nbr)
WITH primary, collect(DISTINCT nbr) AS nbrs
WITH primary + nbrs AS nodelist
UNWIND nodelist AS node
WITH collect(DISTINCT node) AS nodes
MATCH (a)-[r]-(b)
WHERE a IN nodes AND b IN nodes
WITH nodes, collect(DISTINCT r) AS rels
RETURN
[n IN nodes |
{{ id: n.id,
properties: properties(n) }}] AS rawNodes,
[r IN rels |
{{ type: type(r),
properties: properties(r) }}] AS rawRels
"""
start_time = time.time()
result = await self.query(query, {"names": node_name})
if not result:
return [], []
try:
label = node_type.__name__
raw_nodes = result[0]["rawNodes"]
raw_rels = result[0]["rawRels"]
query = f"""
UNWIND $names AS wantedName
MATCH (n:`{label}`)
WHERE n.name = wantedName
WITH collect(DISTINCT n) AS primary
UNWIND primary AS p
OPTIONAL MATCH (p)--(nbr)
WITH primary, collect(DISTINCT nbr) AS nbrs
WITH primary + nbrs AS nodelist
UNWIND nodelist AS node
WITH collect(DISTINCT node) AS nodes
MATCH (a)-[r]-(b)
WHERE a IN nodes AND b IN nodes
WITH nodes, collect(DISTINCT r) AS rels
RETURN
[n IN nodes |
{{ id: n.id,
properties: properties(n) }}] AS rawNodes,
[r IN rels |
{{ type: type(r),
properties: properties(r) }}] AS rawRels
"""
nodes = [(n["properties"]["id"], n["properties"]) for n in raw_nodes]
edges = [
(
r["properties"]["source_node_id"],
r["properties"]["target_node_id"],
r["type"],
r["properties"],
result = await self.query(query, {"names": node_name})
if not result:
return [], []
raw_nodes = result[0]["rawNodes"]
raw_rels = result[0]["rawRels"]
# Process nodes
nodes = []
for n in raw_nodes:
nodes.append((n["properties"]["id"], n["properties"]))
# Process edges
edges = []
for r in raw_rels:
edges.append(
(
r["properties"]["source_node_id"],
r["properties"]["target_node_id"],
r["type"],
r["properties"],
)
)
retrieval_time = time.time() - start_time
logger.info(
f"Retrieved {len(nodes)} nodes and {len(edges)} edges for {node_type.__name__} in {retrieval_time:.2f} seconds"
)
for r in raw_rels
]
return nodes, edges
return nodes, edges
except Exception as e:
logger.error(f"Error during nodeset subgraph retrieval: {str(e)}")
raise
async def get_filtered_graph_data(self, attribute_filters):
"""
@ -1011,8 +1050,8 @@ class Neo4jAdapter(GraphDBInterface):
edges = [
(
record["source"],
record["target"],
record["properties"]["source_node_id"],
record["properties"]["target_node_id"],
record["type"],
record["properties"],
)

View file

@ -18,11 +18,8 @@ class UnstructuredDocument(Document):
except ModuleNotFoundError:
raise UnstructuredLibraryImportError
if self.raw_data_location.startswith("s3://"):
async with open_data_file(self.raw_data_location, mode="rb") as f:
elements = partition(file=f, content_type=self.mime_type)
else:
elements = partition(self.raw_data_location, content_type=self.mime_type)
async with open_data_file(self.raw_data_location, mode="rb") as f:
elements = partition(file=f, content_type=self.mime_type)
in_memory_file = StringIO("\n\n".join([str(el) for el in elements]))
in_memory_file.seek(0)

View file

@ -8,7 +8,7 @@ from cognee.modules.graph.cognee_graph.CogneeGraphElements import Node, Edge
from cognee.modules.graph.cognee_graph.CogneeAbstractGraph import CogneeAbstractGraph
import heapq
logger = get_logger()
logger = get_logger("CogneeGraph")
class CogneeGraph(CogneeAbstractGraph):
@ -66,7 +66,13 @@ class CogneeGraph(CogneeAbstractGraph):
) -> None:
if node_dimension < 1 or edge_dimension < 1:
raise InvalidValueError(message="Dimensions must be positive integers")
try:
import time
start_time = time.time()
# Determine projection strategy
if node_type is not None and node_name is not None:
nodes_data, edges_data = await adapter.get_nodeset_subgraph(
node_type=node_type, node_name=node_name
@ -83,16 +89,17 @@ class CogneeGraph(CogneeAbstractGraph):
nodes_data, edges_data = await adapter.get_filtered_graph_data(
attribute_filters=memory_fragment_filter
)
if not nodes_data or not edges_data:
raise EntityNotFoundError(
message="Empty filtered graph projected from the database."
)
# Process nodes
for node_id, properties in nodes_data:
node_attributes = {key: properties.get(key) for key in node_properties_to_project}
self.add_node(Node(str(node_id), node_attributes, dimension=node_dimension))
# Process edges
for source_id, target_id, relationship_type, properties in edges_data:
source_node = self.get_node(str(source_id))
target_node = self.get_node(str(target_id))
@ -113,17 +120,23 @@ class CogneeGraph(CogneeAbstractGraph):
source_node.add_skeleton_edge(edge)
target_node.add_skeleton_edge(edge)
else:
raise EntityNotFoundError(
message=f"Edge references nonexistent nodes: {source_id} -> {target_id}"
)
except (ValueError, TypeError) as e:
print(f"Error projecting graph: {e}")
raise e
# Final statistics
projection_time = time.time() - start_time
logger.info(
f"Graph projection completed: {len(self.nodes)} nodes, {len(self.edges)} edges in {projection_time:.2f}s"
)
except Exception as e:
logger.error(f"Error during graph projection: {str(e)}")
raise
async def map_vector_distances_to_graph_nodes(self, node_distances) -> None:
mapped_nodes = 0
for category, scored_results in node_distances.items():
for scored_result in scored_results:
node_id = str(scored_result.id)
@ -131,6 +144,7 @@ class CogneeGraph(CogneeAbstractGraph):
node = self.get_node(node_id)
if node:
node.add_attribute("vector_distance", score)
mapped_nodes += 1
async def map_vector_distances_to_graph_edges(
self, vector_engine, query_vector, edge_distances
@ -150,18 +164,16 @@ class CogneeGraph(CogneeAbstractGraph):
for edge in self.edges:
relationship_type = edge.attributes.get("relationship_type")
if not relationship_type or relationship_type not in embedding_map:
print(f"Edge {edge} has an unknown or missing relationship type.")
continue
edge.attributes["vector_distance"] = embedding_map[relationship_type]
if relationship_type and relationship_type in embedding_map:
edge.attributes["vector_distance"] = embedding_map[relationship_type]
except Exception as ex:
print(f"Error mapping vector distances to edges: {ex}")
logger.error(f"Error mapping vector distances to edges: {str(ex)}")
raise ex
async def calculate_top_triplet_importances(self, k: int) -> List:
min_heap = []
for i, edge in enumerate(self.edges):
source_node = self.get_node(edge.node1.id)
target_node = self.get_node(edge.node2.id)

View file

@ -33,7 +33,7 @@ async def get_formatted_graph_data(dataset_id: UUID, user_id: UUID):
lambda edge: {
"source": str(edge[0]),
"target": str(edge[1]),
"label": edge[2],
"label": str(edge[2]),
},
edges,
)

View file

@ -1,10 +1,13 @@
from typing import Any, Optional
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
logger = get_logger("ChunksRetriever")
class ChunksRetriever(BaseRetriever):
"""
@ -41,14 +44,22 @@ class ChunksRetriever(BaseRetriever):
- Any: A list of document chunk payloads retrieved from the search.
"""
logger.info(
f"Starting chunk retrieval for query: '{query[:100]}{'...' if len(query) > 100 else ''}'"
)
vector_engine = get_vector_engine()
try:
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k)
logger.info(f"Found {len(found_chunks)} chunks from vector search")
except CollectionNotFoundError as error:
logger.error("DocumentChunk_text collection not found in vector database")
raise NoDataError("No data found in the system, please add data first.") from error
return [result.payload for result in found_chunks]
chunk_payloads = [result.payload for result in found_chunks]
logger.info(f"Returning {len(chunk_payloads)} chunk payloads")
return chunk_payloads
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
"""
@ -70,6 +81,17 @@ class ChunksRetriever(BaseRetriever):
- Any: The context used for the completion or the retrieved context if none was
provided.
"""
logger.info(
f"Starting completion generation for query: '{query[:100]}{'...' if len(query) > 100 else ''}'"
)
if context is None:
logger.debug("No context provided, retrieving context from vector database")
context = await self.get_context(query)
else:
logger.debug("Using provided context")
logger.info(
f"Returning context with {len(context) if isinstance(context, list) else 1} item(s)"
)
return context

View file

@ -3,6 +3,7 @@ import asyncio
import aiofiles
from pydantic import BaseModel
from cognee.shared.logging_utils import get_logger
from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
@ -13,6 +14,8 @@ from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.l
read_query_prompt,
)
logger = get_logger("CodeRetriever")
class CodeRetriever(BaseRetriever):
"""Retriever for handling code-based searches."""
@ -39,26 +42,43 @@ class CodeRetriever(BaseRetriever):
async def _process_query(self, query: str) -> "CodeRetriever.CodeQueryInfo":
"""Process the query using LLM to extract file names and source code parts."""
logger.debug(
f"Processing query with LLM: '{query[:100]}{'...' if len(query) > 100 else ''}'"
)
system_prompt = read_query_prompt("codegraph_retriever_system.txt")
llm_client = get_llm_client()
try:
return await llm_client.acreate_structured_output(
result = await llm_client.acreate_structured_output(
text_input=query,
system_prompt=system_prompt,
response_model=self.CodeQueryInfo,
)
logger.info(
f"LLM extracted {len(result.filenames)} filenames and {len(result.sourcecode)} chars of source code"
)
return result
except Exception as e:
logger.error(f"Failed to retrieve structured output from LLM: {str(e)}")
raise RuntimeError("Failed to retrieve structured output from LLM") from e
async def get_context(self, query: str) -> Any:
"""Find relevant code files based on the query."""
logger.info(
f"Starting code retrieval for query: '{query[:100]}{'...' if len(query) > 100 else ''}'"
)
if not query or not isinstance(query, str):
logger.error("Invalid query: must be a non-empty string")
raise ValueError("The query must be a non-empty string.")
try:
vector_engine = get_vector_engine()
graph_engine = await get_graph_engine()
logger.debug("Successfully initialized vector and graph engines")
except Exception as e:
logger.error(f"Database initialization error: {str(e)}")
raise RuntimeError("Database initialization error in code_graph_retriever, ") from e
files_and_codeparts = await self._process_query(query)
@ -67,52 +87,80 @@ class CodeRetriever(BaseRetriever):
similar_codepieces = []
if not files_and_codeparts.filenames or not files_and_codeparts.sourcecode:
logger.info("No specific files/code extracted from query, performing general search")
for collection in self.file_name_collections:
logger.debug(f"Searching {collection} collection with general query")
search_results_file = await vector_engine.search(
collection, query, limit=self.top_k
)
logger.debug(f"Found {len(search_results_file)} results in {collection}")
for res in search_results_file:
similar_filenames.append(
{"id": res.id, "score": res.score, "payload": res.payload}
)
for collection in self.classes_and_functions_collections:
logger.debug(f"Searching {collection} collection with general query")
search_results_code = await vector_engine.search(
collection, query, limit=self.top_k
)
logger.debug(f"Found {len(search_results_code)} results in {collection}")
for res in search_results_code:
similar_codepieces.append(
{"id": res.id, "score": res.score, "payload": res.payload}
)
else:
logger.info(
f"Using extracted filenames ({len(files_and_codeparts.filenames)}) and source code for targeted search"
)
for collection in self.file_name_collections:
for file_from_query in files_and_codeparts.filenames:
logger.debug(f"Searching {collection} for specific file: {file_from_query}")
search_results_file = await vector_engine.search(
collection, file_from_query, limit=self.top_k
)
logger.debug(
f"Found {len(search_results_file)} results for file {file_from_query}"
)
for res in search_results_file:
similar_filenames.append(
{"id": res.id, "score": res.score, "payload": res.payload}
)
for collection in self.classes_and_functions_collections:
logger.debug(f"Searching {collection} with extracted source code")
search_results_code = await vector_engine.search(
collection, files_and_codeparts.sourcecode, limit=self.top_k
)
logger.debug(f"Found {len(search_results_code)} results for source code search")
for res in search_results_code:
similar_codepieces.append(
{"id": res.id, "score": res.score, "payload": res.payload}
)
total_items = len(similar_filenames) + len(similar_codepieces)
logger.info(
f"Total search results: {total_items} items ({len(similar_filenames)} filenames, {len(similar_codepieces)} code pieces)"
)
if total_items == 0:
logger.warning("No search results found, returning empty list")
return []
logger.debug("Getting graph connections for all search results")
relevant_triplets = await asyncio.gather(
*[
graph_engine.get_connections(similar_piece["id"])
for similar_piece in similar_filenames + similar_codepieces
]
)
logger.info(f"Retrieved graph connections for {len(relevant_triplets)} items")
paths = set()
for sublist in relevant_triplets:
for i, sublist in enumerate(relevant_triplets):
logger.debug(f"Processing connections for item {i}: {len(sublist)} connections")
for tpl in sublist:
if isinstance(tpl, tuple) and len(tpl) >= 3:
if "file_path" in tpl[0]:
@ -120,23 +168,31 @@ class CodeRetriever(BaseRetriever):
if "file_path" in tpl[2]:
paths.add(tpl[2]["file_path"])
logger.info(f"Found {len(paths)} unique file paths to read")
retrieved_files = {}
read_tasks = []
for file_path in paths:
async def read_file(fp):
try:
logger.debug(f"Reading file: {fp}")
async with aiofiles.open(fp, "r", encoding="utf-8") as f:
retrieved_files[fp] = await f.read()
content = await f.read()
retrieved_files[fp] = content
logger.debug(f"Successfully read {len(content)} characters from {fp}")
except Exception as e:
print(f"Error reading {fp}: {e}")
logger.error(f"Error reading {fp}: {e}")
retrieved_files[fp] = ""
read_tasks.append(read_file(file_path))
await asyncio.gather(*read_tasks)
logger.info(
f"Successfully read {len([f for f in retrieved_files.values() if f])} files (out of {len(paths)} total)"
)
return [
result = [
{
"name": file_path,
"description": file_path,
@ -145,6 +201,9 @@ class CodeRetriever(BaseRetriever):
for file_path in paths
]
logger.info(f"Returning {len(result)} code file contexts")
return result
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
"""Returns the code files context."""
if context is None:

View file

@ -1,11 +1,14 @@
from typing import Any, Optional
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.retrieval.utils.completion import generate_completion
from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
logger = get_logger("CompletionRetriever")
class CompletionRetriever(BaseRetriever):
"""
@ -56,8 +59,10 @@ class CompletionRetriever(BaseRetriever):
# Combine all chunks text returned from vector search (number of chunks is determined by top_k
chunks_payload = [found_chunk.payload["text"] for found_chunk in found_chunks]
return "\n".join(chunks_payload)
combined_context = "\n".join(chunks_payload)
return combined_context
except CollectionNotFoundError as error:
logger.error("DocumentChunk_text collection not found")
raise NoDataError("No data found in the system, please add data first.") from error
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
@ -70,22 +75,19 @@ class CompletionRetriever(BaseRetriever):
Parameters:
-----------
- query (str): The input query for which the completion is generated.
- context (Optional[Any]): Optional context to use for generating the completion; if
not provided, it will be retrieved using get_context. (default None)
- query (str): The query string to be used for generating a completion.
- context (Optional[Any]): Optional pre-fetched context to use for generating the
completion; if None, it retrieves the context for the query. (default None)
Returns:
--------
- Any: A list containing the generated completion from the LLM.
- Any: The generated completion based on the provided query and context.
"""
if context is None:
context = await self.get_context(query)
completion = await generate_completion(
query=query,
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
query, context, self.user_prompt_path, self.system_prompt_path
)
return [completion]
return completion

View file

@ -10,7 +10,7 @@ from cognee.modules.retrieval.utils.completion import generate_completion
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
from cognee.shared.logging_utils import get_logger
logger = get_logger()
logger = get_logger("GraphCompletionRetriever")
class GraphCompletionRetriever(BaseRetriever):

View file

@ -1,12 +1,15 @@
import asyncio
from typing import Any, Optional
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
logger = get_logger("InsightsRetriever")
class InsightsRetriever(BaseRetriever):
"""
@ -63,6 +66,7 @@ class InsightsRetriever(BaseRetriever):
vector_engine.search("EntityType_name", query_text=query, limit=self.top_k),
)
except CollectionNotFoundError as error:
logger.error("Entity collections not found")
raise NoDataError("No data found in the system, please add data first.") from error
results = [*results[0], *results[1]]

View file

@ -1,5 +1,5 @@
from typing import Any, Optional
import logging
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.graph.networkx.adapter import NetworkXAdapter
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import (
@ -12,7 +12,7 @@ from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.exceptions import SearchTypeNotSupported
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
logger = logging.getLogger("NaturalLanguageRetriever")
logger = get_logger("NaturalLanguageRetriever")
class NaturalLanguageRetriever(BaseRetriever):
@ -127,16 +127,12 @@ class NaturalLanguageRetriever(BaseRetriever):
- Optional[Any]: Returns the context retrieved from the graph database based on the
query.
"""
try:
graph_engine = await get_graph_engine()
graph_engine = await get_graph_engine()
if isinstance(graph_engine, (NetworkXAdapter)):
raise SearchTypeNotSupported("Natural language search type not supported.")
if isinstance(graph_engine, (NetworkXAdapter)):
raise SearchTypeNotSupported("Natural language search type not supported.")
return await self._execute_cypher_query(query, graph_engine)
except Exception as e:
logger.error("Failed to execute natural language search retrieval: %s", str(e))
raise e
return await self._execute_cypher_query(query, graph_engine)
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
"""

View file

@ -1,10 +1,13 @@
from typing import Any, Optional
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
logger = get_logger("SummariesRetriever")
class SummariesRetriever(BaseRetriever):
"""
@ -40,16 +43,24 @@ class SummariesRetriever(BaseRetriever):
- Any: A list of payloads from the retrieved summaries.
"""
logger.info(
f"Starting summary retrieval for query: '{query[:100]}{'...' if len(query) > 100 else ''}'"
)
vector_engine = get_vector_engine()
try:
summaries_results = await vector_engine.search(
"TextSummary_text", query, limit=self.top_k
)
logger.info(f"Found {len(summaries_results)} summaries from vector search")
except CollectionNotFoundError as error:
logger.error("TextSummary_text collection not found in vector database")
raise NoDataError("No data found in the system, please add data first.") from error
return [summary.payload for summary in summaries_results]
summary_payloads = [summary.payload for summary in summaries_results]
logger.info(f"Returning {len(summary_payloads)} summary payloads")
return summary_payloads
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
"""
@ -70,6 +81,17 @@ class SummariesRetriever(BaseRetriever):
- Any: The generated completion context, which is either provided or retrieved.
"""
logger.info(
f"Starting completion generation for query: '{query[:100]}{'...' if len(query) > 100 else ''}'"
)
if context is None:
logger.debug("No context provided, retrieving context from vector database")
context = await self.get_context(query)
else:
logger.debug("Using provided context")
logger.info(
f"Returning context with {len(context) if isinstance(context, list) else 1} item(s)"
)
return context

View file

@ -59,13 +59,13 @@ async def get_memory_fragment(
node_name: Optional[List[str]] = None,
) -> CogneeGraph:
"""Creates and initializes a CogneeGraph memory fragment with optional property projections."""
graph_engine = await get_graph_engine()
memory_fragment = CogneeGraph()
if properties_to_project is None:
properties_to_project = ["id", "description", "name", "type", "text"]
try:
graph_engine = await get_graph_engine()
memory_fragment = CogneeGraph()
await memory_fragment.project_graph_from_db(
graph_engine,
node_properties_to_project=properties_to_project,
@ -73,7 +73,13 @@ async def get_memory_fragment(
node_type=node_type,
node_name=node_name,
)
except EntityNotFoundError:
# This is expected behavior - continue with empty fragment
pass
except Exception as e:
logger.error(f"Error during memory fragment creation: {str(e)}")
# Still return the fragment even if projection failed
pass
return memory_fragment