Compare commits
20 commits
main
...
chore/mypy
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a3d2e7bd2a | ||
|
|
ed83f76f25 | ||
|
|
3a65160839 | ||
|
|
1d282f3f83 | ||
|
|
1a5914e0ab | ||
|
|
d2a2ade643 | ||
|
|
0fb962e29a | ||
|
|
b9cd847e9d | ||
|
|
aa74672aeb | ||
|
|
9a2bf0f137 | ||
|
|
b59087841b | ||
|
|
deaf3debbf | ||
|
|
eebca89855 | ||
|
|
4ae41fede3 | ||
|
|
26f5ab4f0f | ||
|
|
baffd9187e | ||
|
|
85f37a6ee5 | ||
|
|
aa686cefe8 | ||
|
|
636d38c018 | ||
|
|
7d80701381 |
18 changed files with 648 additions and 1524 deletions
76
.github/workflows/database_protocol_mypy_check.yml
vendored
Normal file
76
.github/workflows/database_protocol_mypy_check.yml
vendored
Normal file
|
|
@ -0,0 +1,76 @@
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
name: Database Adapter MyPy Check
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
push:
|
||||||
|
branches: [ main, dev ]
|
||||||
|
paths:
|
||||||
|
- 'cognee/infrastructure/databases/**'
|
||||||
|
- 'tools/check_*_adapters.sh'
|
||||||
|
- 'mypy.ini'
|
||||||
|
- '.github/workflows/database_protocol_mypy_check.yml'
|
||||||
|
pull_request:
|
||||||
|
branches: [ main, dev ]
|
||||||
|
paths:
|
||||||
|
- 'cognee/infrastructure/databases/**'
|
||||||
|
- 'tools/check_*_adapters.sh'
|
||||||
|
- 'mypy.ini'
|
||||||
|
- '.github/workflows/database_protocol_mypy_check.yml'
|
||||||
|
|
||||||
|
env:
|
||||||
|
RUNTIME__LOG_LEVEL: ERROR
|
||||||
|
ENV: 'dev'
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
mypy-database-adapters:
|
||||||
|
name: MyPy Database Adapter Type Check
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Check out repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Cognee Setup
|
||||||
|
uses: ./.github/actions/cognee_setup
|
||||||
|
with:
|
||||||
|
python-version: '3.11.x'
|
||||||
|
|
||||||
|
- name: Discover and Check Vector Database Adapters
|
||||||
|
run: ./tools/check_vector_adapters.sh
|
||||||
|
|
||||||
|
- name: Discover and Check Graph Database Adapters
|
||||||
|
run: ./tools/check_graph_adapters.sh
|
||||||
|
|
||||||
|
- name: Discover and Check Hybrid Database Adapters
|
||||||
|
run: ./tools/check_hybrid_adapters.sh
|
||||||
|
|
||||||
|
- name: Protocol Compliance Summary
|
||||||
|
run: |
|
||||||
|
echo "✅ Database Adapter MyPy Check Passed!"
|
||||||
|
echo ""
|
||||||
|
echo "🔍 Auto-Discovery Approach:"
|
||||||
|
echo " • Vector Adapters: cognee/infrastructure/databases/vector/**/*Adapter.py"
|
||||||
|
echo " • Graph Adapters: cognee/infrastructure/databases/graph/**/*adapter.py"
|
||||||
|
echo " • Hybrid Adapters: cognee/infrastructure/databases/hybrid/**/*Adapter.py"
|
||||||
|
echo ""
|
||||||
|
echo "🚀 Using Dedicated Scripts:"
|
||||||
|
echo " • Vector: ./tools/check_vector_adapters.sh"
|
||||||
|
echo " • Graph: ./tools/check_graph_adapters.sh"
|
||||||
|
echo " • Hybrid: ./tools/check_hybrid_adapters.sh"
|
||||||
|
echo " • All: ./tools/check_all_adapters.sh"
|
||||||
|
echo ""
|
||||||
|
echo "🎯 Purpose: Enforce that database adapters are properly typed"
|
||||||
|
echo "🔧 MyPy Configuration: mypy.ini (strict mode enabled)"
|
||||||
|
echo "🚀 Maintenance-Free: Automatically discovers new adapters"
|
||||||
|
echo ""
|
||||||
|
echo "⚠️ This workflow FAILS on any type errors to ensure adapter quality."
|
||||||
|
echo " All database adapters must be properly typed."
|
||||||
|
echo ""
|
||||||
|
echo "🛠️ To fix type issues locally, run:"
|
||||||
|
echo " ./tools/check_all_adapters.sh # Check all adapters"
|
||||||
|
echo " ./tools/check_vector_adapters.sh # Check vector adapters only"
|
||||||
|
echo " mypy <adapter_file_path> --config-file mypy.ini # Check specific file"
|
||||||
|
|
@ -118,7 +118,7 @@ class HealthChecker:
|
||||||
|
|
||||||
# Test basic operation with actual graph query
|
# Test basic operation with actual graph query
|
||||||
if hasattr(engine, "execute"):
|
if hasattr(engine, "execute"):
|
||||||
# For SQL-like graph DBs (Neo4j, Memgraph)
|
# For SQL-like graph DBs (Neo4j)
|
||||||
await engine.execute("MATCH () RETURN count(*) LIMIT 1")
|
await engine.execute("MATCH () RETURN count(*) LIMIT 1")
|
||||||
elif hasattr(engine, "query"):
|
elif hasattr(engine, "query"):
|
||||||
# For other graph engines
|
# For other graph engines
|
||||||
|
|
|
||||||
|
|
@ -179,5 +179,5 @@ def create_graph_engine(
|
||||||
|
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
f"Unsupported graph database provider: {graph_database_provider}. "
|
f"Unsupported graph database provider: {graph_database_provider}. "
|
||||||
f"Supported providers are: {', '.join(list(supported_databases.keys()) + ['neo4j', 'falkordb', 'kuzu', 'kuzu-remote', 'memgraph', 'neptune', 'neptune_analytics'])}"
|
f"Supported providers are: {', '.join(list(supported_databases.keys()) + ['neo4j', 'falkordb', 'kuzu', 'kuzu-remote', 'neptune', 'neptune_analytics'])}"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from kuzu.database import Database
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import Dict, Any, List, Union, Optional, Tuple, Type
|
from typing import Dict, Any, List, Union, Optional, Tuple, Type, AsyncGenerator
|
||||||
|
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.infrastructure.utils.run_sync import run_sync
|
from cognee.infrastructure.utils.run_sync import run_sync
|
||||||
|
|
@ -22,7 +22,7 @@ from cognee.infrastructure.databases.graph.graph_db_interface import (
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
from cognee.modules.storage.utils import JSONEncoder
|
from cognee.modules.storage.utils import JSONEncoder
|
||||||
from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int
|
from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int
|
||||||
from cognee.tasks.temporal_graph.models import Timestamp
|
from cognee.modules.engine.models.Timestamp import Timestamp
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
@ -146,15 +146,21 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
logger.error(f"Failed to initialize Kuzu database: {e}")
|
logger.error(f"Failed to initialize Kuzu database: {e}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
def _get_connection(self) -> Connection:
|
||||||
|
"""Get the connection to the Kuzu database."""
|
||||||
|
if not self.connection:
|
||||||
|
raise RuntimeError("Kuzu database connection not initialized")
|
||||||
|
return self.connection
|
||||||
|
|
||||||
async def push_to_s3(self) -> None:
|
async def push_to_s3(self) -> None:
|
||||||
if os.getenv("STORAGE_BACKEND", "").lower() == "s3" and hasattr(self, "temp_graph_file"):
|
if os.getenv("STORAGE_BACKEND", "").lower() == "s3" and hasattr(self, "temp_graph_file"):
|
||||||
from cognee.infrastructure.files.storage.S3FileStorage import S3FileStorage
|
from cognee.infrastructure.files.storage.S3FileStorage import S3FileStorage
|
||||||
|
|
||||||
s3_file_storage = S3FileStorage("")
|
s3_file_storage = S3FileStorage("")
|
||||||
|
|
||||||
if self.connection:
|
if self._get_connection():
|
||||||
async with self.KUZU_ASYNC_LOCK:
|
async with self.KUZU_ASYNC_LOCK:
|
||||||
self.connection.execute("CHECKPOINT;")
|
self._get_connection().execute("CHECKPOINT;")
|
||||||
|
|
||||||
s3_file_storage.s3.put(self.temp_graph_file, self.db_path, recursive=True)
|
s3_file_storage.s3.put(self.temp_graph_file, self.db_path, recursive=True)
|
||||||
|
|
||||||
|
|
@ -167,7 +173,9 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
logger.warning(f"Kuzu S3 storage file not found: {self.db_path}")
|
logger.warning(f"Kuzu S3 storage file not found: {self.db_path}")
|
||||||
|
|
||||||
async def query(self, query: str, params: Optional[dict] = None) -> List[Tuple]:
|
async def query(
|
||||||
|
self, query: str, params: Optional[Dict[str, Any]] = None
|
||||||
|
) -> List[Tuple[Any, ...]]:
|
||||||
"""
|
"""
|
||||||
Execute a Kuzu query asynchronously with automatic reconnection.
|
Execute a Kuzu query asynchronously with automatic reconnection.
|
||||||
|
|
||||||
|
|
@ -190,23 +198,32 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
params = params or {}
|
params = params or {}
|
||||||
|
|
||||||
def blocking_query():
|
def blocking_query() -> List[Tuple[Any, ...]]:
|
||||||
try:
|
try:
|
||||||
if not self.connection:
|
if not self._get_connection():
|
||||||
logger.debug("Reconnecting to Kuzu database...")
|
logger.debug("Reconnecting to Kuzu database...")
|
||||||
self._initialize_connection()
|
self._initialize_connection()
|
||||||
|
|
||||||
result = self.connection.execute(query, params)
|
if not self._get_connection():
|
||||||
|
raise RuntimeError("Failed to establish database connection")
|
||||||
|
|
||||||
|
result = self._get_connection().execute(query, params)
|
||||||
rows = []
|
rows = []
|
||||||
|
|
||||||
while result.has_next():
|
if not isinstance(result, list):
|
||||||
row = result.get_next()
|
result = [result]
|
||||||
processed_rows = []
|
|
||||||
for val in row:
|
# Handle QueryResult vs List[QueryResult] union type
|
||||||
if hasattr(val, "as_py"):
|
for single_result in result:
|
||||||
val = val.as_py()
|
while single_result.has_next():
|
||||||
processed_rows.append(val)
|
row = single_result.get_next()
|
||||||
rows.append(tuple(processed_rows))
|
processed_rows = []
|
||||||
|
for val in row:
|
||||||
|
if hasattr(val, "as_py"):
|
||||||
|
val = val.as_py()
|
||||||
|
processed_rows.append(val)
|
||||||
|
rows.append(tuple(processed_rows))
|
||||||
|
|
||||||
return rows
|
return rows
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Query execution failed: {str(e)}")
|
logger.error(f"Query execution failed: {str(e)}")
|
||||||
|
|
@ -215,7 +232,7 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
return await loop.run_in_executor(self.executor, blocking_query)
|
return await loop.run_in_executor(self.executor, blocking_query)
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def get_session(self):
|
async def get_session(self) -> AsyncGenerator[Optional[Connection], None]:
|
||||||
"""
|
"""
|
||||||
Get a database session.
|
Get a database session.
|
||||||
|
|
||||||
|
|
@ -224,7 +241,7 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
and on exit performs cleanup if necessary.
|
and on exit performs cleanup if necessary.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
yield self.connection
|
yield self._get_connection()
|
||||||
finally:
|
finally:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
@ -255,7 +272,7 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
|
|
||||||
def _edge_query_and_params(
|
def _edge_query_and_params(
|
||||||
self, from_node: str, to_node: str, relationship_name: str, properties: Dict[str, Any]
|
self, from_node: str, to_node: str, relationship_name: str, properties: Dict[str, Any]
|
||||||
) -> Tuple[str, dict]:
|
) -> Tuple[str, Dict[str, Any]]:
|
||||||
"""Build the edge creation query and parameters."""
|
"""Build the edge creation query and parameters."""
|
||||||
now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")
|
now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")
|
||||||
query = """
|
query = """
|
||||||
|
|
@ -305,7 +322,9 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
result = await self.query(query_str, {"id": node_id})
|
result = await self.query(query_str, {"id": node_id})
|
||||||
return result[0][0] if result else False
|
return result[0][0] if result else False
|
||||||
|
|
||||||
async def add_node(self, node: DataPoint) -> None:
|
async def add_node(
|
||||||
|
self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Add a single node to the graph if it doesn't exist.
|
Add a single node to the graph if it doesn't exist.
|
||||||
|
|
||||||
|
|
@ -319,20 +338,32 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
- node (DataPoint): The node to be added, represented as a DataPoint.
|
- node (DataPoint): The node to be added, represented as a DataPoint.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
properties = node.model_dump() if hasattr(node, "model_dump") else vars(node)
|
if isinstance(node, str):
|
||||||
|
# Handle string node ID with properties parameter
|
||||||
|
node_properties = properties or {}
|
||||||
|
core_properties = {
|
||||||
|
"id": node,
|
||||||
|
"name": str(node_properties.get("name", "")),
|
||||||
|
"type": str(node_properties.get("type", "")),
|
||||||
|
}
|
||||||
|
# Use the passed properties, excluding core fields
|
||||||
|
other_properties = {
|
||||||
|
k: v for k, v in node_properties.items() if k not in ["id", "name", "type"]
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# Handle DataPoint object
|
||||||
|
node_properties = node.model_dump()
|
||||||
|
core_properties = {
|
||||||
|
"id": str(node_properties.get("id", "")),
|
||||||
|
"name": str(node_properties.get("name", "")),
|
||||||
|
"type": str(node_properties.get("type", "")),
|
||||||
|
}
|
||||||
|
# Remove core fields from other properties
|
||||||
|
other_properties = {
|
||||||
|
k: v for k, v in node_properties.items() if k not in ["id", "name", "type"]
|
||||||
|
}
|
||||||
|
|
||||||
# Extract core fields with defaults if not present
|
core_properties["properties"] = json.dumps(other_properties, cls=JSONEncoder)
|
||||||
core_properties = {
|
|
||||||
"id": str(properties.get("id", "")),
|
|
||||||
"name": str(properties.get("name", "")),
|
|
||||||
"type": str(properties.get("type", "")),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Remove core fields from other properties
|
|
||||||
for key in core_properties:
|
|
||||||
properties.pop(key, None)
|
|
||||||
|
|
||||||
core_properties["properties"] = json.dumps(properties, cls=JSONEncoder)
|
|
||||||
|
|
||||||
# Add timestamps for new node
|
# Add timestamps for new node
|
||||||
now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")
|
now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")
|
||||||
|
|
@ -360,7 +391,7 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
logger.error(f"Failed to add node: {e}")
|
logger.error(f"Failed to add node: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@record_graph_changes
|
@record_graph_changes # type: ignore
|
||||||
async def add_nodes(self, nodes: List[DataPoint]) -> None:
|
async def add_nodes(self, nodes: List[DataPoint]) -> None:
|
||||||
"""
|
"""
|
||||||
Add multiple nodes to the graph in a batch operation.
|
Add multiple nodes to the graph in a batch operation.
|
||||||
|
|
@ -568,7 +599,9 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
)
|
)
|
||||||
return result[0][0] if result else False
|
return result[0][0] if result else False
|
||||||
|
|
||||||
async def has_edges(self, edges: List[Tuple[str, str, str]]) -> List[Tuple[str, str, str]]:
|
async def has_edges(
|
||||||
|
self, edges: List[Tuple[str, str, str, Dict[str, Any]]]
|
||||||
|
) -> List[Tuple[str, str, str, Dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
Check if multiple edges exist in a batch operation.
|
Check if multiple edges exist in a batch operation.
|
||||||
|
|
||||||
|
|
@ -599,7 +632,7 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
"to_id": str(to_node), # Ensure string type
|
"to_id": str(to_node), # Ensure string type
|
||||||
"relationship_name": str(edge_label), # Ensure string type
|
"relationship_name": str(edge_label), # Ensure string type
|
||||||
}
|
}
|
||||||
for from_node, to_node, edge_label in edges
|
for from_node, to_node, edge_label, _ in edges
|
||||||
]
|
]
|
||||||
|
|
||||||
# Batch check query with direct string comparison
|
# Batch check query with direct string comparison
|
||||||
|
|
@ -615,9 +648,21 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
results = await self.query(query, {"edges": edge_params})
|
results = await self.query(query, {"edges": edge_params})
|
||||||
|
|
||||||
# Convert results back to tuples and ensure string types
|
# Convert results back to tuples and ensure string types
|
||||||
existing_edges = [(str(row[0]), str(row[1]), str(row[2])) for row in results]
|
# Find the original edge properties for each existing edge
|
||||||
|
# TODO: get review on this
|
||||||
|
existing_edges = []
|
||||||
|
for row in results:
|
||||||
|
from_id, to_id, rel_name = str(row[0]), str(row[1]), str(row[2])
|
||||||
|
# Find the original properties from the input edges
|
||||||
|
original_props = {}
|
||||||
|
for orig_from, orig_to, orig_rel, orig_props in edges:
|
||||||
|
if orig_from == from_id and orig_to == to_id and orig_rel == rel_name:
|
||||||
|
original_props = orig_props
|
||||||
|
break
|
||||||
|
existing_edges.append((from_id, to_id, rel_name, original_props))
|
||||||
|
|
||||||
logger.debug(f"Found {len(existing_edges)} existing edges out of {len(edges)} checked")
|
logger.debug(f"Found {len(existing_edges)} existing edges out of {len(edges)} checked")
|
||||||
|
# TODO: otherwise, we can just return dummy properties since they are not used apparently
|
||||||
return existing_edges
|
return existing_edges
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -626,10 +671,10 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
|
|
||||||
async def add_edge(
|
async def add_edge(
|
||||||
self,
|
self,
|
||||||
from_node: str,
|
source_id: str,
|
||||||
to_node: str,
|
target_id: str,
|
||||||
relationship_name: str,
|
relationship_name: str,
|
||||||
edge_properties: Dict[str, Any] = {},
|
properties: Optional[Dict[str, Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Add an edge between two nodes.
|
Add an edge between two nodes.
|
||||||
|
|
@ -641,23 +686,23 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
Parameters:
|
Parameters:
|
||||||
-----------
|
-----------
|
||||||
|
|
||||||
- from_node (str): The identifier of the source node from which the edge originates.
|
- source_id (str): The identifier of the source node from which the edge originates.
|
||||||
- to_node (str): The identifier of the target node to which the edge points.
|
- target_id (str): The identifier of the target node to which the edge points.
|
||||||
- relationship_name (str): The label of the edge to be created, representing the
|
- relationship_name (str): The label of the edge to be created, representing the
|
||||||
relationship name.
|
relationship name.
|
||||||
- edge_properties (Dict[str, Any]): A dictionary containing properties for the edge.
|
- properties (Optional[Dict[str, Any]]): A dictionary containing properties for the edge.
|
||||||
(default {})
|
(default None)
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
query, params = self._edge_query_and_params(
|
query, params = self._edge_query_and_params(
|
||||||
from_node, to_node, relationship_name, edge_properties
|
source_id, target_id, relationship_name, properties or {}
|
||||||
)
|
)
|
||||||
await self.query(query, params)
|
await self.query(query, params)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to add edge: {e}")
|
logger.error(f"Failed to add edge: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@record_graph_changes
|
@record_graph_changes # type: ignore
|
||||||
async def add_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> None:
|
async def add_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> None:
|
||||||
"""
|
"""
|
||||||
Add multiple edges in a batch operation.
|
Add multiple edges in a batch operation.
|
||||||
|
|
@ -712,7 +757,7 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
logger.error(f"Failed to add edges in batch: {e}")
|
logger.error(f"Failed to add edges in batch: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def get_edges(self, node_id: str) -> List[Tuple[Dict[str, Any], str, Dict[str, Any]]]:
|
async def get_edges(self, node_id: str) -> List[Tuple[str, str, str, Dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
Get all edges connected to a node.
|
Get all edges connected to a node.
|
||||||
|
|
||||||
|
|
@ -727,9 +772,8 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
Returns:
|
Returns:
|
||||||
--------
|
--------
|
||||||
|
|
||||||
- List[Tuple[Dict[str, Any], str, Dict[str, Any]]]: A list of tuples where each
|
- List[Tuple[str, str, str, Dict[str, Any]]]: A list of tuples where each
|
||||||
tuple contains (source_node, relationship_name, target_node), with source_node and
|
tuple contains (source_id, relationship_name, target_id, edge_properties).
|
||||||
target_node as dictionaries of node properties.
|
|
||||||
"""
|
"""
|
||||||
query_str = """
|
query_str = """
|
||||||
MATCH (n:Node)-[r]-(m:Node)
|
MATCH (n:Node)-[r]-(m:Node)
|
||||||
|
|
@ -750,12 +794,14 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
results = await self.query(query_str, {"node_id": node_id})
|
results = await self.query(query_str, {"node_id": node_id})
|
||||||
edges = []
|
edges: List[Tuple[str, str, str, Dict[str, Any]]] = []
|
||||||
for row in results:
|
for row in results:
|
||||||
if row and len(row) == 3:
|
if row and len(row) == 3:
|
||||||
source_node = self._parse_node_properties(row[0])
|
source_node = self._parse_node_properties(row[0])
|
||||||
|
relationship_name = row[1]
|
||||||
target_node = self._parse_node_properties(row[2])
|
target_node = self._parse_node_properties(row[2])
|
||||||
edges.append((source_node, row[1], target_node))
|
# TODO: any edge properties we can add? Adding empty to avoid modifying query without reason
|
||||||
|
edges.append((source_node, relationship_name, target_node, {})) # type: ignore # currently each node is a dict, wihle typing expects nodes to be strings
|
||||||
return edges
|
return edges
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to get edges for node {node_id}: {e}")
|
logger.error(f"Failed to get edges for node {node_id}: {e}")
|
||||||
|
|
@ -977,7 +1023,7 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def get_connections(
|
async def get_connections(
|
||||||
self, node_id: str
|
self, node_id: Union[str, UUID]
|
||||||
) -> List[Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]]:
|
) -> List[Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
Get all nodes connected to a given node.
|
Get all nodes connected to a given node.
|
||||||
|
|
@ -1019,7 +1065,9 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
results = await self.query(query_str, {"node_id": node_id})
|
# Convert UUID to string if needed
|
||||||
|
node_id_str = str(node_id)
|
||||||
|
results = await self.query(query_str, {"node_id": node_id_str})
|
||||||
edges = []
|
edges = []
|
||||||
for row in results:
|
for row in results:
|
||||||
if row and len(row) == 3:
|
if row and len(row) == 3:
|
||||||
|
|
@ -1177,7 +1225,7 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
|
|
||||||
async def get_nodeset_subgraph(
|
async def get_nodeset_subgraph(
|
||||||
self, node_type: Type[Any], node_name: List[str]
|
self, node_type: Type[Any], node_name: List[str]
|
||||||
) -> Tuple[List[Tuple[str, dict]], List[Tuple[str, str, str, dict]]]:
|
) -> Tuple[List[Tuple[int, Dict[str, Any]]], List[Tuple[int, int, str, Dict[str, Any]]]]:
|
||||||
"""
|
"""
|
||||||
Get subgraph for a set of nodes based on type and names.
|
Get subgraph for a set of nodes based on type and names.
|
||||||
|
|
||||||
|
|
@ -1225,9 +1273,9 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
RETURN n.id, n.name, n.type, n.properties
|
RETURN n.id, n.name, n.type, n.properties
|
||||||
"""
|
"""
|
||||||
node_rows = await self.query(nodes_query, {"ids": all_ids})
|
node_rows = await self.query(nodes_query, {"ids": all_ids})
|
||||||
nodes: List[Tuple[str, dict]] = []
|
nodes: List[Tuple[str, Dict[str, Any]]] = []
|
||||||
for node_id, name, typ, props in node_rows:
|
for node_id, name, typ, props in node_rows:
|
||||||
data = {"id": node_id, "name": name, "type": typ}
|
data: Dict[str, Any] = {"id": node_id, "name": name, "type": typ}
|
||||||
if props:
|
if props:
|
||||||
try:
|
try:
|
||||||
data.update(json.loads(props))
|
data.update(json.loads(props))
|
||||||
|
|
@ -1241,22 +1289,22 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
RETURN a.id, b.id, r.relationship_name, r.properties
|
RETURN a.id, b.id, r.relationship_name, r.properties
|
||||||
"""
|
"""
|
||||||
edge_rows = await self.query(edges_query, {"ids": all_ids})
|
edge_rows = await self.query(edges_query, {"ids": all_ids})
|
||||||
edges: List[Tuple[str, str, str, dict]] = []
|
edges: List[Tuple[str, str, str, Dict[str, Any]]] = []
|
||||||
for from_id, to_id, rel_type, props in edge_rows:
|
for from_id, to_id, rel_type, props in edge_rows:
|
||||||
data = {}
|
edge_data: Dict[str, Any] = {}
|
||||||
if props:
|
if props:
|
||||||
try:
|
try:
|
||||||
data = json.loads(props)
|
edge_data = json.loads(props)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logger.warning(f"Failed to parse JSON props for edge {from_id}->{to_id}")
|
logger.warning(f"Failed to parse JSON props for edge {from_id}->{to_id}")
|
||||||
|
|
||||||
edges.append((from_id, to_id, rel_type, data))
|
edges.append((from_id, to_id, rel_type, edge_data))
|
||||||
|
|
||||||
return nodes, edges
|
return nodes, edges # type: ignore # Interface expects int IDs but string IDs are more natural for graph DBs
|
||||||
|
|
||||||
async def get_filtered_graph_data(
|
async def get_filtered_graph_data(
|
||||||
self, attribute_filters: List[Dict[str, List[Union[str, int]]]]
|
self, attribute_filters: List[Dict[str, List[Union[str, int]]]]
|
||||||
):
|
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
Get filtered nodes and relationships based on attributes.
|
Get filtered nodes and relationships based on attributes.
|
||||||
|
|
||||||
|
|
@ -1299,7 +1347,7 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
)
|
)
|
||||||
return ([n[0] for n in nodes], [e[0] for e in edges])
|
return ([n[0] for n in nodes], [e[0] for e in edges])
|
||||||
|
|
||||||
async def get_graph_metrics(self, include_optional=False) -> Dict[str, Any]:
|
async def get_graph_metrics(self, include_optional: bool = False) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Get metrics on graph structure and connectivity.
|
Get metrics on graph structure and connectivity.
|
||||||
|
|
||||||
|
|
@ -1322,8 +1370,8 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
try:
|
try:
|
||||||
# Get basic graph data
|
# Get basic graph data
|
||||||
nodes, edges = await self.get_model_independent_graph_data()
|
nodes, edges = await self.get_model_independent_graph_data()
|
||||||
num_nodes = len(nodes[0]["nodes"]) if nodes else 0
|
num_nodes = len(nodes[0]["nodes"]) if nodes else 0 # type: ignore # nodes is type string?
|
||||||
num_edges = len(edges[0]["elements"]) if edges else 0
|
num_edges = len(edges[0]["elements"]) if edges else 0 # type: ignore # edges is type string?
|
||||||
|
|
||||||
# Calculate mandatory metrics
|
# Calculate mandatory metrics
|
||||||
mandatory_metrics = {
|
mandatory_metrics = {
|
||||||
|
|
@ -1483,8 +1531,8 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
It raises exceptions for failures occurring during deletion processes.
|
It raises exceptions for failures occurring during deletion processes.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if self.connection:
|
if self._get_connection():
|
||||||
self.connection.close()
|
self._get_connection().close()
|
||||||
self.connection = None
|
self.connection = None
|
||||||
if self.db:
|
if self.db:
|
||||||
self.db.close()
|
self.db.close()
|
||||||
|
|
@ -1515,7 +1563,7 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
occur during file deletions or initializations carefully.
|
occur during file deletions or initializations carefully.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if self.connection:
|
if self._get_connection():
|
||||||
self.connection = None
|
self.connection = None
|
||||||
if self.db:
|
if self.db:
|
||||||
self.db.close()
|
self.db.close()
|
||||||
|
|
@ -1531,20 +1579,30 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
|
|
||||||
# Reinitialize the database
|
# Reinitialize the database
|
||||||
self._initialize_connection()
|
self._initialize_connection()
|
||||||
|
|
||||||
|
if not self._get_connection():
|
||||||
|
raise RuntimeError("Failed to establish database connection")
|
||||||
|
|
||||||
# Verify the database is empty
|
# Verify the database is empty
|
||||||
result = self.connection.execute("MATCH (n:Node) RETURN COUNT(n)")
|
result = self._get_connection().execute("MATCH (n:Node) RETURN COUNT(n)")
|
||||||
count = result.get_next()[0] if result.has_next() else 0
|
if not isinstance(result, list):
|
||||||
|
result = [result]
|
||||||
|
for single_result in result:
|
||||||
|
_next = single_result.get_next()
|
||||||
|
if not isinstance(_next, list):
|
||||||
|
raise RuntimeError("Expected list of results")
|
||||||
|
count = _next[0] if _next else 0
|
||||||
if count > 0:
|
if count > 0:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Database still contains {count} nodes after clearing, forcing deletion"
|
f"Database still contains {count} nodes after clearing, forcing deletion"
|
||||||
)
|
)
|
||||||
self.connection.execute("MATCH (n:Node) DETACH DELETE n")
|
self._get_connection().execute("MATCH (n:Node) DETACH DELETE n")
|
||||||
logger.info("Database cleared successfully")
|
logger.info("Database cleared successfully")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error during database clearing: {e}")
|
logger.error(f"Error during database clearing: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def get_document_subgraph(self, data_id: str):
|
async def get_document_subgraph(self, data_id: str) -> Optional[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Get all nodes that should be deleted when removing a document.
|
Get all nodes that should be deleted when removing a document.
|
||||||
|
|
||||||
|
|
@ -1616,7 +1674,7 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
"orphan_types": result[0][4],
|
"orphan_types": result[0][4],
|
||||||
}
|
}
|
||||||
|
|
||||||
async def get_degree_one_nodes(self, node_type: str):
|
async def get_degree_one_nodes(self, node_type: str) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Get all nodes that have only one connection.
|
Get all nodes that have only one connection.
|
||||||
|
|
||||||
|
|
@ -1769,8 +1827,8 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
ids: List[str] = []
|
ids: List[str] = []
|
||||||
|
|
||||||
if time_from and time_to:
|
if time_from and time_to:
|
||||||
time_from = date_to_int(time_from)
|
time_from_int = date_to_int(time_from)
|
||||||
time_to = date_to_int(time_to)
|
time_to_int = date_to_int(time_to)
|
||||||
|
|
||||||
cypher = f"""
|
cypher = f"""
|
||||||
MATCH (n:Node)
|
MATCH (n:Node)
|
||||||
|
|
@ -1782,13 +1840,13 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
WHEN t_str IS NULL OR t_str = '' THEN NULL
|
WHEN t_str IS NULL OR t_str = '' THEN NULL
|
||||||
ELSE CAST(t_str AS INT64)
|
ELSE CAST(t_str AS INT64)
|
||||||
END AS t
|
END AS t
|
||||||
WHERE t >= {time_from}
|
WHERE t >= {time_from_int}
|
||||||
AND t <= {time_to}
|
AND t <= {time_to_int}
|
||||||
RETURN n.id as id
|
RETURN n.id as id
|
||||||
"""
|
"""
|
||||||
|
|
||||||
elif time_from:
|
elif time_from:
|
||||||
time_from = date_to_int(time_from)
|
time_from_int = date_to_int(time_from)
|
||||||
|
|
||||||
cypher = f"""
|
cypher = f"""
|
||||||
MATCH (n:Node)
|
MATCH (n:Node)
|
||||||
|
|
@ -1800,12 +1858,12 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
WHEN t_str IS NULL OR t_str = '' THEN NULL
|
WHEN t_str IS NULL OR t_str = '' THEN NULL
|
||||||
ELSE CAST(t_str AS INT64)
|
ELSE CAST(t_str AS INT64)
|
||||||
END AS t
|
END AS t
|
||||||
WHERE t >= {time_from}
|
WHERE t >= {time_from_int}
|
||||||
RETURN n.id as id
|
RETURN n.id as id
|
||||||
"""
|
"""
|
||||||
|
|
||||||
elif time_to:
|
elif time_to:
|
||||||
time_to = date_to_int(time_to)
|
time_to_int = date_to_int(time_to)
|
||||||
|
|
||||||
cypher = f"""
|
cypher = f"""
|
||||||
MATCH (n:Node)
|
MATCH (n:Node)
|
||||||
|
|
@ -1817,12 +1875,12 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
WHEN t_str IS NULL OR t_str = '' THEN NULL
|
WHEN t_str IS NULL OR t_str = '' THEN NULL
|
||||||
ELSE CAST(t_str AS INT64)
|
ELSE CAST(t_str AS INT64)
|
||||||
END AS t
|
END AS t
|
||||||
WHERE t <= {time_to}
|
WHERE t <= {time_to_int}
|
||||||
RETURN n.id as id
|
RETURN n.id as id
|
||||||
"""
|
"""
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return ids
|
return ", ".join(f"'{uid}'" for uid in ids)
|
||||||
|
|
||||||
time_nodes = await self.query(cypher)
|
time_nodes = await self.query(cypher)
|
||||||
time_ids_list = [item[0] for item in time_nodes]
|
time_ids_list = [item[0] for item in time_nodes]
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
import json
|
import json
|
||||||
from typing import Dict, Any, List, Optional, Tuple
|
from typing import Dict, Any, List, Optional, Tuple, Union
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
|
|
@ -14,7 +14,7 @@ logger = get_logger()
|
||||||
class UUIDEncoder(json.JSONEncoder):
|
class UUIDEncoder(json.JSONEncoder):
|
||||||
"""Custom JSON encoder that handles UUID objects."""
|
"""Custom JSON encoder that handles UUID objects."""
|
||||||
|
|
||||||
def default(self, obj):
|
def default(self, obj: Union[UUID, Any]) -> Any:
|
||||||
if isinstance(obj, UUID):
|
if isinstance(obj, UUID):
|
||||||
return str(obj)
|
return str(obj)
|
||||||
return super().default(obj)
|
return super().default(obj)
|
||||||
|
|
@ -36,7 +36,7 @@ class RemoteKuzuAdapter(KuzuAdapter):
|
||||||
self.api_url = api_url
|
self.api_url = api_url
|
||||||
self.username = username
|
self.username = username
|
||||||
self.password = password
|
self.password = password
|
||||||
self._session = None
|
self._session: Optional[aiohttp.ClientSession] = None
|
||||||
self._schema_initialized = False
|
self._schema_initialized = False
|
||||||
|
|
||||||
async def _get_session(self) -> aiohttp.ClientSession:
|
async def _get_session(self) -> aiohttp.ClientSession:
|
||||||
|
|
@ -45,13 +45,13 @@ class RemoteKuzuAdapter(KuzuAdapter):
|
||||||
self._session = aiohttp.ClientSession()
|
self._session = aiohttp.ClientSession()
|
||||||
return self._session
|
return self._session
|
||||||
|
|
||||||
async def close(self):
|
async def close(self) -> None:
|
||||||
"""Close the adapter and its session."""
|
"""Close the adapter and its session."""
|
||||||
if self._session and not self._session.closed:
|
if self._session and not self._session.closed:
|
||||||
await self._session.close()
|
await self._session.close()
|
||||||
self._session = None
|
self._session = None
|
||||||
|
|
||||||
async def _make_request(self, endpoint: str, data: dict) -> dict:
|
async def _make_request(self, endpoint: str, data: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""Make a request to the Kuzu API."""
|
"""Make a request to the Kuzu API."""
|
||||||
url = f"{self.api_url}{endpoint}"
|
url = f"{self.api_url}{endpoint}"
|
||||||
session = await self._get_session()
|
session = await self._get_session()
|
||||||
|
|
@ -73,13 +73,15 @@ class RemoteKuzuAdapter(KuzuAdapter):
|
||||||
status=response.status,
|
status=response.status,
|
||||||
message=error_detail,
|
message=error_detail,
|
||||||
)
|
)
|
||||||
return await response.json()
|
return await response.json() # type: ignore
|
||||||
except aiohttp.ClientError as e:
|
except aiohttp.ClientError as e:
|
||||||
logger.error(f"API request failed: {str(e)}")
|
logger.error(f"API request failed: {str(e)}")
|
||||||
logger.error(f"Request data: {data}")
|
logger.error(f"Request data: {data}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def query(self, query: str, params: Optional[dict] = None) -> List[Tuple]:
|
async def query(
|
||||||
|
self, query: str, params: Optional[dict[str, Any]] = None
|
||||||
|
) -> List[Tuple[Any, ...]]:
|
||||||
"""Execute a Kuzu query via the REST API."""
|
"""Execute a Kuzu query via the REST API."""
|
||||||
try:
|
try:
|
||||||
# Initialize schema if needed
|
# Initialize schema if needed
|
||||||
|
|
@ -126,7 +128,7 @@ class RemoteKuzuAdapter(KuzuAdapter):
|
||||||
logger.error(f"Failed to check schema: {e}")
|
logger.error(f"Failed to check schema: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def _create_schema(self):
|
async def _create_schema(self) -> None:
|
||||||
"""Create the required schema tables."""
|
"""Create the required schema tables."""
|
||||||
try:
|
try:
|
||||||
# Create Node table if it doesn't exist
|
# Create Node table if it doesn't exist
|
||||||
|
|
@ -180,7 +182,7 @@ class RemoteKuzuAdapter(KuzuAdapter):
|
||||||
logger.error(f"Failed to create schema: {e}")
|
logger.error(f"Failed to create schema: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def _initialize_schema(self):
|
async def _initialize_schema(self) -> None:
|
||||||
"""Initialize the database schema if it doesn't exist."""
|
"""Initialize the database schema if it doesn't exist."""
|
||||||
if self._schema_initialized:
|
if self._schema_initialized:
|
||||||
return
|
return
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -8,11 +8,11 @@ from neo4j import AsyncSession
|
||||||
from neo4j import AsyncGraphDatabase
|
from neo4j import AsyncGraphDatabase
|
||||||
from neo4j.exceptions import Neo4jError
|
from neo4j.exceptions import Neo4jError
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import Optional, Any, List, Dict, Type, Tuple
|
from typing import Optional, Any, List, Dict, Type, Tuple, Union, AsyncGenerator
|
||||||
|
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int
|
from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int
|
||||||
from cognee.tasks.temporal_graph.models import Timestamp
|
from cognee.modules.engine.models.Timestamp import Timestamp
|
||||||
from cognee.shared.logging_utils import get_logger, ERROR
|
from cognee.shared.logging_utils import get_logger, ERROR
|
||||||
from cognee.infrastructure.databases.graph.graph_db_interface import (
|
from cognee.infrastructure.databases.graph.graph_db_interface import (
|
||||||
GraphDBInterface,
|
GraphDBInterface,
|
||||||
|
|
@ -79,14 +79,14 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
)
|
)
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def get_session(self) -> AsyncSession:
|
async def get_session(self) -> AsyncGenerator[AsyncSession, None]:
|
||||||
"""
|
"""
|
||||||
Get a session for database operations.
|
Get a session for database operations.
|
||||||
"""
|
"""
|
||||||
async with self.driver.session(database=self.graph_database_name) as session:
|
async with self.driver.session(database=self.graph_database_name) as session:
|
||||||
yield session
|
yield session
|
||||||
|
|
||||||
@deadlock_retry()
|
@deadlock_retry() # type: ignore
|
||||||
async def query(
|
async def query(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
|
|
@ -112,6 +112,7 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
async with self.get_session() as session:
|
async with self.get_session() as session:
|
||||||
result = await session.run(query, parameters=params)
|
result = await session.run(query, parameters=params)
|
||||||
data = await result.data()
|
data = await result.data()
|
||||||
|
# TODO: why we don't return List[Dict[str, Any]]?
|
||||||
return data
|
return data
|
||||||
except Neo4jError as error:
|
except Neo4jError as error:
|
||||||
logger.error("Neo4j query error: %s", error, exc_info=True)
|
logger.error("Neo4j query error: %s", error, exc_info=True)
|
||||||
|
|
@ -141,21 +142,29 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
)
|
)
|
||||||
return results[0]["node_exists"] if len(results) > 0 else False
|
return results[0]["node_exists"] if len(results) > 0 else False
|
||||||
|
|
||||||
async def add_node(self, node: DataPoint):
|
async def add_node(
|
||||||
|
self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Add a new node to the database based on the provided DataPoint object.
|
Add a new node to the database based on the provided DataPoint object or string ID.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
-----------
|
-----------
|
||||||
|
|
||||||
- node (DataPoint): An instance of DataPoint representing the node to add.
|
- node (Union[DataPoint, str]): An instance of DataPoint or string ID representing the node to add.
|
||||||
|
- properties (Optional[Dict[str, Any]]): Properties to set on the node when node is a string ID.
|
||||||
Returns:
|
|
||||||
--------
|
|
||||||
|
|
||||||
The result of the query execution, typically the ID of the added node.
|
|
||||||
"""
|
"""
|
||||||
serialized_properties = self.serialize_properties(node.model_dump())
|
if isinstance(node, str):
|
||||||
|
# TODO: this was not handled in the original code, check if it is correct
|
||||||
|
# Handle string node ID with properties parameter
|
||||||
|
node_id = node
|
||||||
|
node_label = "Node" # Default label for string nodes
|
||||||
|
serialized_properties = self.serialize_properties(properties or {})
|
||||||
|
else:
|
||||||
|
# Handle DataPoint object
|
||||||
|
node_id = str(node.id)
|
||||||
|
node_label = type(node).__name__
|
||||||
|
serialized_properties = self.serialize_properties(node.model_dump())
|
||||||
|
|
||||||
query = dedent(
|
query = dedent(
|
||||||
f"""MERGE (node: `{BASE_LABEL}`{{id: $node_id}})
|
f"""MERGE (node: `{BASE_LABEL}`{{id: $node_id}})
|
||||||
|
|
@ -167,16 +176,16 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
)
|
)
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"node_id": str(node.id),
|
"node_id": node_id,
|
||||||
"node_label": type(node).__name__,
|
"node_label": node_label,
|
||||||
"properties": serialized_properties,
|
"properties": serialized_properties,
|
||||||
}
|
}
|
||||||
|
|
||||||
return await self.query(query, params)
|
await self.query(query, params)
|
||||||
|
|
||||||
@record_graph_changes
|
@record_graph_changes # type: ignore
|
||||||
@override_distributed(queued_add_nodes)
|
@override_distributed(queued_add_nodes) # type: ignore
|
||||||
async def add_nodes(self, nodes: list[DataPoint]) -> None:
|
async def add_nodes(self, nodes: List[DataPoint]) -> None:
|
||||||
"""
|
"""
|
||||||
Add multiple nodes to the database in a single query.
|
Add multiple nodes to the database in a single query.
|
||||||
|
|
||||||
|
|
@ -201,7 +210,7 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
RETURN ID(labeledNode) AS internal_id, labeledNode.id AS nodeId
|
RETURN ID(labeledNode) AS internal_id, labeledNode.id AS nodeId
|
||||||
"""
|
"""
|
||||||
|
|
||||||
nodes = [
|
node_params = [
|
||||||
{
|
{
|
||||||
"node_id": str(node.id),
|
"node_id": str(node.id),
|
||||||
"label": type(node).__name__,
|
"label": type(node).__name__,
|
||||||
|
|
@ -210,10 +219,9 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
for node in nodes
|
for node in nodes
|
||||||
]
|
]
|
||||||
|
|
||||||
results = await self.query(query, dict(nodes=nodes))
|
await self.query(query, dict(nodes=node_params))
|
||||||
return results
|
|
||||||
|
|
||||||
async def extract_node(self, node_id: str):
|
async def extract_node(self, node_id: str) -> Optional[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Retrieve a single node from the database by its ID.
|
Retrieve a single node from the database by its ID.
|
||||||
|
|
||||||
|
|
@ -231,7 +239,7 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
|
|
||||||
return results[0] if len(results) > 0 else None
|
return results[0] if len(results) > 0 else None
|
||||||
|
|
||||||
async def extract_nodes(self, node_ids: List[str]):
|
async def extract_nodes(self, node_ids: List[str]) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Retrieve multiple nodes from the database by their IDs.
|
Retrieve multiple nodes from the database by their IDs.
|
||||||
|
|
||||||
|
|
@ -256,7 +264,7 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
|
|
||||||
return [result["node"] for result in results]
|
return [result["node"] for result in results]
|
||||||
|
|
||||||
async def delete_node(self, node_id: str):
|
async def delete_node(self, node_id: str) -> None:
|
||||||
"""
|
"""
|
||||||
Remove a node from the database identified by its ID.
|
Remove a node from the database identified by its ID.
|
||||||
|
|
||||||
|
|
@ -273,7 +281,7 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
query = f"MATCH (node: `{BASE_LABEL}`{{id: $node_id}}) DETACH DELETE node"
|
query = f"MATCH (node: `{BASE_LABEL}`{{id: $node_id}}) DETACH DELETE node"
|
||||||
params = {"node_id": node_id}
|
params = {"node_id": node_id}
|
||||||
|
|
||||||
return await self.query(query, params)
|
await self.query(query, params)
|
||||||
|
|
||||||
async def delete_nodes(self, node_ids: list[str]) -> None:
|
async def delete_nodes(self, node_ids: list[str]) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
@ -296,18 +304,18 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
|
|
||||||
params = {"node_ids": node_ids}
|
params = {"node_ids": node_ids}
|
||||||
|
|
||||||
return await self.query(query, params)
|
await self.query(query, params)
|
||||||
|
|
||||||
async def has_edge(self, from_node: UUID, to_node: UUID, edge_label: str) -> bool:
|
async def has_edge(self, source_id: str, target_id: str, relationship_name: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if an edge exists between two nodes with the specified IDs and edge label.
|
Check if an edge exists between two nodes with the specified IDs and edge label.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
-----------
|
-----------
|
||||||
|
|
||||||
- from_node (UUID): The ID of the node from which the edge originates.
|
- source_id (str): The ID of the node from which the edge originates.
|
||||||
- to_node (UUID): The ID of the node to which the edge points.
|
- target_id (str): The ID of the node to which the edge points.
|
||||||
- edge_label (str): The label of the edge to check for existence.
|
- relationship_name (str): The label of the edge to check for existence.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
--------
|
--------
|
||||||
|
|
@ -315,27 +323,28 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
- bool: True if the edge exists, otherwise False.
|
- bool: True if the edge exists, otherwise False.
|
||||||
"""
|
"""
|
||||||
query = f"""
|
query = f"""
|
||||||
MATCH (from_node: `{BASE_LABEL}`)-[:`{edge_label}`]->(to_node: `{BASE_LABEL}`)
|
MATCH (from_node: `{BASE_LABEL}`)-[:`{relationship_name}`]->(to_node: `{BASE_LABEL}`)
|
||||||
WHERE from_node.id = $from_node_id AND to_node.id = $to_node_id
|
WHERE from_node.id = $source_id AND to_node.id = $target_id
|
||||||
RETURN COUNT(relationship) > 0 AS edge_exists
|
RETURN COUNT(relationship) > 0 AS edge_exists
|
||||||
"""
|
"""
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"from_node_id": str(from_node),
|
"source_id": str(source_id),
|
||||||
"to_node_id": str(to_node),
|
"target_id": str(target_id),
|
||||||
}
|
}
|
||||||
|
|
||||||
edge_exists = await self.query(query, params)
|
edge_exists = await self.query(query, params)
|
||||||
|
assert isinstance(edge_exists, bool), "Edge existence check should return a boolean"
|
||||||
return edge_exists
|
return edge_exists
|
||||||
|
|
||||||
async def has_edges(self, edges):
|
async def has_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> List[bool]:
|
||||||
"""
|
"""
|
||||||
Check if multiple edges exist based on provided edge criteria.
|
Check if multiple edges exist based on provided edge criteria.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
-----------
|
-----------
|
||||||
|
|
||||||
- edges: A list of edge specifications to check for existence.
|
- edges: A list of edge specifications to check for existence. (source_id, target_id, relationship_name, properties)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
--------
|
--------
|
||||||
|
|
@ -369,29 +378,24 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
|
|
||||||
async def add_edge(
|
async def add_edge(
|
||||||
self,
|
self,
|
||||||
from_node: UUID,
|
source_id: str,
|
||||||
to_node: UUID,
|
target_id: str,
|
||||||
relationship_name: str,
|
relationship_name: str,
|
||||||
edge_properties: Optional[Dict[str, Any]] = {},
|
properties: Optional[Dict[str, Any]] = None,
|
||||||
):
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Create a new edge between two nodes with specified properties.
|
Create a new edge between two nodes with specified properties.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
-----------
|
-----------
|
||||||
|
|
||||||
- from_node (UUID): The ID of the source node of the edge.
|
- source_id (str): The ID of the source node of the edge.
|
||||||
- to_node (UUID): The ID of the target node of the edge.
|
- target_id (str): The ID of the target node of the edge.
|
||||||
- relationship_name (str): The type/label of the edge to create.
|
- relationship_name (str): The type/label of the edge to create.
|
||||||
- edge_properties (Optional[Dict[str, Any]]): A dictionary of properties to assign
|
- properties (Optional[Dict[str, Any]]): A dictionary of properties to assign
|
||||||
to the edge. (default {})
|
to the edge. (default None)
|
||||||
|
|
||||||
Returns:
|
|
||||||
--------
|
|
||||||
|
|
||||||
The result of the query execution, typically indicating the created edge.
|
|
||||||
"""
|
"""
|
||||||
serialized_properties = self.serialize_properties(edge_properties)
|
serialized_properties = self.serialize_properties(properties or {})
|
||||||
|
|
||||||
query = dedent(
|
query = dedent(
|
||||||
f"""\
|
f"""\
|
||||||
|
|
@ -405,13 +409,13 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
)
|
)
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"from_node": str(from_node),
|
"from_node": str(source_id), # Adding str as callsites may still be passing UUID
|
||||||
"to_node": str(to_node),
|
"to_node": str(target_id),
|
||||||
"relationship_name": relationship_name,
|
"relationship_name": relationship_name,
|
||||||
"properties": serialized_properties,
|
"properties": serialized_properties,
|
||||||
}
|
}
|
||||||
|
|
||||||
return await self.query(query, params)
|
await self.query(query, params)
|
||||||
|
|
||||||
def _flatten_edge_properties(self, properties: Dict[str, Any]) -> Dict[str, Any]:
|
def _flatten_edge_properties(self, properties: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -445,9 +449,9 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
|
|
||||||
return flattened
|
return flattened
|
||||||
|
|
||||||
@record_graph_changes
|
@record_graph_changes # type: ignore
|
||||||
@override_distributed(queued_add_edges)
|
@override_distributed(queued_add_edges) # type: ignore
|
||||||
async def add_edges(self, edges: list[tuple[str, str, str, dict[str, Any]]]) -> None:
|
async def add_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> None:
|
||||||
"""
|
"""
|
||||||
Add multiple edges between nodes in a single query.
|
Add multiple edges between nodes in a single query.
|
||||||
|
|
||||||
|
|
@ -478,10 +482,10 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
) YIELD rel
|
) YIELD rel
|
||||||
RETURN rel"""
|
RETURN rel"""
|
||||||
|
|
||||||
edges = [
|
edge_params = [
|
||||||
{
|
{
|
||||||
"from_node": str(edge[0]),
|
"from_node": str(edge[0]), # Adding str as callsites may still be passing UUID
|
||||||
"to_node": str(edge[1]),
|
"to_node": str(edge[1]), # Adding str as callsites may still be passing UUID
|
||||||
"relationship_name": edge[2],
|
"relationship_name": edge[2],
|
||||||
"properties": self._flatten_edge_properties(
|
"properties": self._flatten_edge_properties(
|
||||||
{
|
{
|
||||||
|
|
@ -495,13 +499,12 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
]
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
results = await self.query(query, dict(edges=edges))
|
await self.query(query, dict(edges=edge_params))
|
||||||
return results
|
|
||||||
except Neo4jError as error:
|
except Neo4jError as error:
|
||||||
logger.error("Neo4j query error: %s", error, exc_info=True)
|
logger.error("Neo4j query error: %s", error, exc_info=True)
|
||||||
raise error
|
raise error
|
||||||
|
|
||||||
async def get_edges(self, node_id: str):
|
async def get_edges(self, node_id: str) -> List[Tuple[str, str, str, Dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
Retrieve all edges connected to a specified node.
|
Retrieve all edges connected to a specified node.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,13 @@
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Dict, Any
|
||||||
from chromadb import AsyncHttpClient, Settings
|
from chromadb import AsyncHttpClient, Settings
|
||||||
|
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.modules.storage.utils import get_own_properties
|
from cognee.modules.storage.utils import get_own_properties
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
|
from cognee.infrastructure.engine.models.DataPoint import MetaData
|
||||||
from cognee.infrastructure.engine.utils import parse_id
|
from cognee.infrastructure.engine.utils import parse_id
|
||||||
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
||||||
from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult
|
from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult
|
||||||
|
|
@ -35,9 +36,9 @@ class IndexSchema(DataPoint):
|
||||||
|
|
||||||
text: str
|
text: str
|
||||||
|
|
||||||
metadata: dict = {"index_fields": ["text"]}
|
metadata: MetaData = {"index_fields": ["text"], "type": "IndexSchema"}
|
||||||
|
|
||||||
def model_dump(self):
|
def model_dump(self, **kwargs: Any) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Serialize the instance data for storage.
|
Serialize the instance data for storage.
|
||||||
|
|
||||||
|
|
@ -49,11 +50,11 @@ class IndexSchema(DataPoint):
|
||||||
|
|
||||||
A dictionary containing serialized data processed for ChromaDB storage.
|
A dictionary containing serialized data processed for ChromaDB storage.
|
||||||
"""
|
"""
|
||||||
data = super().model_dump()
|
data = super().model_dump(**kwargs)
|
||||||
return process_data_for_chroma(data)
|
return process_data_for_chroma(data)
|
||||||
|
|
||||||
|
|
||||||
def process_data_for_chroma(data):
|
def process_data_for_chroma(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Convert complex data types to a format suitable for ChromaDB storage.
|
Convert complex data types to a format suitable for ChromaDB storage.
|
||||||
|
|
||||||
|
|
@ -73,7 +74,7 @@ def process_data_for_chroma(data):
|
||||||
|
|
||||||
A dictionary containing the processed key-value pairs suitable for ChromaDB storage.
|
A dictionary containing the processed key-value pairs suitable for ChromaDB storage.
|
||||||
"""
|
"""
|
||||||
processed_data = {}
|
processed_data: Dict[str, Any] = {}
|
||||||
for key, value in data.items():
|
for key, value in data.items():
|
||||||
if isinstance(value, UUID):
|
if isinstance(value, UUID):
|
||||||
processed_data[key] = str(value)
|
processed_data[key] = str(value)
|
||||||
|
|
@ -90,7 +91,7 @@ def process_data_for_chroma(data):
|
||||||
return processed_data
|
return processed_data
|
||||||
|
|
||||||
|
|
||||||
def restore_data_from_chroma(data):
|
def restore_data_from_chroma(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Restore original data structure from ChromaDB storage format.
|
Restore original data structure from ChromaDB storage format.
|
||||||
|
|
||||||
|
|
@ -152,8 +153,8 @@ class ChromaDBAdapter(VectorDBInterface):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name = "ChromaDB"
|
name = "ChromaDB"
|
||||||
url: str
|
url: str | None
|
||||||
api_key: str
|
api_key: str | None
|
||||||
connection: AsyncHttpClient = None
|
connection: AsyncHttpClient = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -216,7 +217,9 @@ class ChromaDBAdapter(VectorDBInterface):
|
||||||
collections = await self.get_collection_names()
|
collections = await self.get_collection_names()
|
||||||
return collection_name in collections
|
return collection_name in collections
|
||||||
|
|
||||||
async def create_collection(self, collection_name: str, payload_schema=None):
|
async def create_collection(
|
||||||
|
self, collection_name: str, payload_schema: Optional[Any] = None
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Create a new collection in ChromaDB if it does not already exist.
|
Create a new collection in ChromaDB if it does not already exist.
|
||||||
|
|
||||||
|
|
@ -254,7 +257,7 @@ class ChromaDBAdapter(VectorDBInterface):
|
||||||
client = await self.get_connection()
|
client = await self.get_connection()
|
||||||
return await client.get_collection(collection_name)
|
return await client.get_collection(collection_name)
|
||||||
|
|
||||||
async def create_data_points(self, collection_name: str, data_points: list[DataPoint]):
|
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]) -> None:
|
||||||
"""
|
"""
|
||||||
Create and upsert data points into the specified collection in ChromaDB.
|
Create and upsert data points into the specified collection in ChromaDB.
|
||||||
|
|
||||||
|
|
@ -282,7 +285,7 @@ class ChromaDBAdapter(VectorDBInterface):
|
||||||
ids=ids, embeddings=embeddings, metadatas=metadatas, documents=texts
|
ids=ids, embeddings=embeddings, metadatas=metadatas, documents=texts
|
||||||
)
|
)
|
||||||
|
|
||||||
async def create_vector_index(self, index_name: str, index_property_name: str):
|
async def create_vector_index(self, index_name: str, index_property_name: str) -> None:
|
||||||
"""
|
"""
|
||||||
Create a vector index as a ChromaDB collection based on provided names.
|
Create a vector index as a ChromaDB collection based on provided names.
|
||||||
|
|
||||||
|
|
@ -296,7 +299,7 @@ class ChromaDBAdapter(VectorDBInterface):
|
||||||
|
|
||||||
async def index_data_points(
|
async def index_data_points(
|
||||||
self, index_name: str, index_property_name: str, data_points: list[DataPoint]
|
self, index_name: str, index_property_name: str, data_points: list[DataPoint]
|
||||||
):
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Index the provided data points based on the specified index property in ChromaDB.
|
Index the provided data points based on the specified index property in ChromaDB.
|
||||||
|
|
||||||
|
|
@ -315,10 +318,11 @@ class ChromaDBAdapter(VectorDBInterface):
|
||||||
text=getattr(data_point, data_point.metadata["index_fields"][0]),
|
text=getattr(data_point, data_point.metadata["index_fields"][0]),
|
||||||
)
|
)
|
||||||
for data_point in data_points
|
for data_point in data_points
|
||||||
|
if data_point.metadata and len(data_point.metadata["index_fields"]) > 0
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
|
async def retrieve(self, collection_name: str, data_point_ids: List[str]) -> List[ScoredResult]:
|
||||||
"""
|
"""
|
||||||
Retrieve data points by their IDs from a ChromaDB collection.
|
Retrieve data points by their IDs from a ChromaDB collection.
|
||||||
|
|
||||||
|
|
@ -350,12 +354,12 @@ class ChromaDBAdapter(VectorDBInterface):
|
||||||
async def search(
|
async def search(
|
||||||
self,
|
self,
|
||||||
collection_name: str,
|
collection_name: str,
|
||||||
query_text: str = None,
|
query_text: Optional[str] = None,
|
||||||
query_vector: List[float] = None,
|
query_vector: Optional[List[float]] = None,
|
||||||
limit: int = 15,
|
limit: int = 15,
|
||||||
with_vector: bool = False,
|
with_vector: bool = False,
|
||||||
normalized: bool = True,
|
normalized: bool = True,
|
||||||
):
|
) -> List[ScoredResult]:
|
||||||
"""
|
"""
|
||||||
Search for items in a collection using either a text or a vector query.
|
Search for items in a collection using either a text or a vector query.
|
||||||
|
|
||||||
|
|
@ -437,7 +441,7 @@ class ChromaDBAdapter(VectorDBInterface):
|
||||||
query_texts: List[str],
|
query_texts: List[str],
|
||||||
limit: int = 5,
|
limit: int = 5,
|
||||||
with_vectors: bool = False,
|
with_vectors: bool = False,
|
||||||
):
|
) -> List[List[ScoredResult]]:
|
||||||
"""
|
"""
|
||||||
Perform multiple searches in a single request for efficiency, returning results for each
|
Perform multiple searches in a single request for efficiency, returning results for each
|
||||||
query.
|
query.
|
||||||
|
|
@ -507,7 +511,7 @@ class ChromaDBAdapter(VectorDBInterface):
|
||||||
|
|
||||||
return all_results
|
return all_results
|
||||||
|
|
||||||
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
async def delete_data_points(self, collection_name: str, data_point_ids: List[str]) -> bool:
|
||||||
"""
|
"""
|
||||||
Remove data points from a collection based on their IDs.
|
Remove data points from a collection based on their IDs.
|
||||||
|
|
||||||
|
|
@ -528,7 +532,7 @@ class ChromaDBAdapter(VectorDBInterface):
|
||||||
await collection.delete(ids=data_point_ids)
|
await collection.delete(ids=data_point_ids)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def prune(self):
|
async def prune(self) -> bool:
|
||||||
"""
|
"""
|
||||||
Delete all collections in the ChromaDB database.
|
Delete all collections in the ChromaDB database.
|
||||||
|
|
||||||
|
|
@ -538,12 +542,12 @@ class ChromaDBAdapter(VectorDBInterface):
|
||||||
Returns True upon successful deletion of all collections.
|
Returns True upon successful deletion of all collections.
|
||||||
"""
|
"""
|
||||||
client = await self.get_connection()
|
client = await self.get_connection()
|
||||||
collections = await self.list_collections()
|
collection_names = await self.get_collection_names()
|
||||||
for collection_name in collections:
|
for collection_name in collection_names:
|
||||||
await client.delete_collection(collection_name)
|
await client.delete_collection(collection_name)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def get_collection_names(self):
|
async def get_collection_names(self) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Retrieve the names of all collections in the ChromaDB database.
|
Retrieve the names of all collections in the ChromaDB database.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,25 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from os import path
|
from os import path
|
||||||
|
from uuid import UUID
|
||||||
import lancedb
|
import lancedb
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from lancedb.pydantic import LanceModel, Vector
|
from lancedb.pydantic import LanceModel, Vector
|
||||||
from typing import Generic, List, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
|
from typing import (
|
||||||
|
Generic,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
get_args,
|
||||||
|
get_origin,
|
||||||
|
get_type_hints,
|
||||||
|
Dict,
|
||||||
|
Any,
|
||||||
|
)
|
||||||
|
|
||||||
from cognee.infrastructure.databases.exceptions import MissingQueryParameterError
|
from cognee.infrastructure.databases.exceptions import MissingQueryParameterError
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
|
from cognee.infrastructure.engine.models.DataPoint import MetaData
|
||||||
from cognee.infrastructure.engine.utils import parse_id
|
from cognee.infrastructure.engine.utils import parse_id
|
||||||
from cognee.infrastructure.files.storage import get_file_storage
|
from cognee.infrastructure.files.storage import get_file_storage
|
||||||
from cognee.modules.storage.utils import copy_model, get_own_properties
|
from cognee.modules.storage.utils import copy_model, get_own_properties
|
||||||
|
|
@ -30,16 +43,16 @@ class IndexSchema(DataPoint):
|
||||||
to include 'text'.
|
to include 'text'.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
id: str
|
id: UUID
|
||||||
text: str
|
text: str
|
||||||
|
|
||||||
metadata: dict = {"index_fields": ["text"]}
|
metadata: MetaData = {"index_fields": ["text"], "type": "IndexSchema"}
|
||||||
|
|
||||||
|
|
||||||
class LanceDBAdapter(VectorDBInterface):
|
class LanceDBAdapter(VectorDBInterface):
|
||||||
name = "LanceDB"
|
name = "LanceDB"
|
||||||
url: str
|
url: Optional[str]
|
||||||
api_key: str
|
api_key: Optional[str]
|
||||||
connection: lancedb.AsyncConnection = None
|
connection: lancedb.AsyncConnection = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -53,7 +66,7 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
self.embedding_engine = embedding_engine
|
self.embedding_engine = embedding_engine
|
||||||
self.VECTOR_DB_LOCK = asyncio.Lock()
|
self.VECTOR_DB_LOCK = asyncio.Lock()
|
||||||
|
|
||||||
async def get_connection(self):
|
async def get_connection(self) -> lancedb.AsyncConnection:
|
||||||
"""
|
"""
|
||||||
Establishes and returns a connection to the LanceDB.
|
Establishes and returns a connection to the LanceDB.
|
||||||
|
|
||||||
|
|
@ -107,12 +120,11 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
collection_names = await connection.table_names()
|
collection_names = await connection.table_names()
|
||||||
return collection_name in collection_names
|
return collection_name in collection_names
|
||||||
|
|
||||||
async def create_collection(self, collection_name: str, payload_schema: BaseModel):
|
async def create_collection(
|
||||||
|
self, collection_name: str, payload_schema: Optional[Any] = None
|
||||||
|
) -> None:
|
||||||
vector_size = self.embedding_engine.get_vector_size()
|
vector_size = self.embedding_engine.get_vector_size()
|
||||||
|
|
||||||
payload_schema = self.get_data_point_schema(payload_schema)
|
|
||||||
data_point_types = get_type_hints(payload_schema)
|
|
||||||
|
|
||||||
class LanceDataPoint(LanceModel):
|
class LanceDataPoint(LanceModel):
|
||||||
"""
|
"""
|
||||||
Represents a data point in the Lance model with an ID, vector, and associated payload.
|
Represents a data point in the Lance model with an ID, vector, and associated payload.
|
||||||
|
|
@ -123,28 +135,28 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
- payload: Additional data or metadata associated with the data point.
|
- payload: Additional data or metadata associated with the data point.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
id: data_point_types["id"]
|
id: UUID
|
||||||
vector: Vector(vector_size)
|
vector: Vector[vector_size]
|
||||||
payload: payload_schema
|
payload: Dict[str, Any]
|
||||||
|
|
||||||
if not await self.has_collection(collection_name):
|
if not await self.has_collection(collection_name):
|
||||||
async with self.VECTOR_DB_LOCK:
|
async with self.VECTOR_DB_LOCK:
|
||||||
if not await self.has_collection(collection_name):
|
if not await self.has_collection(collection_name):
|
||||||
connection = await self.get_connection()
|
connection = await self.get_connection()
|
||||||
return await connection.create_table(
|
await connection.create_table(
|
||||||
name=collection_name,
|
name=collection_name,
|
||||||
schema=LanceDataPoint,
|
schema=LanceDataPoint,
|
||||||
exist_ok=True,
|
exist_ok=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_collection(self, collection_name: str):
|
async def get_collection(self, collection_name: str) -> Any:
|
||||||
if not await self.has_collection(collection_name):
|
if not await self.has_collection(collection_name):
|
||||||
raise CollectionNotFoundError(f"Collection '{collection_name}' not found!")
|
raise CollectionNotFoundError(f"Collection '{collection_name}' not found!")
|
||||||
|
|
||||||
connection = await self.get_connection()
|
connection = await self.get_connection()
|
||||||
return await connection.open_table(collection_name)
|
return await connection.open_table(collection_name)
|
||||||
|
|
||||||
async def create_data_points(self, collection_name: str, data_points: list[DataPoint]):
|
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]) -> None:
|
||||||
payload_schema = type(data_points[0])
|
payload_schema = type(data_points[0])
|
||||||
|
|
||||||
if not await self.has_collection(collection_name):
|
if not await self.has_collection(collection_name):
|
||||||
|
|
@ -175,14 +187,14 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
id: IdType
|
id: IdType
|
||||||
vector: Vector(vector_size)
|
vector: Vector[vector_size]
|
||||||
payload: PayloadSchema
|
payload: PayloadSchema
|
||||||
|
|
||||||
def create_lance_data_point(data_point: DataPoint, vector: list[float]) -> LanceDataPoint:
|
def create_lance_data_point(data_point: DataPoint, vector: List[float]) -> Any:
|
||||||
properties = get_own_properties(data_point)
|
properties = get_own_properties(data_point)
|
||||||
properties["id"] = str(properties["id"])
|
properties["id"] = str(properties["id"])
|
||||||
|
|
||||||
return LanceDataPoint[str, self.get_data_point_schema(type(data_point))](
|
return LanceDataPoint(
|
||||||
id=str(data_point.id),
|
id=str(data_point.id),
|
||||||
vector=vector,
|
vector=vector,
|
||||||
payload=properties,
|
payload=properties,
|
||||||
|
|
@ -201,7 +213,7 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
.execute(lance_data_points)
|
.execute(lance_data_points)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
|
async def retrieve(self, collection_name: str, data_point_ids: list[str]) -> List[ScoredResult]:
|
||||||
collection = await self.get_collection(collection_name)
|
collection = await self.get_collection(collection_name)
|
||||||
|
|
||||||
if len(data_point_ids) == 1:
|
if len(data_point_ids) == 1:
|
||||||
|
|
@ -221,12 +233,12 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
async def search(
|
async def search(
|
||||||
self,
|
self,
|
||||||
collection_name: str,
|
collection_name: str,
|
||||||
query_text: str = None,
|
query_text: Optional[str] = None,
|
||||||
query_vector: List[float] = None,
|
query_vector: Optional[List[float]] = None,
|
||||||
limit: int = 15,
|
limit: int = 15,
|
||||||
with_vector: bool = False,
|
with_vector: bool = False,
|
||||||
normalized: bool = True,
|
normalized: bool = True,
|
||||||
):
|
) -> List[ScoredResult]:
|
||||||
if query_text is None and query_vector is None:
|
if query_text is None and query_vector is None:
|
||||||
raise MissingQueryParameterError()
|
raise MissingQueryParameterError()
|
||||||
|
|
||||||
|
|
@ -264,9 +276,9 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
self,
|
self,
|
||||||
collection_name: str,
|
collection_name: str,
|
||||||
query_texts: List[str],
|
query_texts: List[str],
|
||||||
limit: int = None,
|
limit: Optional[int] = None,
|
||||||
with_vectors: bool = False,
|
with_vectors: bool = False,
|
||||||
):
|
) -> List[List[ScoredResult]]:
|
||||||
query_vectors = await self.embedding_engine.embed_text(query_texts)
|
query_vectors = await self.embedding_engine.embed_text(query_texts)
|
||||||
|
|
||||||
return await asyncio.gather(
|
return await asyncio.gather(
|
||||||
|
|
@ -274,40 +286,41 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
self.search(
|
self.search(
|
||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
query_vector=query_vector,
|
query_vector=query_vector,
|
||||||
limit=limit,
|
limit=limit or 15,
|
||||||
with_vector=with_vectors,
|
with_vector=with_vectors,
|
||||||
)
|
)
|
||||||
for query_vector in query_vectors
|
for query_vector in query_vectors
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
async def delete_data_points(self, collection_name: str, data_point_ids: List[str]) -> None:
|
||||||
collection = await self.get_collection(collection_name)
|
collection = await self.get_collection(collection_name)
|
||||||
|
|
||||||
# Delete one at a time to avoid commit conflicts
|
# Delete one at a time to avoid commit conflicts
|
||||||
for data_point_id in data_point_ids:
|
for data_point_id in data_point_ids:
|
||||||
await collection.delete(f"id = '{data_point_id}'")
|
await collection.delete(f"id = '{data_point_id}'")
|
||||||
|
|
||||||
async def create_vector_index(self, index_name: str, index_property_name: str):
|
async def create_vector_index(self, index_name: str, index_property_name: str) -> None:
|
||||||
await self.create_collection(
|
await self.create_collection(
|
||||||
f"{index_name}_{index_property_name}", payload_schema=IndexSchema
|
f"{index_name}_{index_property_name}", payload_schema=IndexSchema
|
||||||
)
|
)
|
||||||
|
|
||||||
async def index_data_points(
|
async def index_data_points(
|
||||||
self, index_name: str, index_property_name: str, data_points: list[DataPoint]
|
self, index_name: str, index_property_name: str, data_points: List[DataPoint]
|
||||||
):
|
) -> None:
|
||||||
await self.create_data_points(
|
await self.create_data_points(
|
||||||
f"{index_name}_{index_property_name}",
|
f"{index_name}_{index_property_name}",
|
||||||
[
|
[
|
||||||
IndexSchema(
|
IndexSchema(
|
||||||
id=str(data_point.id),
|
id=data_point.id,
|
||||||
text=getattr(data_point, data_point.metadata["index_fields"][0]),
|
text=getattr(data_point, data_point.metadata["index_fields"][0]),
|
||||||
)
|
)
|
||||||
for data_point in data_points
|
for data_point in data_points
|
||||||
|
if data_point.metadata and len(data_point.metadata.get("index_fields", [])) > 0
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
async def prune(self):
|
async def prune(self) -> None:
|
||||||
connection = await self.get_connection()
|
connection = await self.get_connection()
|
||||||
collection_names = await connection.table_names()
|
collection_names = await connection.table_names()
|
||||||
|
|
||||||
|
|
@ -316,12 +329,15 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
await collection.delete("id IS NOT NULL")
|
await collection.delete("id IS NOT NULL")
|
||||||
await connection.drop_table(collection_name)
|
await connection.drop_table(collection_name)
|
||||||
|
|
||||||
if self.url.startswith("/"):
|
if self.url and self.url.startswith("/"):
|
||||||
db_dir_path = path.dirname(self.url)
|
db_dir_path = path.dirname(self.url)
|
||||||
db_file_name = path.basename(self.url)
|
db_file_name = path.basename(self.url)
|
||||||
await get_file_storage(db_dir_path).remove_all(db_file_name)
|
await get_file_storage(db_dir_path).remove_all(db_file_name)
|
||||||
|
|
||||||
def get_data_point_schema(self, model_type: BaseModel):
|
def get_data_point_schema(self, model_type: Optional[Any]) -> Any:
|
||||||
|
if model_type is None:
|
||||||
|
return DataPoint
|
||||||
|
|
||||||
related_models_fields = []
|
related_models_fields = []
|
||||||
|
|
||||||
for field_name, field_config in model_type.model_fields.items():
|
for field_name, field_config in model_type.model_fields.items():
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict, List, Optional
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
@ -14,8 +14,10 @@ class ScoredResult(BaseModel):
|
||||||
better outcome.
|
better outcome.
|
||||||
- payload (Dict[str, Any]): Additional information related to the score, stored as
|
- payload (Dict[str, Any]): Additional information related to the score, stored as
|
||||||
key-value pairs in a dictionary.
|
key-value pairs in a dictionary.
|
||||||
|
- vector (Optional[List[float]]): Optional vector embedding associated with the result.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
id: UUID
|
id: UUID
|
||||||
score: float # Lower score is better
|
score: float # Lower score is better
|
||||||
payload: Dict[str, Any]
|
payload: Dict[str, Any]
|
||||||
|
vector: Optional[List[float]] = None
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,9 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import List, Optional, get_type_hints
|
from typing import List, Optional, get_type_hints, Dict, Any
|
||||||
from sqlalchemy.inspection import inspect
|
from sqlalchemy.inspection import inspect
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
from sqlalchemy.dialects.postgresql import insert
|
from sqlalchemy.dialects.postgresql import insert
|
||||||
from sqlalchemy import JSON, Column, Table, select, delete, MetaData
|
from sqlalchemy import JSON, Table, select, delete, MetaData
|
||||||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
||||||
from sqlalchemy.exc import ProgrammingError
|
from sqlalchemy.exc import ProgrammingError
|
||||||
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
|
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
|
||||||
|
|
@ -12,6 +12,7 @@ from asyncpg import DeadlockDetectedError, DuplicateTableError, UniqueViolationE
|
||||||
|
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
|
from cognee.infrastructure.engine.models.DataPoint import MetaData as DataPointMetaData
|
||||||
from cognee.infrastructure.engine.utils import parse_id
|
from cognee.infrastructure.engine.utils import parse_id
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
|
|
||||||
|
|
@ -42,7 +43,7 @@ class IndexSchema(DataPoint):
|
||||||
|
|
||||||
text: str
|
text: str
|
||||||
|
|
||||||
metadata: dict = {"index_fields": ["text"]}
|
metadata: DataPointMetaData = {"index_fields": ["text"], "type": "IndexSchema"}
|
||||||
|
|
||||||
|
|
||||||
class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
|
|
@ -122,8 +123,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
stop=stop_after_attempt(5),
|
stop=stop_after_attempt(5),
|
||||||
wait=wait_exponential(multiplier=2, min=1, max=6),
|
wait=wait_exponential(multiplier=2, min=1, max=6),
|
||||||
)
|
)
|
||||||
async def create_collection(self, collection_name: str, payload_schema=None):
|
async def create_collection(
|
||||||
data_point_types = get_type_hints(DataPoint)
|
self, collection_name: str, payload_schema: Optional[Any] = None
|
||||||
|
) -> None:
|
||||||
vector_size = self.embedding_engine.get_vector_size()
|
vector_size = self.embedding_engine.get_vector_size()
|
||||||
|
|
||||||
async with self.VECTOR_DB_LOCK:
|
async with self.VECTOR_DB_LOCK:
|
||||||
|
|
@ -147,29 +149,31 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
__tablename__ = collection_name
|
__tablename__ = collection_name
|
||||||
__table_args__ = {"extend_existing": True}
|
__table_args__ = {"extend_existing": True}
|
||||||
# PGVector requires one column to be the primary key
|
# PGVector requires one column to be the primary key
|
||||||
id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True)
|
id: Mapped[str] = mapped_column(primary_key=True)
|
||||||
payload = Column(JSON)
|
payload: Mapped[Dict[str, Any]] = mapped_column(JSON)
|
||||||
vector = Column(self.Vector(vector_size))
|
vector: Mapped[List[float]] = mapped_column(self.Vector(vector_size))
|
||||||
|
|
||||||
def __init__(self, id, payload, vector):
|
def __init__(
|
||||||
|
self, id: str, payload: Dict[str, Any], vector: List[float]
|
||||||
|
) -> None:
|
||||||
self.id = id
|
self.id = id
|
||||||
self.payload = payload
|
self.payload = payload
|
||||||
self.vector = vector
|
self.vector = vector
|
||||||
|
|
||||||
async with self.engine.begin() as connection:
|
async with self.engine.begin() as connection:
|
||||||
if len(Base.metadata.tables.keys()) > 0:
|
if len(Base.metadata.tables.keys()) > 0:
|
||||||
await connection.run_sync(
|
from sqlalchemy import Table
|
||||||
Base.metadata.create_all, tables=[PGVectorDataPoint.__table__]
|
|
||||||
)
|
table: Table = PGVectorDataPoint.__table__ # type: ignore
|
||||||
|
await connection.run_sync(Base.metadata.create_all, tables=[table])
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
retry=retry_if_exception_type(DeadlockDetectedError),
|
retry=retry_if_exception_type(DeadlockDetectedError),
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
wait=wait_exponential(multiplier=2, min=1, max=6),
|
wait=wait_exponential(multiplier=2, min=1, max=6),
|
||||||
)
|
)
|
||||||
@override_distributed(queued_add_data_points)
|
@override_distributed(queued_add_data_points) # type: ignore
|
||||||
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
|
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]) -> None:
|
||||||
data_point_types = get_type_hints(DataPoint)
|
|
||||||
if not await self.has_collection(collection_name):
|
if not await self.has_collection(collection_name):
|
||||||
await self.create_collection(
|
await self.create_collection(
|
||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
|
|
@ -196,11 +200,11 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
__tablename__ = collection_name
|
__tablename__ = collection_name
|
||||||
__table_args__ = {"extend_existing": True}
|
__table_args__ = {"extend_existing": True}
|
||||||
# PGVector requires one column to be the primary key
|
# PGVector requires one column to be the primary key
|
||||||
id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True)
|
id: Mapped[str] = mapped_column(primary_key=True)
|
||||||
payload = Column(JSON)
|
payload: Mapped[Dict[str, Any]] = mapped_column(JSON)
|
||||||
vector = Column(self.Vector(vector_size))
|
vector: Mapped[List[float]] = mapped_column(self.Vector(vector_size))
|
||||||
|
|
||||||
def __init__(self, id, payload, vector):
|
def __init__(self, id: str, payload: Dict[str, Any], vector: List[float]) -> None:
|
||||||
self.id = id
|
self.id = id
|
||||||
self.payload = payload
|
self.payload = payload
|
||||||
self.vector = vector
|
self.vector = vector
|
||||||
|
|
@ -225,13 +229,13 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
# else:
|
# else:
|
||||||
pgvector_data_points.append(
|
pgvector_data_points.append(
|
||||||
PGVectorDataPoint(
|
PGVectorDataPoint(
|
||||||
id=data_point.id,
|
id=str(data_point.id),
|
||||||
vector=data_vectors[data_index],
|
vector=data_vectors[data_index],
|
||||||
payload=serialize_data(data_point.model_dump()),
|
payload=serialize_data(data_point.model_dump()),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_dict(obj):
|
def to_dict(obj: Any) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
column.key: getattr(obj, column.key)
|
column.key: getattr(obj, column.key)
|
||||||
for column in inspect(obj).mapper.column_attrs
|
for column in inspect(obj).mapper.column_attrs
|
||||||
|
|
@ -245,12 +249,12 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
await session.execute(insert_statement)
|
await session.execute(insert_statement)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
async def create_vector_index(self, index_name: str, index_property_name: str):
|
async def create_vector_index(self, index_name: str, index_property_name: str) -> None:
|
||||||
await self.create_collection(f"{index_name}_{index_property_name}")
|
await self.create_collection(f"{index_name}_{index_property_name}")
|
||||||
|
|
||||||
async def index_data_points(
|
async def index_data_points(
|
||||||
self, index_name: str, index_property_name: str, data_points: list[DataPoint]
|
self, index_name: str, index_property_name: str, data_points: List[DataPoint]
|
||||||
):
|
) -> None:
|
||||||
await self.create_data_points(
|
await self.create_data_points(
|
||||||
f"{index_name}_{index_property_name}",
|
f"{index_name}_{index_property_name}",
|
||||||
[
|
[
|
||||||
|
|
@ -262,11 +266,12 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_table(self, collection_name: str) -> Table:
|
async def get_table(self, table_name: str, schema_name: Optional[str] = None) -> Table:
|
||||||
"""
|
"""
|
||||||
Dynamically loads a table using the given collection name
|
Dynamically loads a table using the given table name
|
||||||
with an async engine.
|
with an async engine. Schema parameter is ignored for vector collections.
|
||||||
"""
|
"""
|
||||||
|
collection_name = table_name
|
||||||
async with self.engine.begin() as connection:
|
async with self.engine.begin() as connection:
|
||||||
# Create a MetaData instance to load table information
|
# Create a MetaData instance to load table information
|
||||||
metadata = MetaData()
|
metadata = MetaData()
|
||||||
|
|
@ -279,15 +284,15 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
f"Collection '{collection_name}' not found!",
|
f"Collection '{collection_name}' not found!",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def retrieve(self, collection_name: str, data_point_ids: List[str]):
|
async def retrieve(self, collection_name: str, data_point_ids: List[str]) -> List[ScoredResult]:
|
||||||
# Get PGVectorDataPoint Table from database
|
# Get PGVectorDataPoint Table from database
|
||||||
PGVectorDataPoint = await self.get_table(collection_name)
|
PGVectorDataPoint = await self.get_table(collection_name)
|
||||||
|
|
||||||
async with self.get_async_session() as session:
|
async with self.get_async_session() as session:
|
||||||
results = await session.execute(
|
query_result = await session.execute(
|
||||||
select(PGVectorDataPoint).where(PGVectorDataPoint.c.id.in_(data_point_ids))
|
select(PGVectorDataPoint).where(PGVectorDataPoint.c.id.in_(data_point_ids))
|
||||||
)
|
)
|
||||||
results = results.all()
|
results = query_result.all()
|
||||||
|
|
||||||
return [
|
return [
|
||||||
ScoredResult(id=parse_id(result.id), payload=result.payload, score=0)
|
ScoredResult(id=parse_id(result.id), payload=result.payload, score=0)
|
||||||
|
|
@ -311,9 +316,6 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
# Get PGVectorDataPoint Table from database
|
# Get PGVectorDataPoint Table from database
|
||||||
PGVectorDataPoint = await self.get_table(collection_name)
|
PGVectorDataPoint = await self.get_table(collection_name)
|
||||||
|
|
||||||
# NOTE: This needs to be initialized in case search doesn't return a value
|
|
||||||
closest_items = []
|
|
||||||
|
|
||||||
# Use async session to connect to the database
|
# Use async session to connect to the database
|
||||||
async with self.get_async_session() as session:
|
async with self.get_async_session() as session:
|
||||||
query = select(
|
query = select(
|
||||||
|
|
@ -325,12 +327,12 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
query = query.limit(limit)
|
query = query.limit(limit)
|
||||||
|
|
||||||
# Find closest vectors to query_vector
|
# Find closest vectors to query_vector
|
||||||
closest_items = await session.execute(query)
|
query_results = await session.execute(query)
|
||||||
|
|
||||||
vector_list = []
|
vector_list = []
|
||||||
|
|
||||||
# Extract distances and find min/max for normalization
|
# Extract distances and find min/max for normalization
|
||||||
for vector in closest_items.all():
|
for vector in query_results.all():
|
||||||
vector_list.append(
|
vector_list.append(
|
||||||
{
|
{
|
||||||
"id": parse_id(str(vector.id)),
|
"id": parse_id(str(vector.id)),
|
||||||
|
|
@ -349,7 +351,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
|
|
||||||
# Create and return ScoredResult objects
|
# Create and return ScoredResult objects
|
||||||
return [
|
return [
|
||||||
ScoredResult(id=row.get("id"), payload=row.get("payload"), score=row.get("score"))
|
ScoredResult(id=row["id"], payload=row["payload"] or {}, score=row["score"])
|
||||||
for row in vector_list
|
for row in vector_list
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -357,9 +359,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
self,
|
self,
|
||||||
collection_name: str,
|
collection_name: str,
|
||||||
query_texts: List[str],
|
query_texts: List[str],
|
||||||
limit: int = None,
|
limit: Optional[int] = None,
|
||||||
with_vectors: bool = False,
|
with_vectors: bool = False,
|
||||||
):
|
) -> List[List[ScoredResult]]:
|
||||||
query_vectors = await self.embedding_engine.embed_text(query_texts)
|
query_vectors = await self.embedding_engine.embed_text(query_texts)
|
||||||
|
|
||||||
return await asyncio.gather(
|
return await asyncio.gather(
|
||||||
|
|
@ -367,14 +369,14 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
self.search(
|
self.search(
|
||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
query_vector=query_vector,
|
query_vector=query_vector,
|
||||||
limit=limit,
|
limit=limit or 15,
|
||||||
with_vector=with_vectors,
|
with_vector=with_vectors,
|
||||||
)
|
)
|
||||||
for query_vector in query_vectors
|
for query_vector in query_vectors
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
async def delete_data_points(self, collection_name: str, data_point_ids: List[str]) -> Any:
|
||||||
async with self.get_async_session() as session:
|
async with self.get_async_session() as session:
|
||||||
# Get PGVectorDataPoint Table from database
|
# Get PGVectorDataPoint Table from database
|
||||||
PGVectorDataPoint = await self.get_table(collection_name)
|
PGVectorDataPoint = await self.get_table(collection_name)
|
||||||
|
|
@ -384,6 +386,6 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
await session.commit()
|
await session.commit()
|
||||||
return results
|
return results
|
||||||
|
|
||||||
async def prune(self):
|
async def prune(self) -> None:
|
||||||
# Clean up the database if it was set up as temporary
|
# Clean up the database if it was set up as temporary
|
||||||
await self.delete_database()
|
await self.delete_database()
|
||||||
|
|
|
||||||
|
|
@ -1,109 +0,0 @@
|
||||||
import os
|
|
||||||
|
|
||||||
import pathlib
|
|
||||||
import cognee
|
|
||||||
from cognee.infrastructure.files.storage import get_storage_config
|
|
||||||
from cognee.modules.search.operations import get_history
|
|
||||||
from cognee.modules.users.methods import get_default_user
|
|
||||||
from cognee.shared.logging_utils import get_logger
|
|
||||||
from cognee.modules.search.types import SearchType
|
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger()
|
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
cognee.config.set_graph_database_provider("memgraph")
|
|
||||||
data_directory_path = str(
|
|
||||||
pathlib.Path(
|
|
||||||
os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_memgraph")
|
|
||||||
).resolve()
|
|
||||||
)
|
|
||||||
cognee.config.data_root_directory(data_directory_path)
|
|
||||||
cognee_directory_path = str(
|
|
||||||
pathlib.Path(
|
|
||||||
os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_memgraph")
|
|
||||||
).resolve()
|
|
||||||
)
|
|
||||||
cognee.config.system_root_directory(cognee_directory_path)
|
|
||||||
|
|
||||||
await cognee.prune.prune_data()
|
|
||||||
await cognee.prune.prune_system(metadata=True)
|
|
||||||
|
|
||||||
dataset_name = "cs_explanations"
|
|
||||||
|
|
||||||
explanation_file_path = os.path.join(
|
|
||||||
pathlib.Path(__file__).parent, "test_data/Natural_language_processing.txt"
|
|
||||||
)
|
|
||||||
await cognee.add([explanation_file_path], dataset_name)
|
|
||||||
|
|
||||||
text = """A quantum computer is a computer that takes advantage of quantum mechanical phenomena.
|
|
||||||
At small scales, physical matter exhibits properties of both particles and waves, and quantum computing leverages this behavior, specifically quantum superposition and entanglement, using specialized hardware that supports the preparation and manipulation of quantum states.
|
|
||||||
Classical physics cannot explain the operation of these quantum devices, and a scalable quantum computer could perform some calculations exponentially faster (with respect to input size scaling) than any modern "classical" computer. In particular, a large-scale quantum computer could break widely used encryption schemes and aid physicists in performing physical simulations; however, the current state of the technology is largely experimental and impractical, with several obstacles to useful applications. Moreover, scalable quantum computers do not hold promise for many practical tasks, and for many important tasks quantum speedups are proven impossible.
|
|
||||||
The basic unit of information in quantum computing is the qubit, similar to the bit in traditional digital electronics. Unlike a classical bit, a qubit can exist in a superposition of its two "basis" states. When measuring a qubit, the result is a probabilistic output of a classical bit, therefore making quantum computers nondeterministic in general. If a quantum computer manipulates the qubit in a particular way, wave interference effects can amplify the desired measurement results. The design of quantum algorithms involves creating procedures that allow a quantum computer to perform calculations efficiently and quickly.
|
|
||||||
Physically engineering high-quality qubits has proven challenging. If a physical qubit is not sufficiently isolated from its environment, it suffers from quantum decoherence, introducing noise into calculations. Paradoxically, perfectly isolating qubits is also undesirable because quantum computations typically need to initialize qubits, perform controlled qubit interactions, and measure the resulting quantum states. Each of those operations introduces errors and suffers from noise, and such inaccuracies accumulate.
|
|
||||||
In principle, a non-quantum (classical) computer can solve the same computational problems as a quantum computer, given enough time. Quantum advantage comes in the form of time complexity rather than computability, and quantum complexity theory shows that some quantum algorithms for carefully selected tasks require exponentially fewer computational steps than the best known non-quantum algorithms. Such tasks can in theory be solved on a large-scale quantum computer whereas classical computers would not finish computations in any reasonable amount of time. However, quantum speedup is not universal or even typical across computational tasks, since basic tasks such as sorting are proven to not allow any asymptotic quantum speedup. Claims of quantum supremacy have drawn significant attention to the discipline, but are demonstrated on contrived tasks, while near-term practical use cases remain limited.
|
|
||||||
"""
|
|
||||||
|
|
||||||
await cognee.add([text], dataset_name)
|
|
||||||
|
|
||||||
await cognee.cognify([dataset_name])
|
|
||||||
|
|
||||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
|
||||||
|
|
||||||
vector_engine = get_vector_engine()
|
|
||||||
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
|
|
||||||
random_node_name = random_node.payload["text"]
|
|
||||||
|
|
||||||
search_results = await cognee.search(
|
|
||||||
query_type=SearchType.INSIGHTS, query_text=random_node_name
|
|
||||||
)
|
|
||||||
assert len(search_results) != 0, "The search results list is empty."
|
|
||||||
print("\n\nExtracted sentences are:\n")
|
|
||||||
for result in search_results:
|
|
||||||
print(f"{result}\n")
|
|
||||||
|
|
||||||
search_results = await cognee.search(query_type=SearchType.CHUNKS, query_text=random_node_name)
|
|
||||||
assert len(search_results) != 0, "The search results list is empty."
|
|
||||||
print("\n\nExtracted chunks are:\n")
|
|
||||||
for result in search_results:
|
|
||||||
print(f"{result}\n")
|
|
||||||
|
|
||||||
search_results = await cognee.search(
|
|
||||||
query_type=SearchType.SUMMARIES, query_text=random_node_name
|
|
||||||
)
|
|
||||||
assert len(search_results) != 0, "Query related summaries don't exist."
|
|
||||||
print("\nExtracted results are:\n")
|
|
||||||
for result in search_results:
|
|
||||||
print(f"{result}\n")
|
|
||||||
|
|
||||||
search_results = await cognee.search(
|
|
||||||
query_type=SearchType.NATURAL_LANGUAGE,
|
|
||||||
query_text=f"Find nodes connected to node with name {random_node_name}",
|
|
||||||
)
|
|
||||||
assert len(search_results) != 0, "Query related natural language don't exist."
|
|
||||||
print("\nExtracted results are:\n")
|
|
||||||
for result in search_results:
|
|
||||||
print(f"{result}\n")
|
|
||||||
|
|
||||||
user = await get_default_user()
|
|
||||||
history = await get_history(user.id)
|
|
||||||
|
|
||||||
assert len(history) == 8, "Search history is not correct."
|
|
||||||
|
|
||||||
await cognee.prune.prune_data()
|
|
||||||
data_root_directory = get_storage_config()["data_root_directory"]
|
|
||||||
assert not os.path.isdir(data_root_directory), "Local data files are not deleted"
|
|
||||||
|
|
||||||
await cognee.prune.prune_system(metadata=True)
|
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
|
||||||
|
|
||||||
graph_engine = await get_graph_engine()
|
|
||||||
nodes, edges = await graph_engine.get_graph_data()
|
|
||||||
assert len(nodes) == 0 and len(edges) == 0, "Memgraph graph database is not empty"
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
asyncio.run(main())
|
|
||||||
26
mypy.ini
26
mypy.ini
|
|
@ -1,7 +1,7 @@
|
||||||
[mypy]
|
[mypy]
|
||||||
python_version=3.8
|
python_version=3.10
|
||||||
ignore_missing_imports=false
|
ignore_missing_imports=false
|
||||||
strict_optional=false
|
strict_optional=true
|
||||||
warn_redundant_casts=true
|
warn_redundant_casts=true
|
||||||
disallow_any_generics=true
|
disallow_any_generics=true
|
||||||
disallow_untyped_defs=true
|
disallow_untyped_defs=true
|
||||||
|
|
@ -10,6 +10,12 @@ warn_return_any=true
|
||||||
namespace_packages=true
|
namespace_packages=true
|
||||||
warn_unused_ignores=true
|
warn_unused_ignores=true
|
||||||
show_error_codes=true
|
show_error_codes=true
|
||||||
|
disallow_incomplete_defs=true
|
||||||
|
disallow_untyped_decorators=true
|
||||||
|
no_implicit_optional=true
|
||||||
|
warn_unreachable=true
|
||||||
|
warn_no_return=true
|
||||||
|
warn_unused_configs=true
|
||||||
#exclude=reflection/module_cases/*
|
#exclude=reflection/module_cases/*
|
||||||
exclude=docs/examples/archive/*|tests/reflection/module_cases/*
|
exclude=docs/examples/archive/*|tests/reflection/module_cases/*
|
||||||
|
|
||||||
|
|
@ -18,6 +24,22 @@ disallow_untyped_defs=false
|
||||||
warn_return_any=false
|
warn_return_any=false
|
||||||
|
|
||||||
|
|
||||||
|
[mypy-cognee.infrastructure.databases.*]
|
||||||
|
ignore_missing_imports=true
|
||||||
|
|
||||||
|
# Third-party database libraries that lack type stubs
|
||||||
|
[mypy-chromadb.*]
|
||||||
|
ignore_missing_imports=true
|
||||||
|
|
||||||
|
[mypy-lancedb.*]
|
||||||
|
ignore_missing_imports=true
|
||||||
|
|
||||||
|
[mypy-asyncpg.*]
|
||||||
|
ignore_missing_imports=true
|
||||||
|
|
||||||
|
[mypy-pgvector.*]
|
||||||
|
ignore_missing_imports=true
|
||||||
|
|
||||||
[mypy-docs.*]
|
[mypy-docs.*]
|
||||||
disallow_untyped_defs=false
|
disallow_untyped_defs=false
|
||||||
|
|
||||||
|
|
|
||||||
82
notebooks/neptune-analytics-example.ipynb
vendored
82
notebooks/neptune-analytics-example.ipynb
vendored
|
|
@ -83,16 +83,16 @@
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"metadata": {},
|
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"import os\n",
|
"import os\n",
|
||||||
"import pathlib\n",
|
"import pathlib\n",
|
||||||
"from cognee import config, add, cognify, search, SearchType, prune, visualize_graph\n",
|
"from cognee import config, add, cognify, search, SearchType, prune, visualize_graph\n",
|
||||||
"from dotenv import load_dotenv"
|
"from dotenv import load_dotenv"
|
||||||
],
|
]
|
||||||
"outputs": [],
|
|
||||||
"execution_count": null
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
|
|
@ -106,7 +106,9 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# load environment variables from file .env\n",
|
"# load environment variables from file .env\n",
|
||||||
"load_dotenv()\n",
|
"load_dotenv()\n",
|
||||||
|
|
@ -145,9 +147,7 @@
|
||||||
" \"vector_db_url\": f\"neptune-graph://{graph_identifier}\", # Neptune Analytics endpoint with the format neptune-graph://<GRAPH_ID>\n",
|
" \"vector_db_url\": f\"neptune-graph://{graph_identifier}\", # Neptune Analytics endpoint with the format neptune-graph://<GRAPH_ID>\n",
|
||||||
" }\n",
|
" }\n",
|
||||||
")"
|
")"
|
||||||
],
|
]
|
||||||
"outputs": [],
|
|
||||||
"execution_count": null
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
|
|
@ -159,19 +159,19 @@
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"metadata": {},
|
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Prune data and system metadata before running, only if we want \"fresh\" state.\n",
|
"# Prune data and system metadata before running, only if we want \"fresh\" state.\n",
|
||||||
"await prune.prune_data()\n",
|
"await prune.prune_data()\n",
|
||||||
"await prune.prune_system(metadata=True)"
|
"await prune.prune_system(metadata=True)"
|
||||||
],
|
]
|
||||||
"outputs": [],
|
|
||||||
"execution_count": null
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"metadata": {},
|
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"## Setup data and cognify\n",
|
"## Setup data and cognify\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
|
@ -180,7 +180,9 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Add sample text to the dataset\n",
|
"# Add sample text to the dataset\n",
|
||||||
"sample_text_1 = \"\"\"Neptune Analytics is a memory-optimized graph database engine for analytics. With Neptune\n",
|
"sample_text_1 = \"\"\"Neptune Analytics is a memory-optimized graph database engine for analytics. With Neptune\n",
|
||||||
|
|
@ -205,9 +207,7 @@
|
||||||
"\n",
|
"\n",
|
||||||
"# Cognify the text data.\n",
|
"# Cognify the text data.\n",
|
||||||
"await cognify([dataset_name])"
|
"await cognify([dataset_name])"
|
||||||
],
|
]
|
||||||
"outputs": [],
|
|
||||||
"execution_count": null
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
|
|
@ -215,14 +215,16 @@
|
||||||
"source": [
|
"source": [
|
||||||
"## Graph Memory visualization\n",
|
"## Graph Memory visualization\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Initialize Memgraph as a Graph Memory store and save to .artefacts/graph_visualization.html\n",
|
"Initialize a Graph Memory store and save to .artefacts/graph_visualization.html\n",
|
||||||
"\n",
|
"\n",
|
||||||
""
|
""
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Get a graphistry url (Register for a free account at https://www.graphistry.com)\n",
|
"# Get a graphistry url (Register for a free account at https://www.graphistry.com)\n",
|
||||||
"# url = await render_graph()\n",
|
"# url = await render_graph()\n",
|
||||||
|
|
@ -235,9 +237,7 @@
|
||||||
" ).resolve()\n",
|
" ).resolve()\n",
|
||||||
")\n",
|
")\n",
|
||||||
"await visualize_graph(graph_file_path)"
|
"await visualize_graph(graph_file_path)"
|
||||||
],
|
]
|
||||||
"outputs": [],
|
|
||||||
"execution_count": null
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
|
|
@ -250,19 +250,19 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Completion query that uses graph data to form context.\n",
|
"# Completion query that uses graph data to form context.\n",
|
||||||
"graph_completion = await search(query_text=\"What is Neptune Analytics?\", query_type=SearchType.GRAPH_COMPLETION)\n",
|
"graph_completion = await search(query_text=\"What is Neptune Analytics?\", query_type=SearchType.GRAPH_COMPLETION)\n",
|
||||||
"print(\"\\nGraph completion result is:\")\n",
|
"print(\"\\nGraph completion result is:\")\n",
|
||||||
"print(graph_completion)"
|
"print(graph_completion)"
|
||||||
],
|
]
|
||||||
"outputs": [],
|
|
||||||
"execution_count": null
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"metadata": {},
|
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"## SEARCH: RAG Completion\n",
|
"## SEARCH: RAG Completion\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
|
@ -271,19 +271,19 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Completion query that uses document chunks to form context.\n",
|
"# Completion query that uses document chunks to form context.\n",
|
||||||
"rag_completion = await search(query_text=\"What is Neptune Analytics?\", query_type=SearchType.RAG_COMPLETION)\n",
|
"rag_completion = await search(query_text=\"What is Neptune Analytics?\", query_type=SearchType.RAG_COMPLETION)\n",
|
||||||
"print(\"\\nRAG Completion result is:\")\n",
|
"print(\"\\nRAG Completion result is:\")\n",
|
||||||
"print(rag_completion)"
|
"print(rag_completion)"
|
||||||
],
|
]
|
||||||
"outputs": [],
|
|
||||||
"execution_count": null
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"metadata": {},
|
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"## SEARCH: Graph Insights\n",
|
"## SEARCH: Graph Insights\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
|
@ -291,8 +291,10 @@
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"metadata": {},
|
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Search graph insights\n",
|
"# Search graph insights\n",
|
||||||
"insights_results = await search(query_text=\"Neptune Analytics\", query_type=SearchType.INSIGHTS)\n",
|
"insights_results = await search(query_text=\"Neptune Analytics\", query_type=SearchType.INSIGHTS)\n",
|
||||||
|
|
@ -302,13 +304,11 @@
|
||||||
" tgt_node = result[2].get(\"name\", result[2][\"type\"])\n",
|
" tgt_node = result[2].get(\"name\", result[2][\"type\"])\n",
|
||||||
" relationship = result[1].get(\"relationship_name\", \"__relationship__\")\n",
|
" relationship = result[1].get(\"relationship_name\", \"__relationship__\")\n",
|
||||||
" print(f\"- {src_node} -[{relationship}]-> {tgt_node}\")"
|
" print(f\"- {src_node} -[{relationship}]-> {tgt_node}\")"
|
||||||
],
|
]
|
||||||
"outputs": [],
|
|
||||||
"execution_count": null
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"metadata": {},
|
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"## SEARCH: Entity Summaries\n",
|
"## SEARCH: Entity Summaries\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
|
@ -316,8 +316,10 @@
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"metadata": {},
|
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Query all summaries related to query.\n",
|
"# Query all summaries related to query.\n",
|
||||||
"summaries = await search(query_text=\"Neptune Analytics\", query_type=SearchType.SUMMARIES)\n",
|
"summaries = await search(query_text=\"Neptune Analytics\", query_type=SearchType.SUMMARIES)\n",
|
||||||
|
|
@ -326,13 +328,11 @@
|
||||||
" type = summary[\"type\"]\n",
|
" type = summary[\"type\"]\n",
|
||||||
" text = summary[\"text\"]\n",
|
" text = summary[\"text\"]\n",
|
||||||
" print(f\"- {type}: {text}\")"
|
" print(f\"- {type}: {text}\")"
|
||||||
],
|
]
|
||||||
"outputs": [],
|
|
||||||
"execution_count": null
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"metadata": {},
|
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"## SEARCH: Chunks\n",
|
"## SEARCH: Chunks\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
|
@ -340,8 +340,10 @@
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"metadata": {},
|
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"chunks = await search(query_text=\"Neptune Analytics\", query_type=SearchType.CHUNKS)\n",
|
"chunks = await search(query_text=\"Neptune Analytics\", query_type=SearchType.CHUNKS)\n",
|
||||||
"print(\"\\nChunk results are:\")\n",
|
"print(\"\\nChunk results are:\")\n",
|
||||||
|
|
@ -349,9 +351,7 @@
|
||||||
" type = chunk[\"type\"]\n",
|
" type = chunk[\"type\"]\n",
|
||||||
" text = chunk[\"text\"]\n",
|
" text = chunk[\"text\"]\n",
|
||||||
" print(f\"- {type}: {text}\")"
|
" print(f\"- {type}: {text}\")"
|
||||||
],
|
]
|
||||||
"outputs": [],
|
|
||||||
"execution_count": null
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
|
|
||||||
41
tools/check_all_adapters.sh
Executable file
41
tools/check_all_adapters.sh
Executable file
|
|
@ -0,0 +1,41 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# All Database Adapters MyPy Check Script
|
||||||
|
|
||||||
|
set -e # Exit on any error
|
||||||
|
|
||||||
|
echo "🚀 Running MyPy checks on all database adapters..."
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Ensure we're in the project root directory
|
||||||
|
cd "$(dirname "$0")/.."
|
||||||
|
|
||||||
|
# Run all three adapter checks
|
||||||
|
echo "========================================="
|
||||||
|
echo "1️⃣ VECTOR DATABASE ADAPTERS"
|
||||||
|
echo "========================================="
|
||||||
|
./tools/check_vector_adapters.sh
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "========================================="
|
||||||
|
echo "2️⃣ GRAPH DATABASE ADAPTERS"
|
||||||
|
echo "========================================="
|
||||||
|
./tools/check_graph_adapters.sh
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "========================================="
|
||||||
|
echo "3️⃣ HYBRID DATABASE ADAPTERS"
|
||||||
|
echo "========================================="
|
||||||
|
./tools/check_hybrid_adapters.sh
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "🎉 All Database Adapters MyPy Checks Complete!"
|
||||||
|
echo ""
|
||||||
|
echo "🔍 Auto-Discovery Approach:"
|
||||||
|
echo " • Vector Adapters: cognee/infrastructure/databases/vector/**/*Adapter.py"
|
||||||
|
echo " • Graph Adapters: cognee/infrastructure/databases/graph/**/*adapter.py"
|
||||||
|
echo " • Hybrid Adapters: cognee/infrastructure/databases/hybrid/**/*Adapter.py"
|
||||||
|
echo ""
|
||||||
|
echo "🎯 Purpose: Enforce that database adapters are properly typed"
|
||||||
|
echo "🔧 MyPy Configuration: mypy.ini (strict mode enabled)"
|
||||||
|
echo "🚀 Maintenance-Free: Automatically discovers new adapters"
|
||||||
41
tools/check_graph_adapters.sh
Executable file
41
tools/check_graph_adapters.sh
Executable file
|
|
@ -0,0 +1,41 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Graph Database Adapters MyPy Check Script
|
||||||
|
|
||||||
|
set -e # Exit on any error
|
||||||
|
|
||||||
|
echo "🔍 Discovering Graph Database Adapters..."
|
||||||
|
|
||||||
|
# Ensure we're in the project root directory
|
||||||
|
cd "$(dirname "$0")/.."
|
||||||
|
|
||||||
|
# Activate virtual environment
|
||||||
|
source .venv/bin/activate
|
||||||
|
|
||||||
|
# Find all adapter.py and *adapter.py files in graph database directories, excluding utility files
|
||||||
|
graph_adapters=$(find cognee/infrastructure/databases/graph -name "*adapter.py" -o -name "adapter.py" | grep -v "use_graph_adapter.py" | sort)
|
||||||
|
|
||||||
|
if [ -z "$graph_adapters" ]; then
|
||||||
|
echo "No graph database adapters found"
|
||||||
|
exit 0
|
||||||
|
else
|
||||||
|
echo "Found graph database adapters:"
|
||||||
|
echo "$graph_adapters" | sed 's/^/ • /'
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
echo "Running MyPy on graph database adapters..."
|
||||||
|
|
||||||
|
# Use while read to properly handle each file
|
||||||
|
echo "$graph_adapters" | while read -r adapter; do
|
||||||
|
if [ -n "$adapter" ]; then
|
||||||
|
echo "Checking: $adapter"
|
||||||
|
uv run mypy "$adapter" \
|
||||||
|
--config-file mypy.ini \
|
||||||
|
--show-error-codes \
|
||||||
|
--no-error-summary
|
||||||
|
echo ""
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "✅ Graph Database Adapters MyPy Check Complete!"
|
||||||
41
tools/check_hybrid_adapters.sh
Executable file
41
tools/check_hybrid_adapters.sh
Executable file
|
|
@ -0,0 +1,41 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Hybrid Database Adapters MyPy Check Script
|
||||||
|
|
||||||
|
set -e # Exit on any error
|
||||||
|
|
||||||
|
echo "🔍 Discovering Hybrid Database Adapters..."
|
||||||
|
|
||||||
|
# Ensure we're in the project root directory
|
||||||
|
cd "$(dirname "$0")/.."
|
||||||
|
|
||||||
|
# Activate virtual environment
|
||||||
|
source .venv/bin/activate
|
||||||
|
|
||||||
|
# Find all *Adapter.py files in hybrid database directories
|
||||||
|
hybrid_adapters=$(find cognee/infrastructure/databases/hybrid -name "*Adapter.py" -type f | sort)
|
||||||
|
|
||||||
|
if [ -z "$hybrid_adapters" ]; then
|
||||||
|
echo "No hybrid database adapters found"
|
||||||
|
exit 0
|
||||||
|
else
|
||||||
|
echo "Found hybrid database adapters:"
|
||||||
|
echo "$hybrid_adapters" | sed 's/^/ • /'
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
echo "Running MyPy on hybrid database adapters..."
|
||||||
|
|
||||||
|
# Use while read to properly handle each file
|
||||||
|
echo "$hybrid_adapters" | while read -r adapter; do
|
||||||
|
if [ -n "$adapter" ]; then
|
||||||
|
echo "Checking: $adapter"
|
||||||
|
uv run mypy "$adapter" \
|
||||||
|
--config-file mypy.ini \
|
||||||
|
--show-error-codes \
|
||||||
|
--no-error-summary
|
||||||
|
echo ""
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "✅ Hybrid Database Adapters MyPy Check Complete!"
|
||||||
41
tools/check_vector_adapters.sh
Executable file
41
tools/check_vector_adapters.sh
Executable file
|
|
@ -0,0 +1,41 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Vector Database Adapters MyPy Check Script
|
||||||
|
|
||||||
|
set -e # Exit on any error
|
||||||
|
|
||||||
|
echo "🔍 Discovering Vector Database Adapters..."
|
||||||
|
|
||||||
|
# Ensure we're in the project root directory
|
||||||
|
cd "$(dirname "$0")/.."
|
||||||
|
|
||||||
|
# Activate virtual environment
|
||||||
|
source .venv/bin/activate
|
||||||
|
|
||||||
|
# Find all *Adapter.py files in vector database directories
|
||||||
|
vector_adapters=$(find cognee/infrastructure/databases/vector -name "*Adapter.py" -type f | sort)
|
||||||
|
|
||||||
|
if [ -z "$vector_adapters" ]; then
|
||||||
|
echo "No vector database adapters found"
|
||||||
|
exit 0
|
||||||
|
else
|
||||||
|
echo "Found vector database adapters:"
|
||||||
|
echo "$vector_adapters" | sed 's/^/ • /'
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
echo "Running MyPy on vector database adapters..."
|
||||||
|
|
||||||
|
# Use while read to properly handle each file
|
||||||
|
echo "$vector_adapters" | while read -r adapter; do
|
||||||
|
if [ -n "$adapter" ]; then
|
||||||
|
echo "Checking: $adapter"
|
||||||
|
uv run mypy "$adapter" \
|
||||||
|
--config-file mypy.ini \
|
||||||
|
--show-error-codes \
|
||||||
|
--no-error-summary
|
||||||
|
echo ""
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "✅ Vector Database Adapters MyPy Check Complete!"
|
||||||
Loading…
Add table
Reference in a new issue