Merge branch 'dev' into COG-2082
This commit is contained in:
commit
e03b2ea709
15 changed files with 581 additions and 119 deletions
|
|
@ -9,7 +9,7 @@ COPY package.json package-lock.json ./
|
|||
|
||||
# Install any needed packages specified in package.json
|
||||
RUN npm ci
|
||||
# RUN npm rebuild lightningcss
|
||||
RUN npm rebuild lightningcss
|
||||
|
||||
# Copy the rest of the application code to the working directory
|
||||
COPY src ./src
|
||||
|
|
|
|||
|
|
@ -72,11 +72,38 @@ class KuzuAdapter(GraphDBInterface):
|
|||
|
||||
run_sync(file_storage.ensure_directory_exists())
|
||||
|
||||
self.db = Database(
|
||||
self.db_path,
|
||||
buffer_pool_size=256 * 1024 * 1024, # 256MB buffer pool
|
||||
max_db_size=1024 * 1024 * 1024,
|
||||
)
|
||||
try:
|
||||
self.db = Database(
|
||||
self.db_path,
|
||||
buffer_pool_size=2048 * 1024 * 1024, # 2048MB buffer pool
|
||||
max_db_size=4096 * 1024 * 1024,
|
||||
)
|
||||
except RuntimeError:
|
||||
from .kuzu_migrate import read_kuzu_storage_version
|
||||
import kuzu
|
||||
|
||||
kuzu_db_version = read_kuzu_storage_version(self.db_path)
|
||||
if (
|
||||
kuzu_db_version == "0.9.0" or kuzu_db_version == "0.8.2"
|
||||
) and kuzu_db_version != kuzu.__version__:
|
||||
# TODO: Write migration script that will handle all user graph databases in multi-user mode
|
||||
# Try to migrate kuzu database to latest version
|
||||
from .kuzu_migrate import kuzu_migration
|
||||
|
||||
kuzu_migration(
|
||||
new_db=self.db_path + "new",
|
||||
old_db=self.db_path,
|
||||
new_version=kuzu.__version__,
|
||||
old_version=kuzu_db_version,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
self.db = Database(
|
||||
self.db_path,
|
||||
buffer_pool_size=2048 * 1024 * 1024, # 2048MB buffer pool
|
||||
max_db_size=4096 * 1024 * 1024,
|
||||
)
|
||||
|
||||
|
||||
self.db.init_database()
|
||||
self.connection = Connection(self.db)
|
||||
|
|
|
|||
276
cognee/infrastructure/databases/graph/kuzu/kuzu_migrate.py
Normal file
276
cognee/infrastructure/databases/graph/kuzu/kuzu_migrate.py
Normal file
|
|
@ -0,0 +1,276 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Kuzu Database Migration Script
|
||||
|
||||
This script migrates Kuzu databases between different versions by:
|
||||
1. Setting up isolated Python environments for each Kuzu version
|
||||
2. Exporting data from the source database using the old version
|
||||
3. Importing data into the target database using the new version
|
||||
4. If overwrite is enabled target database will replace source database and source database will have the prefix _old
|
||||
5. If delete-old is enabled target database will be renamed to source database and source database will be deleted
|
||||
|
||||
The script automatically handles:
|
||||
- Environment setup (creates virtual environments as needed)
|
||||
- Export/import validation
|
||||
- Error handling and reporting
|
||||
|
||||
Usage Examples:
|
||||
# Basic migration from 0.9.0 to 0.11.0
|
||||
python kuzu_migrate.py --old-version 0.9.0 --new-version 0.11.0 --old-db /path/to/old/database --new-db /path/to/new/database
|
||||
|
||||
Requirements:
|
||||
- Python 3.7+
|
||||
- Internet connection (to download Kuzu packages)
|
||||
- Sufficient disk space for virtual environments and temporary exports
|
||||
|
||||
Notes:
|
||||
- Can only be used to migrate to newer Kuzu versions, from 0.11.0 onwards
|
||||
"""
|
||||
|
||||
import tempfile
|
||||
import sys
|
||||
import struct
|
||||
import shutil
|
||||
import subprocess
|
||||
import argparse
|
||||
import os
|
||||
|
||||
|
||||
kuzu_version_mapping = {
|
||||
34: "0.7.0",
|
||||
35: "0.7.1",
|
||||
36: "0.8.2",
|
||||
37: "0.9.0",
|
||||
38: "0.10.1",
|
||||
39: "0.11.0",
|
||||
}
|
||||
|
||||
|
||||
def read_kuzu_storage_version(kuzu_db_path: str) -> int:
|
||||
"""
|
||||
Reads the Kùzu storage version code from the first catalog.bin file bytes.
|
||||
|
||||
:param kuzu_db_path: Path to the Kuzu database file/directory.
|
||||
:return: Storage version code as an integer.
|
||||
"""
|
||||
if os.path.isdir(kuzu_db_path):
|
||||
kuzu_version_file_path = os.path.join(kuzu_db_path, "catalog.kz")
|
||||
if not os.path.isfile(kuzu_version_file_path):
|
||||
raise FileExistsError("Kuzu catalog.kz file does not exist")
|
||||
else:
|
||||
kuzu_version_file_path = kuzu_db_path
|
||||
|
||||
with open(kuzu_version_file_path, "rb") as f:
|
||||
# Skip the 3-byte magic "KUZ" and one byte of padding
|
||||
f.seek(4)
|
||||
# Read the next 8 bytes as a little-endian unsigned 64-bit integer
|
||||
data = f.read(8)
|
||||
if len(data) < 8:
|
||||
raise ValueError(
|
||||
f"File '{kuzu_version_file_path}' does not contain a storage version code."
|
||||
)
|
||||
version_code = struct.unpack("<Q", data)[0]
|
||||
|
||||
if kuzu_version_mapping.get(version_code):
|
||||
return kuzu_version_mapping[version_code]
|
||||
else:
|
||||
ValueError("Could not map version_code to proper Kuzu version.")
|
||||
|
||||
|
||||
def ensure_env(version: str, export_dir) -> str:
|
||||
"""
|
||||
Create (if needed) a venv at .kuzu_envs/{version} and install kuzu=={version}.
|
||||
Returns the path to the venv's python executable.
|
||||
"""
|
||||
# Use temp directory to create venv
|
||||
kuzu_envs_dir = os.path.join(export_dir, ".kuzu_envs")
|
||||
|
||||
# venv base under the script directory
|
||||
base = os.path.join(kuzu_envs_dir, version)
|
||||
py_bin = os.path.join(base, "bin", "python")
|
||||
# If environment already exists clean it
|
||||
if os.path.isfile(py_bin):
|
||||
shutil.rmtree(base)
|
||||
|
||||
print(f"→ Setting up venv for Kùzu {version}...", file=sys.stderr)
|
||||
# Create venv
|
||||
subprocess.run([sys.executable, "-m", "venv", base], check=True)
|
||||
# Install the specific Kùzu version
|
||||
subprocess.run([py_bin, "-m", "pip", "install", "--upgrade", "pip"], check=True)
|
||||
subprocess.run([py_bin, "-m", "pip", "install", f"kuzu=={version}"], check=True)
|
||||
return py_bin
|
||||
|
||||
|
||||
def run_migration_step(python_exe: str, db_path: str, cypher: str):
|
||||
"""
|
||||
Uses the given python_exe to execute a short snippet that
|
||||
connects to the Kùzu database and runs a Cypher command.
|
||||
"""
|
||||
snippet = f"""
|
||||
import kuzu
|
||||
db = kuzu.Database(r"{db_path}")
|
||||
conn = kuzu.Connection(db)
|
||||
conn.execute(r\"\"\"{cypher}\"\"\")
|
||||
"""
|
||||
proc = subprocess.run([python_exe, "-c", snippet], capture_output=True, text=True)
|
||||
if proc.returncode != 0:
|
||||
print(f"[ERROR] {cypher} failed:\n{proc.stderr}", file=sys.stderr)
|
||||
sys.exit(proc.returncode)
|
||||
|
||||
|
||||
def kuzu_migration(new_db, old_db, new_version, old_version=None, overwrite=None, delete_old=None):
|
||||
"""
|
||||
Main migration function that handles the complete migration process.
|
||||
"""
|
||||
print(f"🔄 Migrating Kuzu database from {old_version} to {new_version}", file=sys.stderr)
|
||||
print(f"📂 Source: {old_db}", file=sys.stderr)
|
||||
print("", file=sys.stderr)
|
||||
|
||||
# If version of old kuzu db is not provided try to determine it based on file info
|
||||
if not old_version:
|
||||
old_version = read_kuzu_storage_version(old_db)
|
||||
|
||||
# Check if old database exists
|
||||
if not os.path.exists(old_db):
|
||||
print(f"Source database '{old_db}' does not exist.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Prepare target - ensure parent directory exists but remove target if it exists
|
||||
parent_dir = os.path.dirname(new_db)
|
||||
if parent_dir:
|
||||
os.makedirs(parent_dir, exist_ok=True)
|
||||
|
||||
if os.path.exists(new_db):
|
||||
raise FileExistsError(
|
||||
"File already exists at new database location, remove file or change new database file path to continue"
|
||||
)
|
||||
|
||||
# Use temp directory for all processing, it will be cleaned up after with statement
|
||||
with tempfile.TemporaryDirectory() as export_dir:
|
||||
# Set up environments
|
||||
print(f"Setting up Kuzu {old_version} environment...", file=sys.stderr)
|
||||
old_py = ensure_env(old_version, export_dir)
|
||||
print(f"Setting up Kuzu {new_version} environment...", file=sys.stderr)
|
||||
new_py = ensure_env(new_version, export_dir)
|
||||
|
||||
export_file = os.path.join(export_dir, "kuzu_export")
|
||||
print(f"Exporting old DB → {export_dir}", file=sys.stderr)
|
||||
run_migration_step(old_py, old_db, f"EXPORT DATABASE '{export_file}'")
|
||||
print("Export complete.", file=sys.stderr)
|
||||
|
||||
# Check if export files were created and have content
|
||||
schema_file = os.path.join(export_file, "schema.cypher")
|
||||
if not os.path.exists(schema_file) or os.path.getsize(schema_file) == 0:
|
||||
raise ValueError(f"Schema file not found: {schema_file}")
|
||||
|
||||
print(f"Importing into new DB at {new_db}", file=sys.stderr)
|
||||
run_migration_step(new_py, new_db, f"IMPORT DATABASE '{export_file}'")
|
||||
print("Import complete.", file=sys.stderr)
|
||||
|
||||
# Rename new kuzu database to old kuzu database name if enabled
|
||||
if overwrite or delete_old:
|
||||
rename_databases(old_db, old_version, new_db, delete_old)
|
||||
|
||||
print("✅ Kuzu graph database migration finished successfully!")
|
||||
|
||||
|
||||
def rename_databases(old_db: str, old_version: str, new_db: str, delete_old: bool):
|
||||
"""
|
||||
When overwrite is enabled, back up the original old_db (file with .lock and .wal or directory)
|
||||
by renaming it to *_old, and replace it with the newly imported new_db files.
|
||||
|
||||
When delete_old is enabled replace the old database with the new one and delete old database
|
||||
"""
|
||||
base_dir = os.path.dirname(old_db)
|
||||
name = os.path.basename(old_db.rstrip(os.sep))
|
||||
# Add _old_ and version info to backup graph database
|
||||
backup_database_name = f"{name}_old_" + old_version.replace(".", "_")
|
||||
backup_base = os.path.join(base_dir, backup_database_name)
|
||||
|
||||
if os.path.isfile(old_db):
|
||||
# File-based database: handle main file and accompanying lock/WAL
|
||||
for ext in ["", ".lock", ".wal"]:
|
||||
src = old_db + ext
|
||||
dst = backup_base + ext
|
||||
if os.path.exists(src):
|
||||
if delete_old:
|
||||
os.remove(src)
|
||||
else:
|
||||
os.rename(src, dst)
|
||||
print(f"Renamed '{src}' to '{dst}'", file=sys.stderr)
|
||||
elif os.path.isdir(old_db):
|
||||
# Directory-based Kuzu database
|
||||
backup_dir = backup_base
|
||||
if delete_old:
|
||||
shutil.rmtree(old_db)
|
||||
else:
|
||||
os.rename(old_db, backup_dir)
|
||||
print(f"Renamed directory '{old_db}' to '{backup_dir}'", file=sys.stderr)
|
||||
else:
|
||||
print(f"Original database path '{old_db}' not found for renaming.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Now move new files into place
|
||||
for ext in ["", ".lock", ".wal"]:
|
||||
src_new = new_db + ext
|
||||
dst_new = os.path.join(base_dir, name + ext)
|
||||
if os.path.exists(src_new):
|
||||
os.rename(src_new, dst_new)
|
||||
print(f"Renamed '{src_new}' to '{dst_new}'", file=sys.stderr)
|
||||
|
||||
|
||||
def main():
|
||||
p = argparse.ArgumentParser(
|
||||
description="Migrate Kùzu DB via PyPI versions",
|
||||
epilog="""
|
||||
Examples:
|
||||
%(prog)s --old-version 0.9.0 --new-version 0.11.0 \\
|
||||
--old-db /path/to/old/db --new-db /path/to/new/db --overwrite
|
||||
|
||||
Note: This script will create virtual environments in .kuzu_envs/ directory
|
||||
to isolate different Kuzu versions.
|
||||
""",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
p.add_argument(
|
||||
"--old-version",
|
||||
required=False,
|
||||
default=None,
|
||||
help="Source Kuzu version (e.g., 0.9.0). If not provided automatic kuzu version detection will be attempted.",
|
||||
)
|
||||
p.add_argument("--new-version", required=True, help="Target Kuzu version (e.g., 0.11.0)")
|
||||
p.add_argument("--old-db", required=True, help="Path to source database directory")
|
||||
p.add_argument(
|
||||
"--new-db",
|
||||
required=True,
|
||||
help="Path to target database directory, it can't be the same path as the old database. Use the overwrite flag if you want to replace the old database with the new one.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--overwrite",
|
||||
required=False,
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Rename new-db to the old-db name and location, keeps old-db as backup if delete-old is not True",
|
||||
)
|
||||
p.add_argument(
|
||||
"--delete-old",
|
||||
required=False,
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="When overwrite and delete-old is True old-db will not be stored as backup",
|
||||
)
|
||||
|
||||
args = p.parse_args()
|
||||
|
||||
kuzu_migration(
|
||||
new_db=args.new_db,
|
||||
old_db=args.old_db,
|
||||
new_version=args.new_version,
|
||||
old_version=args.old_version,
|
||||
overwrite=args.overwrite,
|
||||
delete_old=args.delete_old,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -33,7 +33,7 @@ from .neo4j_metrics_utils import (
|
|||
from .deadlock_retry import deadlock_retry
|
||||
|
||||
|
||||
logger = get_logger("Neo4jAdapter", level=ERROR)
|
||||
logger = get_logger("Neo4jAdapter")
|
||||
|
||||
BASE_LABEL = "__Node__"
|
||||
|
||||
|
|
@ -870,34 +870,52 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
|
||||
A tuple containing two lists: nodes and edges with their properties.
|
||||
"""
|
||||
query = "MATCH (n) RETURN ID(n) AS id, labels(n) AS labels, properties(n) AS properties"
|
||||
import time
|
||||
|
||||
result = await self.query(query)
|
||||
start_time = time.time()
|
||||
|
||||
nodes = [
|
||||
(
|
||||
record["properties"]["id"],
|
||||
record["properties"],
|
||||
try:
|
||||
# Retrieve nodes
|
||||
query = "MATCH (n) RETURN ID(n) AS id, labels(n) AS labels, properties(n) AS properties"
|
||||
result = await self.query(query)
|
||||
|
||||
nodes = []
|
||||
for record in result:
|
||||
nodes.append(
|
||||
(
|
||||
record["properties"]["id"],
|
||||
record["properties"],
|
||||
)
|
||||
)
|
||||
|
||||
# Retrieve edges
|
||||
query = """
|
||||
MATCH (n)-[r]->(m)
|
||||
RETURN ID(n) AS source, ID(m) AS target, TYPE(r) AS type, properties(r) AS properties
|
||||
"""
|
||||
result = await self.query(query)
|
||||
|
||||
edges = []
|
||||
for record in result:
|
||||
edges.append(
|
||||
(
|
||||
record["properties"]["source_node_id"],
|
||||
record["properties"]["target_node_id"],
|
||||
record["type"],
|
||||
record["properties"],
|
||||
)
|
||||
)
|
||||
|
||||
retrieval_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"Retrieved {len(nodes)} nodes and {len(edges)} edges in {retrieval_time:.2f} seconds"
|
||||
)
|
||||
for record in result
|
||||
]
|
||||
|
||||
query = """
|
||||
MATCH (n)-[r]->(m)
|
||||
RETURN ID(n) AS source, ID(m) AS target, TYPE(r) AS type, properties(r) AS properties
|
||||
"""
|
||||
result = await self.query(query)
|
||||
edges = [
|
||||
(
|
||||
record["properties"]["source_node_id"],
|
||||
record["properties"]["target_node_id"],
|
||||
record["type"],
|
||||
record["properties"],
|
||||
)
|
||||
for record in result
|
||||
]
|
||||
return (nodes, edges)
|
||||
|
||||
return (nodes, edges)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during graph data retrieval: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_nodeset_subgraph(
|
||||
self, node_type: Type[Any], node_name: List[str]
|
||||
|
|
@ -918,50 +936,71 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
- Tuple[List[Tuple[int, dict]], List[Tuple[int, int, str, dict]]}: A tuple
|
||||
containing nodes and edges in the requested subgraph.
|
||||
"""
|
||||
label = node_type.__name__
|
||||
import time
|
||||
|
||||
query = f"""
|
||||
UNWIND $names AS wantedName
|
||||
MATCH (n:`{label}`)
|
||||
WHERE n.name = wantedName
|
||||
WITH collect(DISTINCT n) AS primary
|
||||
UNWIND primary AS p
|
||||
OPTIONAL MATCH (p)--(nbr)
|
||||
WITH primary, collect(DISTINCT nbr) AS nbrs
|
||||
WITH primary + nbrs AS nodelist
|
||||
UNWIND nodelist AS node
|
||||
WITH collect(DISTINCT node) AS nodes
|
||||
MATCH (a)-[r]-(b)
|
||||
WHERE a IN nodes AND b IN nodes
|
||||
WITH nodes, collect(DISTINCT r) AS rels
|
||||
RETURN
|
||||
[n IN nodes |
|
||||
{{ id: n.id,
|
||||
properties: properties(n) }}] AS rawNodes,
|
||||
[r IN rels |
|
||||
{{ type: type(r),
|
||||
properties: properties(r) }}] AS rawRels
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
result = await self.query(query, {"names": node_name})
|
||||
if not result:
|
||||
return [], []
|
||||
try:
|
||||
label = node_type.__name__
|
||||
|
||||
raw_nodes = result[0]["rawNodes"]
|
||||
raw_rels = result[0]["rawRels"]
|
||||
query = f"""
|
||||
UNWIND $names AS wantedName
|
||||
MATCH (n:`{label}`)
|
||||
WHERE n.name = wantedName
|
||||
WITH collect(DISTINCT n) AS primary
|
||||
UNWIND primary AS p
|
||||
OPTIONAL MATCH (p)--(nbr)
|
||||
WITH primary, collect(DISTINCT nbr) AS nbrs
|
||||
WITH primary + nbrs AS nodelist
|
||||
UNWIND nodelist AS node
|
||||
WITH collect(DISTINCT node) AS nodes
|
||||
MATCH (a)-[r]-(b)
|
||||
WHERE a IN nodes AND b IN nodes
|
||||
WITH nodes, collect(DISTINCT r) AS rels
|
||||
RETURN
|
||||
[n IN nodes |
|
||||
{{ id: n.id,
|
||||
properties: properties(n) }}] AS rawNodes,
|
||||
[r IN rels |
|
||||
{{ type: type(r),
|
||||
properties: properties(r) }}] AS rawRels
|
||||
"""
|
||||
|
||||
nodes = [(n["properties"]["id"], n["properties"]) for n in raw_nodes]
|
||||
edges = [
|
||||
(
|
||||
r["properties"]["source_node_id"],
|
||||
r["properties"]["target_node_id"],
|
||||
r["type"],
|
||||
r["properties"],
|
||||
result = await self.query(query, {"names": node_name})
|
||||
|
||||
if not result:
|
||||
return [], []
|
||||
|
||||
raw_nodes = result[0]["rawNodes"]
|
||||
raw_rels = result[0]["rawRels"]
|
||||
|
||||
# Process nodes
|
||||
nodes = []
|
||||
for n in raw_nodes:
|
||||
nodes.append((n["properties"]["id"], n["properties"]))
|
||||
|
||||
# Process edges
|
||||
edges = []
|
||||
for r in raw_rels:
|
||||
edges.append(
|
||||
(
|
||||
r["properties"]["source_node_id"],
|
||||
r["properties"]["target_node_id"],
|
||||
r["type"],
|
||||
r["properties"],
|
||||
)
|
||||
)
|
||||
|
||||
retrieval_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"Retrieved {len(nodes)} nodes and {len(edges)} edges for {node_type.__name__} in {retrieval_time:.2f} seconds"
|
||||
)
|
||||
for r in raw_rels
|
||||
]
|
||||
|
||||
return nodes, edges
|
||||
return nodes, edges
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during nodeset subgraph retrieval: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_filtered_graph_data(self, attribute_filters):
|
||||
"""
|
||||
|
|
@ -1011,8 +1050,8 @@ class Neo4jAdapter(GraphDBInterface):
|
|||
|
||||
edges = [
|
||||
(
|
||||
record["source"],
|
||||
record["target"],
|
||||
record["properties"]["source_node_id"],
|
||||
record["properties"]["target_node_id"],
|
||||
record["type"],
|
||||
record["properties"],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -18,11 +18,8 @@ class UnstructuredDocument(Document):
|
|||
except ModuleNotFoundError:
|
||||
raise UnstructuredLibraryImportError
|
||||
|
||||
if self.raw_data_location.startswith("s3://"):
|
||||
async with open_data_file(self.raw_data_location, mode="rb") as f:
|
||||
elements = partition(file=f, content_type=self.mime_type)
|
||||
else:
|
||||
elements = partition(self.raw_data_location, content_type=self.mime_type)
|
||||
async with open_data_file(self.raw_data_location, mode="rb") as f:
|
||||
elements = partition(file=f, content_type=self.mime_type)
|
||||
|
||||
in_memory_file = StringIO("\n\n".join([str(el) for el in elements]))
|
||||
in_memory_file.seek(0)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from cognee.modules.graph.cognee_graph.CogneeGraphElements import Node, Edge
|
|||
from cognee.modules.graph.cognee_graph.CogneeAbstractGraph import CogneeAbstractGraph
|
||||
import heapq
|
||||
|
||||
logger = get_logger()
|
||||
logger = get_logger("CogneeGraph")
|
||||
|
||||
|
||||
class CogneeGraph(CogneeAbstractGraph):
|
||||
|
|
@ -66,7 +66,13 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
) -> None:
|
||||
if node_dimension < 1 or edge_dimension < 1:
|
||||
raise InvalidValueError(message="Dimensions must be positive integers")
|
||||
|
||||
try:
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Determine projection strategy
|
||||
if node_type is not None and node_name is not None:
|
||||
nodes_data, edges_data = await adapter.get_nodeset_subgraph(
|
||||
node_type=node_type, node_name=node_name
|
||||
|
|
@ -83,16 +89,17 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
nodes_data, edges_data = await adapter.get_filtered_graph_data(
|
||||
attribute_filters=memory_fragment_filter
|
||||
)
|
||||
|
||||
if not nodes_data or not edges_data:
|
||||
raise EntityNotFoundError(
|
||||
message="Empty filtered graph projected from the database."
|
||||
)
|
||||
|
||||
# Process nodes
|
||||
for node_id, properties in nodes_data:
|
||||
node_attributes = {key: properties.get(key) for key in node_properties_to_project}
|
||||
self.add_node(Node(str(node_id), node_attributes, dimension=node_dimension))
|
||||
|
||||
# Process edges
|
||||
for source_id, target_id, relationship_type, properties in edges_data:
|
||||
source_node = self.get_node(str(source_id))
|
||||
target_node = self.get_node(str(target_id))
|
||||
|
|
@ -113,17 +120,23 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
|
||||
source_node.add_skeleton_edge(edge)
|
||||
target_node.add_skeleton_edge(edge)
|
||||
|
||||
else:
|
||||
raise EntityNotFoundError(
|
||||
message=f"Edge references nonexistent nodes: {source_id} -> {target_id}"
|
||||
)
|
||||
|
||||
except (ValueError, TypeError) as e:
|
||||
print(f"Error projecting graph: {e}")
|
||||
raise e
|
||||
# Final statistics
|
||||
projection_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"Graph projection completed: {len(self.nodes)} nodes, {len(self.edges)} edges in {projection_time:.2f}s"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during graph projection: {str(e)}")
|
||||
raise
|
||||
|
||||
async def map_vector_distances_to_graph_nodes(self, node_distances) -> None:
|
||||
mapped_nodes = 0
|
||||
for category, scored_results in node_distances.items():
|
||||
for scored_result in scored_results:
|
||||
node_id = str(scored_result.id)
|
||||
|
|
@ -131,6 +144,7 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
node = self.get_node(node_id)
|
||||
if node:
|
||||
node.add_attribute("vector_distance", score)
|
||||
mapped_nodes += 1
|
||||
|
||||
async def map_vector_distances_to_graph_edges(
|
||||
self, vector_engine, query_vector, edge_distances
|
||||
|
|
@ -150,18 +164,16 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
|
||||
for edge in self.edges:
|
||||
relationship_type = edge.attributes.get("relationship_type")
|
||||
if not relationship_type or relationship_type not in embedding_map:
|
||||
print(f"Edge {edge} has an unknown or missing relationship type.")
|
||||
continue
|
||||
|
||||
edge.attributes["vector_distance"] = embedding_map[relationship_type]
|
||||
if relationship_type and relationship_type in embedding_map:
|
||||
edge.attributes["vector_distance"] = embedding_map[relationship_type]
|
||||
|
||||
except Exception as ex:
|
||||
print(f"Error mapping vector distances to edges: {ex}")
|
||||
logger.error(f"Error mapping vector distances to edges: {str(ex)}")
|
||||
raise ex
|
||||
|
||||
async def calculate_top_triplet_importances(self, k: int) -> List:
|
||||
min_heap = []
|
||||
|
||||
for i, edge in enumerate(self.edges):
|
||||
source_node = self.get_node(edge.node1.id)
|
||||
target_node = self.get_node(edge.node2.id)
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ async def get_formatted_graph_data(dataset_id: UUID, user_id: UUID):
|
|||
lambda edge: {
|
||||
"source": str(edge[0]),
|
||||
"target": str(edge[1]),
|
||||
"label": edge[2],
|
||||
"label": str(edge[2]),
|
||||
},
|
||||
edges,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,13 @@
|
|||
from typing import Any, Optional
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
|
||||
|
||||
logger = get_logger("ChunksRetriever")
|
||||
|
||||
|
||||
class ChunksRetriever(BaseRetriever):
|
||||
"""
|
||||
|
|
@ -41,14 +44,22 @@ class ChunksRetriever(BaseRetriever):
|
|||
|
||||
- Any: A list of document chunk payloads retrieved from the search.
|
||||
"""
|
||||
logger.info(
|
||||
f"Starting chunk retrieval for query: '{query[:100]}{'...' if len(query) > 100 else ''}'"
|
||||
)
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
try:
|
||||
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k)
|
||||
logger.info(f"Found {len(found_chunks)} chunks from vector search")
|
||||
except CollectionNotFoundError as error:
|
||||
logger.error("DocumentChunk_text collection not found in vector database")
|
||||
raise NoDataError("No data found in the system, please add data first.") from error
|
||||
|
||||
return [result.payload for result in found_chunks]
|
||||
chunk_payloads = [result.payload for result in found_chunks]
|
||||
logger.info(f"Returning {len(chunk_payloads)} chunk payloads")
|
||||
return chunk_payloads
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
"""
|
||||
|
|
@ -70,6 +81,17 @@ class ChunksRetriever(BaseRetriever):
|
|||
- Any: The context used for the completion or the retrieved context if none was
|
||||
provided.
|
||||
"""
|
||||
logger.info(
|
||||
f"Starting completion generation for query: '{query[:100]}{'...' if len(query) > 100 else ''}'"
|
||||
)
|
||||
|
||||
if context is None:
|
||||
logger.debug("No context provided, retrieving context from vector database")
|
||||
context = await self.get_context(query)
|
||||
else:
|
||||
logger.debug("Using provided context")
|
||||
|
||||
logger.info(
|
||||
f"Returning context with {len(context) if isinstance(context, list) else 1} item(s)"
|
||||
)
|
||||
return context
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import asyncio
|
|||
import aiofiles
|
||||
from pydantic import BaseModel
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
|
@ -13,6 +14,8 @@ from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.l
|
|||
read_query_prompt,
|
||||
)
|
||||
|
||||
logger = get_logger("CodeRetriever")
|
||||
|
||||
|
||||
class CodeRetriever(BaseRetriever):
|
||||
"""Retriever for handling code-based searches."""
|
||||
|
|
@ -39,26 +42,43 @@ class CodeRetriever(BaseRetriever):
|
|||
|
||||
async def _process_query(self, query: str) -> "CodeRetriever.CodeQueryInfo":
|
||||
"""Process the query using LLM to extract file names and source code parts."""
|
||||
logger.debug(
|
||||
f"Processing query with LLM: '{query[:100]}{'...' if len(query) > 100 else ''}'"
|
||||
)
|
||||
|
||||
system_prompt = read_query_prompt("codegraph_retriever_system.txt")
|
||||
llm_client = get_llm_client()
|
||||
|
||||
try:
|
||||
return await llm_client.acreate_structured_output(
|
||||
result = await llm_client.acreate_structured_output(
|
||||
text_input=query,
|
||||
system_prompt=system_prompt,
|
||||
response_model=self.CodeQueryInfo,
|
||||
)
|
||||
logger.info(
|
||||
f"LLM extracted {len(result.filenames)} filenames and {len(result.sourcecode)} chars of source code"
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to retrieve structured output from LLM: {str(e)}")
|
||||
raise RuntimeError("Failed to retrieve structured output from LLM") from e
|
||||
|
||||
async def get_context(self, query: str) -> Any:
|
||||
"""Find relevant code files based on the query."""
|
||||
logger.info(
|
||||
f"Starting code retrieval for query: '{query[:100]}{'...' if len(query) > 100 else ''}'"
|
||||
)
|
||||
|
||||
if not query or not isinstance(query, str):
|
||||
logger.error("Invalid query: must be a non-empty string")
|
||||
raise ValueError("The query must be a non-empty string.")
|
||||
|
||||
try:
|
||||
vector_engine = get_vector_engine()
|
||||
graph_engine = await get_graph_engine()
|
||||
logger.debug("Successfully initialized vector and graph engines")
|
||||
except Exception as e:
|
||||
logger.error(f"Database initialization error: {str(e)}")
|
||||
raise RuntimeError("Database initialization error in code_graph_retriever, ") from e
|
||||
|
||||
files_and_codeparts = await self._process_query(query)
|
||||
|
|
@ -67,52 +87,80 @@ class CodeRetriever(BaseRetriever):
|
|||
similar_codepieces = []
|
||||
|
||||
if not files_and_codeparts.filenames or not files_and_codeparts.sourcecode:
|
||||
logger.info("No specific files/code extracted from query, performing general search")
|
||||
|
||||
for collection in self.file_name_collections:
|
||||
logger.debug(f"Searching {collection} collection with general query")
|
||||
search_results_file = await vector_engine.search(
|
||||
collection, query, limit=self.top_k
|
||||
)
|
||||
logger.debug(f"Found {len(search_results_file)} results in {collection}")
|
||||
for res in search_results_file:
|
||||
similar_filenames.append(
|
||||
{"id": res.id, "score": res.score, "payload": res.payload}
|
||||
)
|
||||
|
||||
for collection in self.classes_and_functions_collections:
|
||||
logger.debug(f"Searching {collection} collection with general query")
|
||||
search_results_code = await vector_engine.search(
|
||||
collection, query, limit=self.top_k
|
||||
)
|
||||
logger.debug(f"Found {len(search_results_code)} results in {collection}")
|
||||
for res in search_results_code:
|
||||
similar_codepieces.append(
|
||||
{"id": res.id, "score": res.score, "payload": res.payload}
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Using extracted filenames ({len(files_and_codeparts.filenames)}) and source code for targeted search"
|
||||
)
|
||||
|
||||
for collection in self.file_name_collections:
|
||||
for file_from_query in files_and_codeparts.filenames:
|
||||
logger.debug(f"Searching {collection} for specific file: {file_from_query}")
|
||||
search_results_file = await vector_engine.search(
|
||||
collection, file_from_query, limit=self.top_k
|
||||
)
|
||||
logger.debug(
|
||||
f"Found {len(search_results_file)} results for file {file_from_query}"
|
||||
)
|
||||
for res in search_results_file:
|
||||
similar_filenames.append(
|
||||
{"id": res.id, "score": res.score, "payload": res.payload}
|
||||
)
|
||||
|
||||
for collection in self.classes_and_functions_collections:
|
||||
logger.debug(f"Searching {collection} with extracted source code")
|
||||
search_results_code = await vector_engine.search(
|
||||
collection, files_and_codeparts.sourcecode, limit=self.top_k
|
||||
)
|
||||
logger.debug(f"Found {len(search_results_code)} results for source code search")
|
||||
for res in search_results_code:
|
||||
similar_codepieces.append(
|
||||
{"id": res.id, "score": res.score, "payload": res.payload}
|
||||
)
|
||||
|
||||
total_items = len(similar_filenames) + len(similar_codepieces)
|
||||
logger.info(
|
||||
f"Total search results: {total_items} items ({len(similar_filenames)} filenames, {len(similar_codepieces)} code pieces)"
|
||||
)
|
||||
|
||||
if total_items == 0:
|
||||
logger.warning("No search results found, returning empty list")
|
||||
return []
|
||||
|
||||
logger.debug("Getting graph connections for all search results")
|
||||
relevant_triplets = await asyncio.gather(
|
||||
*[
|
||||
graph_engine.get_connections(similar_piece["id"])
|
||||
for similar_piece in similar_filenames + similar_codepieces
|
||||
]
|
||||
)
|
||||
logger.info(f"Retrieved graph connections for {len(relevant_triplets)} items")
|
||||
|
||||
paths = set()
|
||||
for sublist in relevant_triplets:
|
||||
for i, sublist in enumerate(relevant_triplets):
|
||||
logger.debug(f"Processing connections for item {i}: {len(sublist)} connections")
|
||||
for tpl in sublist:
|
||||
if isinstance(tpl, tuple) and len(tpl) >= 3:
|
||||
if "file_path" in tpl[0]:
|
||||
|
|
@ -120,23 +168,31 @@ class CodeRetriever(BaseRetriever):
|
|||
if "file_path" in tpl[2]:
|
||||
paths.add(tpl[2]["file_path"])
|
||||
|
||||
logger.info(f"Found {len(paths)} unique file paths to read")
|
||||
|
||||
retrieved_files = {}
|
||||
read_tasks = []
|
||||
for file_path in paths:
|
||||
|
||||
async def read_file(fp):
|
||||
try:
|
||||
logger.debug(f"Reading file: {fp}")
|
||||
async with aiofiles.open(fp, "r", encoding="utf-8") as f:
|
||||
retrieved_files[fp] = await f.read()
|
||||
content = await f.read()
|
||||
retrieved_files[fp] = content
|
||||
logger.debug(f"Successfully read {len(content)} characters from {fp}")
|
||||
except Exception as e:
|
||||
print(f"Error reading {fp}: {e}")
|
||||
logger.error(f"Error reading {fp}: {e}")
|
||||
retrieved_files[fp] = ""
|
||||
|
||||
read_tasks.append(read_file(file_path))
|
||||
|
||||
await asyncio.gather(*read_tasks)
|
||||
logger.info(
|
||||
f"Successfully read {len([f for f in retrieved_files.values() if f])} files (out of {len(paths)} total)"
|
||||
)
|
||||
|
||||
return [
|
||||
result = [
|
||||
{
|
||||
"name": file_path,
|
||||
"description": file_path,
|
||||
|
|
@ -145,6 +201,9 @@ class CodeRetriever(BaseRetriever):
|
|||
for file_path in paths
|
||||
]
|
||||
|
||||
logger.info(f"Returning {len(result)} code file contexts")
|
||||
return result
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
"""Returns the code files context."""
|
||||
if context is None:
|
||||
|
|
|
|||
|
|
@ -1,11 +1,14 @@
|
|||
from typing import Any, Optional
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.retrieval.utils.completion import generate_completion
|
||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
||||
|
||||
logger = get_logger("CompletionRetriever")
|
||||
|
||||
|
||||
class CompletionRetriever(BaseRetriever):
|
||||
"""
|
||||
|
|
@ -56,8 +59,10 @@ class CompletionRetriever(BaseRetriever):
|
|||
|
||||
# Combine all chunks text returned from vector search (number of chunks is determined by top_k
|
||||
chunks_payload = [found_chunk.payload["text"] for found_chunk in found_chunks]
|
||||
return "\n".join(chunks_payload)
|
||||
combined_context = "\n".join(chunks_payload)
|
||||
return combined_context
|
||||
except CollectionNotFoundError as error:
|
||||
logger.error("DocumentChunk_text collection not found")
|
||||
raise NoDataError("No data found in the system, please add data first.") from error
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
|
|
@ -70,22 +75,19 @@ class CompletionRetriever(BaseRetriever):
|
|||
Parameters:
|
||||
-----------
|
||||
|
||||
- query (str): The input query for which the completion is generated.
|
||||
- context (Optional[Any]): Optional context to use for generating the completion; if
|
||||
not provided, it will be retrieved using get_context. (default None)
|
||||
- query (str): The query string to be used for generating a completion.
|
||||
- context (Optional[Any]): Optional pre-fetched context to use for generating the
|
||||
completion; if None, it retrieves the context for the query. (default None)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- Any: A list containing the generated completion from the LLM.
|
||||
- Any: The generated completion based on the provided query and context.
|
||||
"""
|
||||
if context is None:
|
||||
context = await self.get_context(query)
|
||||
|
||||
completion = await generate_completion(
|
||||
query=query,
|
||||
context=context,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
query, context, self.user_prompt_path, self.system_prompt_path
|
||||
)
|
||||
return [completion]
|
||||
return completion
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from cognee.modules.retrieval.utils.completion import generate_completion
|
|||
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
logger = get_logger("GraphCompletionRetriever")
|
||||
|
||||
|
||||
class GraphCompletionRetriever(BaseRetriever):
|
||||
|
|
|
|||
|
|
@ -1,12 +1,15 @@
|
|||
import asyncio
|
||||
from typing import Any, Optional
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
|
||||
|
||||
logger = get_logger("InsightsRetriever")
|
||||
|
||||
|
||||
class InsightsRetriever(BaseRetriever):
|
||||
"""
|
||||
|
|
@ -63,6 +66,7 @@ class InsightsRetriever(BaseRetriever):
|
|||
vector_engine.search("EntityType_name", query_text=query, limit=self.top_k),
|
||||
)
|
||||
except CollectionNotFoundError as error:
|
||||
logger.error("Entity collections not found")
|
||||
raise NoDataError("No data found in the system, please add data first.") from error
|
||||
|
||||
results = [*results[0], *results[1]]
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from typing import Any, Optional
|
||||
import logging
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.graph.networkx.adapter import NetworkXAdapter
|
||||
from cognee.infrastructure.llm.structured_output_framework.llitellm_instructor.llm.get_llm_client import (
|
||||
|
|
@ -12,7 +12,7 @@ from cognee.modules.retrieval.base_retriever import BaseRetriever
|
|||
from cognee.modules.retrieval.exceptions import SearchTypeNotSupported
|
||||
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
|
||||
|
||||
logger = logging.getLogger("NaturalLanguageRetriever")
|
||||
logger = get_logger("NaturalLanguageRetriever")
|
||||
|
||||
|
||||
class NaturalLanguageRetriever(BaseRetriever):
|
||||
|
|
@ -127,16 +127,12 @@ class NaturalLanguageRetriever(BaseRetriever):
|
|||
- Optional[Any]: Returns the context retrieved from the graph database based on the
|
||||
query.
|
||||
"""
|
||||
try:
|
||||
graph_engine = await get_graph_engine()
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
if isinstance(graph_engine, (NetworkXAdapter)):
|
||||
raise SearchTypeNotSupported("Natural language search type not supported.")
|
||||
if isinstance(graph_engine, (NetworkXAdapter)):
|
||||
raise SearchTypeNotSupported("Natural language search type not supported.")
|
||||
|
||||
return await self._execute_cypher_query(query, graph_engine)
|
||||
except Exception as e:
|
||||
logger.error("Failed to execute natural language search retrieval: %s", str(e))
|
||||
raise e
|
||||
return await self._execute_cypher_query(query, graph_engine)
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,10 +1,13 @@
|
|||
from typing import Any, Optional
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||
from cognee.modules.retrieval.exceptions.exceptions import NoDataError
|
||||
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
|
||||
|
||||
logger = get_logger("SummariesRetriever")
|
||||
|
||||
|
||||
class SummariesRetriever(BaseRetriever):
|
||||
"""
|
||||
|
|
@ -40,16 +43,24 @@ class SummariesRetriever(BaseRetriever):
|
|||
|
||||
- Any: A list of payloads from the retrieved summaries.
|
||||
"""
|
||||
logger.info(
|
||||
f"Starting summary retrieval for query: '{query[:100]}{'...' if len(query) > 100 else ''}'"
|
||||
)
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
|
||||
try:
|
||||
summaries_results = await vector_engine.search(
|
||||
"TextSummary_text", query, limit=self.top_k
|
||||
)
|
||||
logger.info(f"Found {len(summaries_results)} summaries from vector search")
|
||||
except CollectionNotFoundError as error:
|
||||
logger.error("TextSummary_text collection not found in vector database")
|
||||
raise NoDataError("No data found in the system, please add data first.") from error
|
||||
|
||||
return [summary.payload for summary in summaries_results]
|
||||
summary_payloads = [summary.payload for summary in summaries_results]
|
||||
logger.info(f"Returning {len(summary_payloads)} summary payloads")
|
||||
return summary_payloads
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
"""
|
||||
|
|
@ -70,6 +81,17 @@ class SummariesRetriever(BaseRetriever):
|
|||
|
||||
- Any: The generated completion context, which is either provided or retrieved.
|
||||
"""
|
||||
logger.info(
|
||||
f"Starting completion generation for query: '{query[:100]}{'...' if len(query) > 100 else ''}'"
|
||||
)
|
||||
|
||||
if context is None:
|
||||
logger.debug("No context provided, retrieving context from vector database")
|
||||
context = await self.get_context(query)
|
||||
else:
|
||||
logger.debug("Using provided context")
|
||||
|
||||
logger.info(
|
||||
f"Returning context with {len(context) if isinstance(context, list) else 1} item(s)"
|
||||
)
|
||||
return context
|
||||
|
|
|
|||
|
|
@ -59,13 +59,13 @@ async def get_memory_fragment(
|
|||
node_name: Optional[List[str]] = None,
|
||||
) -> CogneeGraph:
|
||||
"""Creates and initializes a CogneeGraph memory fragment with optional property projections."""
|
||||
graph_engine = await get_graph_engine()
|
||||
memory_fragment = CogneeGraph()
|
||||
|
||||
if properties_to_project is None:
|
||||
properties_to_project = ["id", "description", "name", "type", "text"]
|
||||
|
||||
try:
|
||||
graph_engine = await get_graph_engine()
|
||||
memory_fragment = CogneeGraph()
|
||||
|
||||
await memory_fragment.project_graph_from_db(
|
||||
graph_engine,
|
||||
node_properties_to_project=properties_to_project,
|
||||
|
|
@ -73,7 +73,13 @@ async def get_memory_fragment(
|
|||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
)
|
||||
|
||||
except EntityNotFoundError:
|
||||
# This is expected behavior - continue with empty fragment
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Error during memory fragment creation: {str(e)}")
|
||||
# Still return the fragment even if projection failed
|
||||
pass
|
||||
|
||||
return memory_fragment
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue