Compare commits

...
Sign in to create a new pull request.

20 commits

Author SHA1 Message Date
Daulet Amirkhanov
a3d2e7bd2a ruff format 2025-09-07 20:51:24 +01:00
Daulet Amirkhanov
ed83f76f25 clean up todos from lancedb 2025-09-07 20:48:58 +01:00
Daulet Amirkhanov
3a65160839 move check adapters scripts to /tools and update mypy workflow 2025-09-07 20:45:17 +01:00
Daulet Amirkhanov
1d282f3f83 ruff check fix 2025-09-07 20:39:39 +01:00
Daulet Amirkhanov
1a5914e0ab ruff format 2025-09-07 20:39:39 +01:00
Daulet Amirkhanov
d2a2ade643 mypy fix: Fix ChromaDBAdapter mypy errors 2025-09-07 20:39:39 +01:00
Daulet Amirkhanov
0fb962e29a mypy: version Neo4j adapter 2025-09-07 20:39:13 +01:00
Daulet Amirkhanov
b9cd847e9d Remove Memgraph and references to it 2025-09-07 20:38:37 +01:00
Daulet Amirkhanov
aa74672aeb mypy: fix RemoteKuzuAdapter mypy errors 2025-09-07 20:38:26 +01:00
Daulet Amirkhanov
9a2bf0f137 kuzu - improve type inference for connection 2025-09-07 20:38:14 +01:00
Daulet Amirkhanov
b59087841b mypy: first fix KuzuAdapter mypy errors 2025-09-07 20:38:14 +01:00
Daulet Amirkhanov
deaf3debbf mypy: fix PGVectorAdapter mypy errors 2025-09-07 20:37:39 +01:00
Daulet Amirkhanov
eebca89855 mypy: fix LanceDBAdapter mypy errors 2025-09-07 20:37:09 +01:00
Daulet Amirkhanov
4ae41fede3 mypy fix: Fix ChromaDBAdapter mypy errors 2025-09-03 18:10:01 +01:00
Daulet Amirkhanov
26f5ab4f0f mypy: ignore missing imports for third party adapter libraries 2025-09-03 17:34:20 +01:00
Daulet Amirkhanov
baffd9187e adding temporary mypy scripts 2025-09-03 17:33:59 +01:00
Daulet Amirkhanov
85f37a6ee5 make protocols_mypy workflow manually dispatchable 2025-09-03 17:03:40 +01:00
Daulet Amirkhanov
aa686cefe8 Potential fix for code scanning alert no. 150: Workflow does not contain permissions
Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
2025-09-03 16:59:09 +01:00
Daulet Amirkhanov
636d38c018 refactor: remove old MyPy workflow and add new database adapter MyPy check workflow 2025-09-03 16:59:09 +01:00
Daulet Amirkhanov
7d80701381 chore: update mypy and create a GitHub workflow 2025-09-03 16:59:09 +01:00
18 changed files with 648 additions and 1524 deletions

View file

@ -0,0 +1,76 @@
permissions:
contents: read
name: Database Adapter MyPy Check
on:
workflow_dispatch:
push:
branches: [ main, dev ]
paths:
- 'cognee/infrastructure/databases/**'
- 'tools/check_*_adapters.sh'
- 'mypy.ini'
- '.github/workflows/database_protocol_mypy_check.yml'
pull_request:
branches: [ main, dev ]
paths:
- 'cognee/infrastructure/databases/**'
- 'tools/check_*_adapters.sh'
- 'mypy.ini'
- '.github/workflows/database_protocol_mypy_check.yml'
env:
RUNTIME__LOG_LEVEL: ERROR
ENV: 'dev'
jobs:
mypy-database-adapters:
name: MyPy Database Adapter Type Check
runs-on: ubuntu-22.04
steps:
- name: Check out repository
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11.x'
- name: Discover and Check Vector Database Adapters
run: ./tools/check_vector_adapters.sh
- name: Discover and Check Graph Database Adapters
run: ./tools/check_graph_adapters.sh
- name: Discover and Check Hybrid Database Adapters
run: ./tools/check_hybrid_adapters.sh
- name: Protocol Compliance Summary
run: |
echo "✅ Database Adapter MyPy Check Passed!"
echo ""
echo "🔍 Auto-Discovery Approach:"
echo " • Vector Adapters: cognee/infrastructure/databases/vector/**/*Adapter.py"
echo " • Graph Adapters: cognee/infrastructure/databases/graph/**/*adapter.py"
echo " • Hybrid Adapters: cognee/infrastructure/databases/hybrid/**/*Adapter.py"
echo ""
echo "🚀 Using Dedicated Scripts:"
echo " • Vector: ./tools/check_vector_adapters.sh"
echo " • Graph: ./tools/check_graph_adapters.sh"
echo " • Hybrid: ./tools/check_hybrid_adapters.sh"
echo " • All: ./tools/check_all_adapters.sh"
echo ""
echo "🎯 Purpose: Enforce that database adapters are properly typed"
echo "🔧 MyPy Configuration: mypy.ini (strict mode enabled)"
echo "🚀 Maintenance-Free: Automatically discovers new adapters"
echo ""
echo "⚠️ This workflow FAILS on any type errors to ensure adapter quality."
echo " All database adapters must be properly typed."
echo ""
echo "🛠️ To fix type issues locally, run:"
echo " ./tools/check_all_adapters.sh # Check all adapters"
echo " ./tools/check_vector_adapters.sh # Check vector adapters only"
echo " mypy <adapter_file_path> --config-file mypy.ini # Check specific file"

View file

@ -118,7 +118,7 @@ class HealthChecker:
# Test basic operation with actual graph query # Test basic operation with actual graph query
if hasattr(engine, "execute"): if hasattr(engine, "execute"):
# For SQL-like graph DBs (Neo4j, Memgraph) # For SQL-like graph DBs (Neo4j)
await engine.execute("MATCH () RETURN count(*) LIMIT 1") await engine.execute("MATCH () RETURN count(*) LIMIT 1")
elif hasattr(engine, "query"): elif hasattr(engine, "query"):
# For other graph engines # For other graph engines

View file

@ -179,5 +179,5 @@ def create_graph_engine(
raise EnvironmentError( raise EnvironmentError(
f"Unsupported graph database provider: {graph_database_provider}. " f"Unsupported graph database provider: {graph_database_provider}. "
f"Supported providers are: {', '.join(list(supported_databases.keys()) + ['neo4j', 'falkordb', 'kuzu', 'kuzu-remote', 'memgraph', 'neptune', 'neptune_analytics'])}" f"Supported providers are: {', '.join(list(supported_databases.keys()) + ['neo4j', 'falkordb', 'kuzu', 'kuzu-remote', 'neptune', 'neptune_analytics'])}"
) )

View file

@ -10,7 +10,7 @@ from kuzu.database import Database
from datetime import datetime, timezone from datetime import datetime, timezone
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Dict, Any, List, Union, Optional, Tuple, Type from typing import Dict, Any, List, Union, Optional, Tuple, Type, AsyncGenerator
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.utils.run_sync import run_sync from cognee.infrastructure.utils.run_sync import run_sync
@ -22,7 +22,7 @@ from cognee.infrastructure.databases.graph.graph_db_interface import (
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.modules.storage.utils import JSONEncoder from cognee.modules.storage.utils import JSONEncoder
from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int
from cognee.tasks.temporal_graph.models import Timestamp from cognee.modules.engine.models.Timestamp import Timestamp
logger = get_logger() logger = get_logger()
@ -146,15 +146,21 @@ class KuzuAdapter(GraphDBInterface):
logger.error(f"Failed to initialize Kuzu database: {e}") logger.error(f"Failed to initialize Kuzu database: {e}")
raise e raise e
def _get_connection(self) -> Connection:
"""Get the connection to the Kuzu database."""
if not self.connection:
raise RuntimeError("Kuzu database connection not initialized")
return self.connection
async def push_to_s3(self) -> None: async def push_to_s3(self) -> None:
if os.getenv("STORAGE_BACKEND", "").lower() == "s3" and hasattr(self, "temp_graph_file"): if os.getenv("STORAGE_BACKEND", "").lower() == "s3" and hasattr(self, "temp_graph_file"):
from cognee.infrastructure.files.storage.S3FileStorage import S3FileStorage from cognee.infrastructure.files.storage.S3FileStorage import S3FileStorage
s3_file_storage = S3FileStorage("") s3_file_storage = S3FileStorage("")
if self.connection: if self._get_connection():
async with self.KUZU_ASYNC_LOCK: async with self.KUZU_ASYNC_LOCK:
self.connection.execute("CHECKPOINT;") self._get_connection().execute("CHECKPOINT;")
s3_file_storage.s3.put(self.temp_graph_file, self.db_path, recursive=True) s3_file_storage.s3.put(self.temp_graph_file, self.db_path, recursive=True)
@ -167,7 +173,9 @@ class KuzuAdapter(GraphDBInterface):
except FileNotFoundError: except FileNotFoundError:
logger.warning(f"Kuzu S3 storage file not found: {self.db_path}") logger.warning(f"Kuzu S3 storage file not found: {self.db_path}")
async def query(self, query: str, params: Optional[dict] = None) -> List[Tuple]: async def query(
self, query: str, params: Optional[Dict[str, Any]] = None
) -> List[Tuple[Any, ...]]:
""" """
Execute a Kuzu query asynchronously with automatic reconnection. Execute a Kuzu query asynchronously with automatic reconnection.
@ -190,23 +198,32 @@ class KuzuAdapter(GraphDBInterface):
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
params = params or {} params = params or {}
def blocking_query(): def blocking_query() -> List[Tuple[Any, ...]]:
try: try:
if not self.connection: if not self._get_connection():
logger.debug("Reconnecting to Kuzu database...") logger.debug("Reconnecting to Kuzu database...")
self._initialize_connection() self._initialize_connection()
result = self.connection.execute(query, params) if not self._get_connection():
raise RuntimeError("Failed to establish database connection")
result = self._get_connection().execute(query, params)
rows = [] rows = []
while result.has_next(): if not isinstance(result, list):
row = result.get_next() result = [result]
processed_rows = []
for val in row: # Handle QueryResult vs List[QueryResult] union type
if hasattr(val, "as_py"): for single_result in result:
val = val.as_py() while single_result.has_next():
processed_rows.append(val) row = single_result.get_next()
rows.append(tuple(processed_rows)) processed_rows = []
for val in row:
if hasattr(val, "as_py"):
val = val.as_py()
processed_rows.append(val)
rows.append(tuple(processed_rows))
return rows return rows
except Exception as e: except Exception as e:
logger.error(f"Query execution failed: {str(e)}") logger.error(f"Query execution failed: {str(e)}")
@ -215,7 +232,7 @@ class KuzuAdapter(GraphDBInterface):
return await loop.run_in_executor(self.executor, blocking_query) return await loop.run_in_executor(self.executor, blocking_query)
@asynccontextmanager @asynccontextmanager
async def get_session(self): async def get_session(self) -> AsyncGenerator[Optional[Connection], None]:
""" """
Get a database session. Get a database session.
@ -224,7 +241,7 @@ class KuzuAdapter(GraphDBInterface):
and on exit performs cleanup if necessary. and on exit performs cleanup if necessary.
""" """
try: try:
yield self.connection yield self._get_connection()
finally: finally:
pass pass
@ -255,7 +272,7 @@ class KuzuAdapter(GraphDBInterface):
def _edge_query_and_params( def _edge_query_and_params(
self, from_node: str, to_node: str, relationship_name: str, properties: Dict[str, Any] self, from_node: str, to_node: str, relationship_name: str, properties: Dict[str, Any]
) -> Tuple[str, dict]: ) -> Tuple[str, Dict[str, Any]]:
"""Build the edge creation query and parameters.""" """Build the edge creation query and parameters."""
now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f") now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")
query = """ query = """
@ -305,7 +322,9 @@ class KuzuAdapter(GraphDBInterface):
result = await self.query(query_str, {"id": node_id}) result = await self.query(query_str, {"id": node_id})
return result[0][0] if result else False return result[0][0] if result else False
async def add_node(self, node: DataPoint) -> None: async def add_node(
self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None
) -> None:
""" """
Add a single node to the graph if it doesn't exist. Add a single node to the graph if it doesn't exist.
@ -319,20 +338,32 @@ class KuzuAdapter(GraphDBInterface):
- node (DataPoint): The node to be added, represented as a DataPoint. - node (DataPoint): The node to be added, represented as a DataPoint.
""" """
try: try:
properties = node.model_dump() if hasattr(node, "model_dump") else vars(node) if isinstance(node, str):
# Handle string node ID with properties parameter
node_properties = properties or {}
core_properties = {
"id": node,
"name": str(node_properties.get("name", "")),
"type": str(node_properties.get("type", "")),
}
# Use the passed properties, excluding core fields
other_properties = {
k: v for k, v in node_properties.items() if k not in ["id", "name", "type"]
}
else:
# Handle DataPoint object
node_properties = node.model_dump()
core_properties = {
"id": str(node_properties.get("id", "")),
"name": str(node_properties.get("name", "")),
"type": str(node_properties.get("type", "")),
}
# Remove core fields from other properties
other_properties = {
k: v for k, v in node_properties.items() if k not in ["id", "name", "type"]
}
# Extract core fields with defaults if not present core_properties["properties"] = json.dumps(other_properties, cls=JSONEncoder)
core_properties = {
"id": str(properties.get("id", "")),
"name": str(properties.get("name", "")),
"type": str(properties.get("type", "")),
}
# Remove core fields from other properties
for key in core_properties:
properties.pop(key, None)
core_properties["properties"] = json.dumps(properties, cls=JSONEncoder)
# Add timestamps for new node # Add timestamps for new node
now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f") now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")
@ -360,7 +391,7 @@ class KuzuAdapter(GraphDBInterface):
logger.error(f"Failed to add node: {e}") logger.error(f"Failed to add node: {e}")
raise raise
@record_graph_changes @record_graph_changes # type: ignore
async def add_nodes(self, nodes: List[DataPoint]) -> None: async def add_nodes(self, nodes: List[DataPoint]) -> None:
""" """
Add multiple nodes to the graph in a batch operation. Add multiple nodes to the graph in a batch operation.
@ -568,7 +599,9 @@ class KuzuAdapter(GraphDBInterface):
) )
return result[0][0] if result else False return result[0][0] if result else False
async def has_edges(self, edges: List[Tuple[str, str, str]]) -> List[Tuple[str, str, str]]: async def has_edges(
self, edges: List[Tuple[str, str, str, Dict[str, Any]]]
) -> List[Tuple[str, str, str, Dict[str, Any]]]:
""" """
Check if multiple edges exist in a batch operation. Check if multiple edges exist in a batch operation.
@ -599,7 +632,7 @@ class KuzuAdapter(GraphDBInterface):
"to_id": str(to_node), # Ensure string type "to_id": str(to_node), # Ensure string type
"relationship_name": str(edge_label), # Ensure string type "relationship_name": str(edge_label), # Ensure string type
} }
for from_node, to_node, edge_label in edges for from_node, to_node, edge_label, _ in edges
] ]
# Batch check query with direct string comparison # Batch check query with direct string comparison
@ -615,9 +648,21 @@ class KuzuAdapter(GraphDBInterface):
results = await self.query(query, {"edges": edge_params}) results = await self.query(query, {"edges": edge_params})
# Convert results back to tuples and ensure string types # Convert results back to tuples and ensure string types
existing_edges = [(str(row[0]), str(row[1]), str(row[2])) for row in results] # Find the original edge properties for each existing edge
# TODO: get review on this
existing_edges = []
for row in results:
from_id, to_id, rel_name = str(row[0]), str(row[1]), str(row[2])
# Find the original properties from the input edges
original_props = {}
for orig_from, orig_to, orig_rel, orig_props in edges:
if orig_from == from_id and orig_to == to_id and orig_rel == rel_name:
original_props = orig_props
break
existing_edges.append((from_id, to_id, rel_name, original_props))
logger.debug(f"Found {len(existing_edges)} existing edges out of {len(edges)} checked") logger.debug(f"Found {len(existing_edges)} existing edges out of {len(edges)} checked")
# TODO: otherwise, we can just return dummy properties since they are not used apparently
return existing_edges return existing_edges
except Exception as e: except Exception as e:
@ -626,10 +671,10 @@ class KuzuAdapter(GraphDBInterface):
async def add_edge( async def add_edge(
self, self,
from_node: str, source_id: str,
to_node: str, target_id: str,
relationship_name: str, relationship_name: str,
edge_properties: Dict[str, Any] = {}, properties: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
""" """
Add an edge between two nodes. Add an edge between two nodes.
@ -641,23 +686,23 @@ class KuzuAdapter(GraphDBInterface):
Parameters: Parameters:
----------- -----------
- from_node (str): The identifier of the source node from which the edge originates. - source_id (str): The identifier of the source node from which the edge originates.
- to_node (str): The identifier of the target node to which the edge points. - target_id (str): The identifier of the target node to which the edge points.
- relationship_name (str): The label of the edge to be created, representing the - relationship_name (str): The label of the edge to be created, representing the
relationship name. relationship name.
- edge_properties (Dict[str, Any]): A dictionary containing properties for the edge. - properties (Optional[Dict[str, Any]]): A dictionary containing properties for the edge.
(default {}) (default None)
""" """
try: try:
query, params = self._edge_query_and_params( query, params = self._edge_query_and_params(
from_node, to_node, relationship_name, edge_properties source_id, target_id, relationship_name, properties or {}
) )
await self.query(query, params) await self.query(query, params)
except Exception as e: except Exception as e:
logger.error(f"Failed to add edge: {e}") logger.error(f"Failed to add edge: {e}")
raise raise
@record_graph_changes @record_graph_changes # type: ignore
async def add_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> None: async def add_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> None:
""" """
Add multiple edges in a batch operation. Add multiple edges in a batch operation.
@ -712,7 +757,7 @@ class KuzuAdapter(GraphDBInterface):
logger.error(f"Failed to add edges in batch: {e}") logger.error(f"Failed to add edges in batch: {e}")
raise raise
async def get_edges(self, node_id: str) -> List[Tuple[Dict[str, Any], str, Dict[str, Any]]]: async def get_edges(self, node_id: str) -> List[Tuple[str, str, str, Dict[str, Any]]]:
""" """
Get all edges connected to a node. Get all edges connected to a node.
@ -727,9 +772,8 @@ class KuzuAdapter(GraphDBInterface):
Returns: Returns:
-------- --------
- List[Tuple[Dict[str, Any], str, Dict[str, Any]]]: A list of tuples where each - List[Tuple[str, str, str, Dict[str, Any]]]: A list of tuples where each
tuple contains (source_node, relationship_name, target_node), with source_node and tuple contains (source_id, relationship_name, target_id, edge_properties).
target_node as dictionaries of node properties.
""" """
query_str = """ query_str = """
MATCH (n:Node)-[r]-(m:Node) MATCH (n:Node)-[r]-(m:Node)
@ -750,12 +794,14 @@ class KuzuAdapter(GraphDBInterface):
""" """
try: try:
results = await self.query(query_str, {"node_id": node_id}) results = await self.query(query_str, {"node_id": node_id})
edges = [] edges: List[Tuple[str, str, str, Dict[str, Any]]] = []
for row in results: for row in results:
if row and len(row) == 3: if row and len(row) == 3:
source_node = self._parse_node_properties(row[0]) source_node = self._parse_node_properties(row[0])
relationship_name = row[1]
target_node = self._parse_node_properties(row[2]) target_node = self._parse_node_properties(row[2])
edges.append((source_node, row[1], target_node)) # TODO: any edge properties we can add? Adding empty to avoid modifying query without reason
edges.append((source_node, relationship_name, target_node, {})) # type: ignore # currently each node is a dict, wihle typing expects nodes to be strings
return edges return edges
except Exception as e: except Exception as e:
logger.error(f"Failed to get edges for node {node_id}: {e}") logger.error(f"Failed to get edges for node {node_id}: {e}")
@ -977,7 +1023,7 @@ class KuzuAdapter(GraphDBInterface):
return [] return []
async def get_connections( async def get_connections(
self, node_id: str self, node_id: Union[str, UUID]
) -> List[Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]]: ) -> List[Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]]:
""" """
Get all nodes connected to a given node. Get all nodes connected to a given node.
@ -1019,7 +1065,9 @@ class KuzuAdapter(GraphDBInterface):
} }
""" """
try: try:
results = await self.query(query_str, {"node_id": node_id}) # Convert UUID to string if needed
node_id_str = str(node_id)
results = await self.query(query_str, {"node_id": node_id_str})
edges = [] edges = []
for row in results: for row in results:
if row and len(row) == 3: if row and len(row) == 3:
@ -1177,7 +1225,7 @@ class KuzuAdapter(GraphDBInterface):
async def get_nodeset_subgraph( async def get_nodeset_subgraph(
self, node_type: Type[Any], node_name: List[str] self, node_type: Type[Any], node_name: List[str]
) -> Tuple[List[Tuple[str, dict]], List[Tuple[str, str, str, dict]]]: ) -> Tuple[List[Tuple[int, Dict[str, Any]]], List[Tuple[int, int, str, Dict[str, Any]]]]:
""" """
Get subgraph for a set of nodes based on type and names. Get subgraph for a set of nodes based on type and names.
@ -1225,9 +1273,9 @@ class KuzuAdapter(GraphDBInterface):
RETURN n.id, n.name, n.type, n.properties RETURN n.id, n.name, n.type, n.properties
""" """
node_rows = await self.query(nodes_query, {"ids": all_ids}) node_rows = await self.query(nodes_query, {"ids": all_ids})
nodes: List[Tuple[str, dict]] = [] nodes: List[Tuple[str, Dict[str, Any]]] = []
for node_id, name, typ, props in node_rows: for node_id, name, typ, props in node_rows:
data = {"id": node_id, "name": name, "type": typ} data: Dict[str, Any] = {"id": node_id, "name": name, "type": typ}
if props: if props:
try: try:
data.update(json.loads(props)) data.update(json.loads(props))
@ -1241,22 +1289,22 @@ class KuzuAdapter(GraphDBInterface):
RETURN a.id, b.id, r.relationship_name, r.properties RETURN a.id, b.id, r.relationship_name, r.properties
""" """
edge_rows = await self.query(edges_query, {"ids": all_ids}) edge_rows = await self.query(edges_query, {"ids": all_ids})
edges: List[Tuple[str, str, str, dict]] = [] edges: List[Tuple[str, str, str, Dict[str, Any]]] = []
for from_id, to_id, rel_type, props in edge_rows: for from_id, to_id, rel_type, props in edge_rows:
data = {} edge_data: Dict[str, Any] = {}
if props: if props:
try: try:
data = json.loads(props) edge_data = json.loads(props)
except json.JSONDecodeError: except json.JSONDecodeError:
logger.warning(f"Failed to parse JSON props for edge {from_id}->{to_id}") logger.warning(f"Failed to parse JSON props for edge {from_id}->{to_id}")
edges.append((from_id, to_id, rel_type, data)) edges.append((from_id, to_id, rel_type, edge_data))
return nodes, edges return nodes, edges # type: ignore # Interface expects int IDs but string IDs are more natural for graph DBs
async def get_filtered_graph_data( async def get_filtered_graph_data(
self, attribute_filters: List[Dict[str, List[Union[str, int]]]] self, attribute_filters: List[Dict[str, List[Union[str, int]]]]
): ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
""" """
Get filtered nodes and relationships based on attributes. Get filtered nodes and relationships based on attributes.
@ -1299,7 +1347,7 @@ class KuzuAdapter(GraphDBInterface):
) )
return ([n[0] for n in nodes], [e[0] for e in edges]) return ([n[0] for n in nodes], [e[0] for e in edges])
async def get_graph_metrics(self, include_optional=False) -> Dict[str, Any]: async def get_graph_metrics(self, include_optional: bool = False) -> Dict[str, Any]:
""" """
Get metrics on graph structure and connectivity. Get metrics on graph structure and connectivity.
@ -1322,8 +1370,8 @@ class KuzuAdapter(GraphDBInterface):
try: try:
# Get basic graph data # Get basic graph data
nodes, edges = await self.get_model_independent_graph_data() nodes, edges = await self.get_model_independent_graph_data()
num_nodes = len(nodes[0]["nodes"]) if nodes else 0 num_nodes = len(nodes[0]["nodes"]) if nodes else 0 # type: ignore # nodes is type string?
num_edges = len(edges[0]["elements"]) if edges else 0 num_edges = len(edges[0]["elements"]) if edges else 0 # type: ignore # edges is type string?
# Calculate mandatory metrics # Calculate mandatory metrics
mandatory_metrics = { mandatory_metrics = {
@ -1483,8 +1531,8 @@ class KuzuAdapter(GraphDBInterface):
It raises exceptions for failures occurring during deletion processes. It raises exceptions for failures occurring during deletion processes.
""" """
try: try:
if self.connection: if self._get_connection():
self.connection.close() self._get_connection().close()
self.connection = None self.connection = None
if self.db: if self.db:
self.db.close() self.db.close()
@ -1515,7 +1563,7 @@ class KuzuAdapter(GraphDBInterface):
occur during file deletions or initializations carefully. occur during file deletions or initializations carefully.
""" """
try: try:
if self.connection: if self._get_connection():
self.connection = None self.connection = None
if self.db: if self.db:
self.db.close() self.db.close()
@ -1531,20 +1579,30 @@ class KuzuAdapter(GraphDBInterface):
# Reinitialize the database # Reinitialize the database
self._initialize_connection() self._initialize_connection()
if not self._get_connection():
raise RuntimeError("Failed to establish database connection")
# Verify the database is empty # Verify the database is empty
result = self.connection.execute("MATCH (n:Node) RETURN COUNT(n)") result = self._get_connection().execute("MATCH (n:Node) RETURN COUNT(n)")
count = result.get_next()[0] if result.has_next() else 0 if not isinstance(result, list):
result = [result]
for single_result in result:
_next = single_result.get_next()
if not isinstance(_next, list):
raise RuntimeError("Expected list of results")
count = _next[0] if _next else 0
if count > 0: if count > 0:
logger.warning( logger.warning(
f"Database still contains {count} nodes after clearing, forcing deletion" f"Database still contains {count} nodes after clearing, forcing deletion"
) )
self.connection.execute("MATCH (n:Node) DETACH DELETE n") self._get_connection().execute("MATCH (n:Node) DETACH DELETE n")
logger.info("Database cleared successfully") logger.info("Database cleared successfully")
except Exception as e: except Exception as e:
logger.error(f"Error during database clearing: {e}") logger.error(f"Error during database clearing: {e}")
raise raise
async def get_document_subgraph(self, data_id: str): async def get_document_subgraph(self, data_id: str) -> Optional[Dict[str, Any]]:
""" """
Get all nodes that should be deleted when removing a document. Get all nodes that should be deleted when removing a document.
@ -1616,7 +1674,7 @@ class KuzuAdapter(GraphDBInterface):
"orphan_types": result[0][4], "orphan_types": result[0][4],
} }
async def get_degree_one_nodes(self, node_type: str): async def get_degree_one_nodes(self, node_type: str) -> List[Dict[str, Any]]:
""" """
Get all nodes that have only one connection. Get all nodes that have only one connection.
@ -1769,8 +1827,8 @@ class KuzuAdapter(GraphDBInterface):
ids: List[str] = [] ids: List[str] = []
if time_from and time_to: if time_from and time_to:
time_from = date_to_int(time_from) time_from_int = date_to_int(time_from)
time_to = date_to_int(time_to) time_to_int = date_to_int(time_to)
cypher = f""" cypher = f"""
MATCH (n:Node) MATCH (n:Node)
@ -1782,13 +1840,13 @@ class KuzuAdapter(GraphDBInterface):
WHEN t_str IS NULL OR t_str = '' THEN NULL WHEN t_str IS NULL OR t_str = '' THEN NULL
ELSE CAST(t_str AS INT64) ELSE CAST(t_str AS INT64)
END AS t END AS t
WHERE t >= {time_from} WHERE t >= {time_from_int}
AND t <= {time_to} AND t <= {time_to_int}
RETURN n.id as id RETURN n.id as id
""" """
elif time_from: elif time_from:
time_from = date_to_int(time_from) time_from_int = date_to_int(time_from)
cypher = f""" cypher = f"""
MATCH (n:Node) MATCH (n:Node)
@ -1800,12 +1858,12 @@ class KuzuAdapter(GraphDBInterface):
WHEN t_str IS NULL OR t_str = '' THEN NULL WHEN t_str IS NULL OR t_str = '' THEN NULL
ELSE CAST(t_str AS INT64) ELSE CAST(t_str AS INT64)
END AS t END AS t
WHERE t >= {time_from} WHERE t >= {time_from_int}
RETURN n.id as id RETURN n.id as id
""" """
elif time_to: elif time_to:
time_to = date_to_int(time_to) time_to_int = date_to_int(time_to)
cypher = f""" cypher = f"""
MATCH (n:Node) MATCH (n:Node)
@ -1817,12 +1875,12 @@ class KuzuAdapter(GraphDBInterface):
WHEN t_str IS NULL OR t_str = '' THEN NULL WHEN t_str IS NULL OR t_str = '' THEN NULL
ELSE CAST(t_str AS INT64) ELSE CAST(t_str AS INT64)
END AS t END AS t
WHERE t <= {time_to} WHERE t <= {time_to_int}
RETURN n.id as id RETURN n.id as id
""" """
else: else:
return ids return ", ".join(f"'{uid}'" for uid in ids)
time_nodes = await self.query(cypher) time_nodes = await self.query(cypher)
time_ids_list = [item[0] for item in time_nodes] time_ids_list = [item[0] for item in time_nodes]

View file

@ -2,7 +2,7 @@
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
import json import json
from typing import Dict, Any, List, Optional, Tuple from typing import Dict, Any, List, Optional, Tuple, Union
import aiohttp import aiohttp
from uuid import UUID from uuid import UUID
@ -14,7 +14,7 @@ logger = get_logger()
class UUIDEncoder(json.JSONEncoder): class UUIDEncoder(json.JSONEncoder):
"""Custom JSON encoder that handles UUID objects.""" """Custom JSON encoder that handles UUID objects."""
def default(self, obj): def default(self, obj: Union[UUID, Any]) -> Any:
if isinstance(obj, UUID): if isinstance(obj, UUID):
return str(obj) return str(obj)
return super().default(obj) return super().default(obj)
@ -36,7 +36,7 @@ class RemoteKuzuAdapter(KuzuAdapter):
self.api_url = api_url self.api_url = api_url
self.username = username self.username = username
self.password = password self.password = password
self._session = None self._session: Optional[aiohttp.ClientSession] = None
self._schema_initialized = False self._schema_initialized = False
async def _get_session(self) -> aiohttp.ClientSession: async def _get_session(self) -> aiohttp.ClientSession:
@ -45,13 +45,13 @@ class RemoteKuzuAdapter(KuzuAdapter):
self._session = aiohttp.ClientSession() self._session = aiohttp.ClientSession()
return self._session return self._session
async def close(self): async def close(self) -> None:
"""Close the adapter and its session.""" """Close the adapter and its session."""
if self._session and not self._session.closed: if self._session and not self._session.closed:
await self._session.close() await self._session.close()
self._session = None self._session = None
async def _make_request(self, endpoint: str, data: dict) -> dict: async def _make_request(self, endpoint: str, data: dict[str, Any]) -> dict[str, Any]:
"""Make a request to the Kuzu API.""" """Make a request to the Kuzu API."""
url = f"{self.api_url}{endpoint}" url = f"{self.api_url}{endpoint}"
session = await self._get_session() session = await self._get_session()
@ -73,13 +73,15 @@ class RemoteKuzuAdapter(KuzuAdapter):
status=response.status, status=response.status,
message=error_detail, message=error_detail,
) )
return await response.json() return await response.json() # type: ignore
except aiohttp.ClientError as e: except aiohttp.ClientError as e:
logger.error(f"API request failed: {str(e)}") logger.error(f"API request failed: {str(e)}")
logger.error(f"Request data: {data}") logger.error(f"Request data: {data}")
raise raise
async def query(self, query: str, params: Optional[dict] = None) -> List[Tuple]: async def query(
self, query: str, params: Optional[dict[str, Any]] = None
) -> List[Tuple[Any, ...]]:
"""Execute a Kuzu query via the REST API.""" """Execute a Kuzu query via the REST API."""
try: try:
# Initialize schema if needed # Initialize schema if needed
@ -126,7 +128,7 @@ class RemoteKuzuAdapter(KuzuAdapter):
logger.error(f"Failed to check schema: {e}") logger.error(f"Failed to check schema: {e}")
return False return False
async def _create_schema(self): async def _create_schema(self) -> None:
"""Create the required schema tables.""" """Create the required schema tables."""
try: try:
# Create Node table if it doesn't exist # Create Node table if it doesn't exist
@ -180,7 +182,7 @@ class RemoteKuzuAdapter(KuzuAdapter):
logger.error(f"Failed to create schema: {e}") logger.error(f"Failed to create schema: {e}")
raise raise
async def _initialize_schema(self): async def _initialize_schema(self) -> None:
"""Initialize the database schema if it doesn't exist.""" """Initialize the database schema if it doesn't exist."""
if self._schema_initialized: if self._schema_initialized:
return return

View file

@ -8,11 +8,11 @@ from neo4j import AsyncSession
from neo4j import AsyncGraphDatabase from neo4j import AsyncGraphDatabase
from neo4j.exceptions import Neo4jError from neo4j.exceptions import Neo4jError
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Optional, Any, List, Dict, Type, Tuple from typing import Optional, Any, List, Dict, Type, Tuple, Union, AsyncGenerator
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int
from cognee.tasks.temporal_graph.models import Timestamp from cognee.modules.engine.models.Timestamp import Timestamp
from cognee.shared.logging_utils import get_logger, ERROR from cognee.shared.logging_utils import get_logger, ERROR
from cognee.infrastructure.databases.graph.graph_db_interface import ( from cognee.infrastructure.databases.graph.graph_db_interface import (
GraphDBInterface, GraphDBInterface,
@ -79,14 +79,14 @@ class Neo4jAdapter(GraphDBInterface):
) )
@asynccontextmanager @asynccontextmanager
async def get_session(self) -> AsyncSession: async def get_session(self) -> AsyncGenerator[AsyncSession, None]:
""" """
Get a session for database operations. Get a session for database operations.
""" """
async with self.driver.session(database=self.graph_database_name) as session: async with self.driver.session(database=self.graph_database_name) as session:
yield session yield session
@deadlock_retry() @deadlock_retry() # type: ignore
async def query( async def query(
self, self,
query: str, query: str,
@ -112,6 +112,7 @@ class Neo4jAdapter(GraphDBInterface):
async with self.get_session() as session: async with self.get_session() as session:
result = await session.run(query, parameters=params) result = await session.run(query, parameters=params)
data = await result.data() data = await result.data()
# TODO: why we don't return List[Dict[str, Any]]?
return data return data
except Neo4jError as error: except Neo4jError as error:
logger.error("Neo4j query error: %s", error, exc_info=True) logger.error("Neo4j query error: %s", error, exc_info=True)
@ -141,21 +142,29 @@ class Neo4jAdapter(GraphDBInterface):
) )
return results[0]["node_exists"] if len(results) > 0 else False return results[0]["node_exists"] if len(results) > 0 else False
async def add_node(self, node: DataPoint): async def add_node(
self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None
) -> None:
""" """
Add a new node to the database based on the provided DataPoint object. Add a new node to the database based on the provided DataPoint object or string ID.
Parameters: Parameters:
----------- -----------
- node (DataPoint): An instance of DataPoint representing the node to add. - node (Union[DataPoint, str]): An instance of DataPoint or string ID representing the node to add.
- properties (Optional[Dict[str, Any]]): Properties to set on the node when node is a string ID.
Returns:
--------
The result of the query execution, typically the ID of the added node.
""" """
serialized_properties = self.serialize_properties(node.model_dump()) if isinstance(node, str):
# TODO: this was not handled in the original code, check if it is correct
# Handle string node ID with properties parameter
node_id = node
node_label = "Node" # Default label for string nodes
serialized_properties = self.serialize_properties(properties or {})
else:
# Handle DataPoint object
node_id = str(node.id)
node_label = type(node).__name__
serialized_properties = self.serialize_properties(node.model_dump())
query = dedent( query = dedent(
f"""MERGE (node: `{BASE_LABEL}`{{id: $node_id}}) f"""MERGE (node: `{BASE_LABEL}`{{id: $node_id}})
@ -167,16 +176,16 @@ class Neo4jAdapter(GraphDBInterface):
) )
params = { params = {
"node_id": str(node.id), "node_id": node_id,
"node_label": type(node).__name__, "node_label": node_label,
"properties": serialized_properties, "properties": serialized_properties,
} }
return await self.query(query, params) await self.query(query, params)
@record_graph_changes @record_graph_changes # type: ignore
@override_distributed(queued_add_nodes) @override_distributed(queued_add_nodes) # type: ignore
async def add_nodes(self, nodes: list[DataPoint]) -> None: async def add_nodes(self, nodes: List[DataPoint]) -> None:
""" """
Add multiple nodes to the database in a single query. Add multiple nodes to the database in a single query.
@ -201,7 +210,7 @@ class Neo4jAdapter(GraphDBInterface):
RETURN ID(labeledNode) AS internal_id, labeledNode.id AS nodeId RETURN ID(labeledNode) AS internal_id, labeledNode.id AS nodeId
""" """
nodes = [ node_params = [
{ {
"node_id": str(node.id), "node_id": str(node.id),
"label": type(node).__name__, "label": type(node).__name__,
@ -210,10 +219,9 @@ class Neo4jAdapter(GraphDBInterface):
for node in nodes for node in nodes
] ]
results = await self.query(query, dict(nodes=nodes)) await self.query(query, dict(nodes=node_params))
return results
async def extract_node(self, node_id: str): async def extract_node(self, node_id: str) -> Optional[Dict[str, Any]]:
""" """
Retrieve a single node from the database by its ID. Retrieve a single node from the database by its ID.
@ -231,7 +239,7 @@ class Neo4jAdapter(GraphDBInterface):
return results[0] if len(results) > 0 else None return results[0] if len(results) > 0 else None
async def extract_nodes(self, node_ids: List[str]): async def extract_nodes(self, node_ids: List[str]) -> List[Dict[str, Any]]:
""" """
Retrieve multiple nodes from the database by their IDs. Retrieve multiple nodes from the database by their IDs.
@ -256,7 +264,7 @@ class Neo4jAdapter(GraphDBInterface):
return [result["node"] for result in results] return [result["node"] for result in results]
async def delete_node(self, node_id: str): async def delete_node(self, node_id: str) -> None:
""" """
Remove a node from the database identified by its ID. Remove a node from the database identified by its ID.
@ -273,7 +281,7 @@ class Neo4jAdapter(GraphDBInterface):
query = f"MATCH (node: `{BASE_LABEL}`{{id: $node_id}}) DETACH DELETE node" query = f"MATCH (node: `{BASE_LABEL}`{{id: $node_id}}) DETACH DELETE node"
params = {"node_id": node_id} params = {"node_id": node_id}
return await self.query(query, params) await self.query(query, params)
async def delete_nodes(self, node_ids: list[str]) -> None: async def delete_nodes(self, node_ids: list[str]) -> None:
""" """
@ -296,18 +304,18 @@ class Neo4jAdapter(GraphDBInterface):
params = {"node_ids": node_ids} params = {"node_ids": node_ids}
return await self.query(query, params) await self.query(query, params)
async def has_edge(self, from_node: UUID, to_node: UUID, edge_label: str) -> bool: async def has_edge(self, source_id: str, target_id: str, relationship_name: str) -> bool:
""" """
Check if an edge exists between two nodes with the specified IDs and edge label. Check if an edge exists between two nodes with the specified IDs and edge label.
Parameters: Parameters:
----------- -----------
- from_node (UUID): The ID of the node from which the edge originates. - source_id (str): The ID of the node from which the edge originates.
- to_node (UUID): The ID of the node to which the edge points. - target_id (str): The ID of the node to which the edge points.
- edge_label (str): The label of the edge to check for existence. - relationship_name (str): The label of the edge to check for existence.
Returns: Returns:
-------- --------
@ -315,27 +323,28 @@ class Neo4jAdapter(GraphDBInterface):
- bool: True if the edge exists, otherwise False. - bool: True if the edge exists, otherwise False.
""" """
query = f""" query = f"""
MATCH (from_node: `{BASE_LABEL}`)-[:`{edge_label}`]->(to_node: `{BASE_LABEL}`) MATCH (from_node: `{BASE_LABEL}`)-[:`{relationship_name}`]->(to_node: `{BASE_LABEL}`)
WHERE from_node.id = $from_node_id AND to_node.id = $to_node_id WHERE from_node.id = $source_id AND to_node.id = $target_id
RETURN COUNT(relationship) > 0 AS edge_exists RETURN COUNT(relationship) > 0 AS edge_exists
""" """
params = { params = {
"from_node_id": str(from_node), "source_id": str(source_id),
"to_node_id": str(to_node), "target_id": str(target_id),
} }
edge_exists = await self.query(query, params) edge_exists = await self.query(query, params)
assert isinstance(edge_exists, bool), "Edge existence check should return a boolean"
return edge_exists return edge_exists
async def has_edges(self, edges): async def has_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> List[bool]:
""" """
Check if multiple edges exist based on provided edge criteria. Check if multiple edges exist based on provided edge criteria.
Parameters: Parameters:
----------- -----------
- edges: A list of edge specifications to check for existence. - edges: A list of edge specifications to check for existence. (source_id, target_id, relationship_name, properties)
Returns: Returns:
-------- --------
@ -369,29 +378,24 @@ class Neo4jAdapter(GraphDBInterface):
async def add_edge( async def add_edge(
self, self,
from_node: UUID, source_id: str,
to_node: UUID, target_id: str,
relationship_name: str, relationship_name: str,
edge_properties: Optional[Dict[str, Any]] = {}, properties: Optional[Dict[str, Any]] = None,
): ) -> None:
""" """
Create a new edge between two nodes with specified properties. Create a new edge between two nodes with specified properties.
Parameters: Parameters:
----------- -----------
- from_node (UUID): The ID of the source node of the edge. - source_id (str): The ID of the source node of the edge.
- to_node (UUID): The ID of the target node of the edge. - target_id (str): The ID of the target node of the edge.
- relationship_name (str): The type/label of the edge to create. - relationship_name (str): The type/label of the edge to create.
- edge_properties (Optional[Dict[str, Any]]): A dictionary of properties to assign - properties (Optional[Dict[str, Any]]): A dictionary of properties to assign
to the edge. (default {}) to the edge. (default None)
Returns:
--------
The result of the query execution, typically indicating the created edge.
""" """
serialized_properties = self.serialize_properties(edge_properties) serialized_properties = self.serialize_properties(properties or {})
query = dedent( query = dedent(
f"""\ f"""\
@ -405,13 +409,13 @@ class Neo4jAdapter(GraphDBInterface):
) )
params = { params = {
"from_node": str(from_node), "from_node": str(source_id), # Adding str as callsites may still be passing UUID
"to_node": str(to_node), "to_node": str(target_id),
"relationship_name": relationship_name, "relationship_name": relationship_name,
"properties": serialized_properties, "properties": serialized_properties,
} }
return await self.query(query, params) await self.query(query, params)
def _flatten_edge_properties(self, properties: Dict[str, Any]) -> Dict[str, Any]: def _flatten_edge_properties(self, properties: Dict[str, Any]) -> Dict[str, Any]:
""" """
@ -445,9 +449,9 @@ class Neo4jAdapter(GraphDBInterface):
return flattened return flattened
@record_graph_changes @record_graph_changes # type: ignore
@override_distributed(queued_add_edges) @override_distributed(queued_add_edges) # type: ignore
async def add_edges(self, edges: list[tuple[str, str, str, dict[str, Any]]]) -> None: async def add_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> None:
""" """
Add multiple edges between nodes in a single query. Add multiple edges between nodes in a single query.
@ -478,10 +482,10 @@ class Neo4jAdapter(GraphDBInterface):
) YIELD rel ) YIELD rel
RETURN rel""" RETURN rel"""
edges = [ edge_params = [
{ {
"from_node": str(edge[0]), "from_node": str(edge[0]), # Adding str as callsites may still be passing UUID
"to_node": str(edge[1]), "to_node": str(edge[1]), # Adding str as callsites may still be passing UUID
"relationship_name": edge[2], "relationship_name": edge[2],
"properties": self._flatten_edge_properties( "properties": self._flatten_edge_properties(
{ {
@ -495,13 +499,12 @@ class Neo4jAdapter(GraphDBInterface):
] ]
try: try:
results = await self.query(query, dict(edges=edges)) await self.query(query, dict(edges=edge_params))
return results
except Neo4jError as error: except Neo4jError as error:
logger.error("Neo4j query error: %s", error, exc_info=True) logger.error("Neo4j query error: %s", error, exc_info=True)
raise error raise error
async def get_edges(self, node_id: str): async def get_edges(self, node_id: str) -> List[Tuple[str, str, str, Dict[str, Any]]]:
""" """
Retrieve all edges connected to a specified node. Retrieve all edges connected to a specified node.

View file

@ -1,12 +1,13 @@
import json import json
import asyncio import asyncio
from uuid import UUID from uuid import UUID
from typing import List, Optional from typing import List, Optional, Dict, Any
from chromadb import AsyncHttpClient, Settings from chromadb import AsyncHttpClient, Settings
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.modules.storage.utils import get_own_properties from cognee.modules.storage.utils import get_own_properties
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.models.DataPoint import MetaData
from cognee.infrastructure.engine.utils import parse_id from cognee.infrastructure.engine.utils import parse_id
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult
@ -35,9 +36,9 @@ class IndexSchema(DataPoint):
text: str text: str
metadata: dict = {"index_fields": ["text"]} metadata: MetaData = {"index_fields": ["text"], "type": "IndexSchema"}
def model_dump(self): def model_dump(self, **kwargs: Any) -> Dict[str, Any]:
""" """
Serialize the instance data for storage. Serialize the instance data for storage.
@ -49,11 +50,11 @@ class IndexSchema(DataPoint):
A dictionary containing serialized data processed for ChromaDB storage. A dictionary containing serialized data processed for ChromaDB storage.
""" """
data = super().model_dump() data = super().model_dump(**kwargs)
return process_data_for_chroma(data) return process_data_for_chroma(data)
def process_data_for_chroma(data): def process_data_for_chroma(data: Dict[str, Any]) -> Dict[str, Any]:
""" """
Convert complex data types to a format suitable for ChromaDB storage. Convert complex data types to a format suitable for ChromaDB storage.
@ -73,7 +74,7 @@ def process_data_for_chroma(data):
A dictionary containing the processed key-value pairs suitable for ChromaDB storage. A dictionary containing the processed key-value pairs suitable for ChromaDB storage.
""" """
processed_data = {} processed_data: Dict[str, Any] = {}
for key, value in data.items(): for key, value in data.items():
if isinstance(value, UUID): if isinstance(value, UUID):
processed_data[key] = str(value) processed_data[key] = str(value)
@ -90,7 +91,7 @@ def process_data_for_chroma(data):
return processed_data return processed_data
def restore_data_from_chroma(data): def restore_data_from_chroma(data: Dict[str, Any]) -> Dict[str, Any]:
""" """
Restore original data structure from ChromaDB storage format. Restore original data structure from ChromaDB storage format.
@ -152,8 +153,8 @@ class ChromaDBAdapter(VectorDBInterface):
""" """
name = "ChromaDB" name = "ChromaDB"
url: str url: str | None
api_key: str api_key: str | None
connection: AsyncHttpClient = None connection: AsyncHttpClient = None
def __init__( def __init__(
@ -216,7 +217,9 @@ class ChromaDBAdapter(VectorDBInterface):
collections = await self.get_collection_names() collections = await self.get_collection_names()
return collection_name in collections return collection_name in collections
async def create_collection(self, collection_name: str, payload_schema=None): async def create_collection(
self, collection_name: str, payload_schema: Optional[Any] = None
) -> None:
""" """
Create a new collection in ChromaDB if it does not already exist. Create a new collection in ChromaDB if it does not already exist.
@ -254,7 +257,7 @@ class ChromaDBAdapter(VectorDBInterface):
client = await self.get_connection() client = await self.get_connection()
return await client.get_collection(collection_name) return await client.get_collection(collection_name)
async def create_data_points(self, collection_name: str, data_points: list[DataPoint]): async def create_data_points(self, collection_name: str, data_points: List[DataPoint]) -> None:
""" """
Create and upsert data points into the specified collection in ChromaDB. Create and upsert data points into the specified collection in ChromaDB.
@ -282,7 +285,7 @@ class ChromaDBAdapter(VectorDBInterface):
ids=ids, embeddings=embeddings, metadatas=metadatas, documents=texts ids=ids, embeddings=embeddings, metadatas=metadatas, documents=texts
) )
async def create_vector_index(self, index_name: str, index_property_name: str): async def create_vector_index(self, index_name: str, index_property_name: str) -> None:
""" """
Create a vector index as a ChromaDB collection based on provided names. Create a vector index as a ChromaDB collection based on provided names.
@ -296,7 +299,7 @@ class ChromaDBAdapter(VectorDBInterface):
async def index_data_points( async def index_data_points(
self, index_name: str, index_property_name: str, data_points: list[DataPoint] self, index_name: str, index_property_name: str, data_points: list[DataPoint]
): ) -> None:
""" """
Index the provided data points based on the specified index property in ChromaDB. Index the provided data points based on the specified index property in ChromaDB.
@ -315,10 +318,11 @@ class ChromaDBAdapter(VectorDBInterface):
text=getattr(data_point, data_point.metadata["index_fields"][0]), text=getattr(data_point, data_point.metadata["index_fields"][0]),
) )
for data_point in data_points for data_point in data_points
if data_point.metadata and len(data_point.metadata["index_fields"]) > 0
], ],
) )
async def retrieve(self, collection_name: str, data_point_ids: list[str]): async def retrieve(self, collection_name: str, data_point_ids: List[str]) -> List[ScoredResult]:
""" """
Retrieve data points by their IDs from a ChromaDB collection. Retrieve data points by their IDs from a ChromaDB collection.
@ -350,12 +354,12 @@ class ChromaDBAdapter(VectorDBInterface):
async def search( async def search(
self, self,
collection_name: str, collection_name: str,
query_text: str = None, query_text: Optional[str] = None,
query_vector: List[float] = None, query_vector: Optional[List[float]] = None,
limit: int = 15, limit: int = 15,
with_vector: bool = False, with_vector: bool = False,
normalized: bool = True, normalized: bool = True,
): ) -> List[ScoredResult]:
""" """
Search for items in a collection using either a text or a vector query. Search for items in a collection using either a text or a vector query.
@ -437,7 +441,7 @@ class ChromaDBAdapter(VectorDBInterface):
query_texts: List[str], query_texts: List[str],
limit: int = 5, limit: int = 5,
with_vectors: bool = False, with_vectors: bool = False,
): ) -> List[List[ScoredResult]]:
""" """
Perform multiple searches in a single request for efficiency, returning results for each Perform multiple searches in a single request for efficiency, returning results for each
query. query.
@ -507,7 +511,7 @@ class ChromaDBAdapter(VectorDBInterface):
return all_results return all_results
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]): async def delete_data_points(self, collection_name: str, data_point_ids: List[str]) -> bool:
""" """
Remove data points from a collection based on their IDs. Remove data points from a collection based on their IDs.
@ -528,7 +532,7 @@ class ChromaDBAdapter(VectorDBInterface):
await collection.delete(ids=data_point_ids) await collection.delete(ids=data_point_ids)
return True return True
async def prune(self): async def prune(self) -> bool:
""" """
Delete all collections in the ChromaDB database. Delete all collections in the ChromaDB database.
@ -538,12 +542,12 @@ class ChromaDBAdapter(VectorDBInterface):
Returns True upon successful deletion of all collections. Returns True upon successful deletion of all collections.
""" """
client = await self.get_connection() client = await self.get_connection()
collections = await self.list_collections() collection_names = await self.get_collection_names()
for collection_name in collections: for collection_name in collection_names:
await client.delete_collection(collection_name) await client.delete_collection(collection_name)
return True return True
async def get_collection_names(self): async def get_collection_names(self) -> List[str]:
""" """
Retrieve the names of all collections in the ChromaDB database. Retrieve the names of all collections in the ChromaDB database.

View file

@ -1,12 +1,25 @@
import asyncio import asyncio
from os import path from os import path
from uuid import UUID
import lancedb import lancedb
from pydantic import BaseModel from pydantic import BaseModel
from lancedb.pydantic import LanceModel, Vector from lancedb.pydantic import LanceModel, Vector
from typing import Generic, List, Optional, TypeVar, Union, get_args, get_origin, get_type_hints from typing import (
Generic,
List,
Optional,
TypeVar,
Union,
get_args,
get_origin,
get_type_hints,
Dict,
Any,
)
from cognee.infrastructure.databases.exceptions import MissingQueryParameterError from cognee.infrastructure.databases.exceptions import MissingQueryParameterError
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.models.DataPoint import MetaData
from cognee.infrastructure.engine.utils import parse_id from cognee.infrastructure.engine.utils import parse_id
from cognee.infrastructure.files.storage import get_file_storage from cognee.infrastructure.files.storage import get_file_storage
from cognee.modules.storage.utils import copy_model, get_own_properties from cognee.modules.storage.utils import copy_model, get_own_properties
@ -30,16 +43,16 @@ class IndexSchema(DataPoint):
to include 'text'. to include 'text'.
""" """
id: str id: UUID
text: str text: str
metadata: dict = {"index_fields": ["text"]} metadata: MetaData = {"index_fields": ["text"], "type": "IndexSchema"}
class LanceDBAdapter(VectorDBInterface): class LanceDBAdapter(VectorDBInterface):
name = "LanceDB" name = "LanceDB"
url: str url: Optional[str]
api_key: str api_key: Optional[str]
connection: lancedb.AsyncConnection = None connection: lancedb.AsyncConnection = None
def __init__( def __init__(
@ -53,7 +66,7 @@ class LanceDBAdapter(VectorDBInterface):
self.embedding_engine = embedding_engine self.embedding_engine = embedding_engine
self.VECTOR_DB_LOCK = asyncio.Lock() self.VECTOR_DB_LOCK = asyncio.Lock()
async def get_connection(self): async def get_connection(self) -> lancedb.AsyncConnection:
""" """
Establishes and returns a connection to the LanceDB. Establishes and returns a connection to the LanceDB.
@ -107,12 +120,11 @@ class LanceDBAdapter(VectorDBInterface):
collection_names = await connection.table_names() collection_names = await connection.table_names()
return collection_name in collection_names return collection_name in collection_names
async def create_collection(self, collection_name: str, payload_schema: BaseModel): async def create_collection(
self, collection_name: str, payload_schema: Optional[Any] = None
) -> None:
vector_size = self.embedding_engine.get_vector_size() vector_size = self.embedding_engine.get_vector_size()
payload_schema = self.get_data_point_schema(payload_schema)
data_point_types = get_type_hints(payload_schema)
class LanceDataPoint(LanceModel): class LanceDataPoint(LanceModel):
""" """
Represents a data point in the Lance model with an ID, vector, and associated payload. Represents a data point in the Lance model with an ID, vector, and associated payload.
@ -123,28 +135,28 @@ class LanceDBAdapter(VectorDBInterface):
- payload: Additional data or metadata associated with the data point. - payload: Additional data or metadata associated with the data point.
""" """
id: data_point_types["id"] id: UUID
vector: Vector(vector_size) vector: Vector[vector_size]
payload: payload_schema payload: Dict[str, Any]
if not await self.has_collection(collection_name): if not await self.has_collection(collection_name):
async with self.VECTOR_DB_LOCK: async with self.VECTOR_DB_LOCK:
if not await self.has_collection(collection_name): if not await self.has_collection(collection_name):
connection = await self.get_connection() connection = await self.get_connection()
return await connection.create_table( await connection.create_table(
name=collection_name, name=collection_name,
schema=LanceDataPoint, schema=LanceDataPoint,
exist_ok=True, exist_ok=True,
) )
async def get_collection(self, collection_name: str): async def get_collection(self, collection_name: str) -> Any:
if not await self.has_collection(collection_name): if not await self.has_collection(collection_name):
raise CollectionNotFoundError(f"Collection '{collection_name}' not found!") raise CollectionNotFoundError(f"Collection '{collection_name}' not found!")
connection = await self.get_connection() connection = await self.get_connection()
return await connection.open_table(collection_name) return await connection.open_table(collection_name)
async def create_data_points(self, collection_name: str, data_points: list[DataPoint]): async def create_data_points(self, collection_name: str, data_points: List[DataPoint]) -> None:
payload_schema = type(data_points[0]) payload_schema = type(data_points[0])
if not await self.has_collection(collection_name): if not await self.has_collection(collection_name):
@ -175,14 +187,14 @@ class LanceDBAdapter(VectorDBInterface):
""" """
id: IdType id: IdType
vector: Vector(vector_size) vector: Vector[vector_size]
payload: PayloadSchema payload: PayloadSchema
def create_lance_data_point(data_point: DataPoint, vector: list[float]) -> LanceDataPoint: def create_lance_data_point(data_point: DataPoint, vector: List[float]) -> Any:
properties = get_own_properties(data_point) properties = get_own_properties(data_point)
properties["id"] = str(properties["id"]) properties["id"] = str(properties["id"])
return LanceDataPoint[str, self.get_data_point_schema(type(data_point))]( return LanceDataPoint(
id=str(data_point.id), id=str(data_point.id),
vector=vector, vector=vector,
payload=properties, payload=properties,
@ -201,7 +213,7 @@ class LanceDBAdapter(VectorDBInterface):
.execute(lance_data_points) .execute(lance_data_points)
) )
async def retrieve(self, collection_name: str, data_point_ids: list[str]): async def retrieve(self, collection_name: str, data_point_ids: list[str]) -> List[ScoredResult]:
collection = await self.get_collection(collection_name) collection = await self.get_collection(collection_name)
if len(data_point_ids) == 1: if len(data_point_ids) == 1:
@ -221,12 +233,12 @@ class LanceDBAdapter(VectorDBInterface):
async def search( async def search(
self, self,
collection_name: str, collection_name: str,
query_text: str = None, query_text: Optional[str] = None,
query_vector: List[float] = None, query_vector: Optional[List[float]] = None,
limit: int = 15, limit: int = 15,
with_vector: bool = False, with_vector: bool = False,
normalized: bool = True, normalized: bool = True,
): ) -> List[ScoredResult]:
if query_text is None and query_vector is None: if query_text is None and query_vector is None:
raise MissingQueryParameterError() raise MissingQueryParameterError()
@ -264,9 +276,9 @@ class LanceDBAdapter(VectorDBInterface):
self, self,
collection_name: str, collection_name: str,
query_texts: List[str], query_texts: List[str],
limit: int = None, limit: Optional[int] = None,
with_vectors: bool = False, with_vectors: bool = False,
): ) -> List[List[ScoredResult]]:
query_vectors = await self.embedding_engine.embed_text(query_texts) query_vectors = await self.embedding_engine.embed_text(query_texts)
return await asyncio.gather( return await asyncio.gather(
@ -274,40 +286,41 @@ class LanceDBAdapter(VectorDBInterface):
self.search( self.search(
collection_name=collection_name, collection_name=collection_name,
query_vector=query_vector, query_vector=query_vector,
limit=limit, limit=limit or 15,
with_vector=with_vectors, with_vector=with_vectors,
) )
for query_vector in query_vectors for query_vector in query_vectors
] ]
) )
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]): async def delete_data_points(self, collection_name: str, data_point_ids: List[str]) -> None:
collection = await self.get_collection(collection_name) collection = await self.get_collection(collection_name)
# Delete one at a time to avoid commit conflicts # Delete one at a time to avoid commit conflicts
for data_point_id in data_point_ids: for data_point_id in data_point_ids:
await collection.delete(f"id = '{data_point_id}'") await collection.delete(f"id = '{data_point_id}'")
async def create_vector_index(self, index_name: str, index_property_name: str): async def create_vector_index(self, index_name: str, index_property_name: str) -> None:
await self.create_collection( await self.create_collection(
f"{index_name}_{index_property_name}", payload_schema=IndexSchema f"{index_name}_{index_property_name}", payload_schema=IndexSchema
) )
async def index_data_points( async def index_data_points(
self, index_name: str, index_property_name: str, data_points: list[DataPoint] self, index_name: str, index_property_name: str, data_points: List[DataPoint]
): ) -> None:
await self.create_data_points( await self.create_data_points(
f"{index_name}_{index_property_name}", f"{index_name}_{index_property_name}",
[ [
IndexSchema( IndexSchema(
id=str(data_point.id), id=data_point.id,
text=getattr(data_point, data_point.metadata["index_fields"][0]), text=getattr(data_point, data_point.metadata["index_fields"][0]),
) )
for data_point in data_points for data_point in data_points
if data_point.metadata and len(data_point.metadata.get("index_fields", [])) > 0
], ],
) )
async def prune(self): async def prune(self) -> None:
connection = await self.get_connection() connection = await self.get_connection()
collection_names = await connection.table_names() collection_names = await connection.table_names()
@ -316,12 +329,15 @@ class LanceDBAdapter(VectorDBInterface):
await collection.delete("id IS NOT NULL") await collection.delete("id IS NOT NULL")
await connection.drop_table(collection_name) await connection.drop_table(collection_name)
if self.url.startswith("/"): if self.url and self.url.startswith("/"):
db_dir_path = path.dirname(self.url) db_dir_path = path.dirname(self.url)
db_file_name = path.basename(self.url) db_file_name = path.basename(self.url)
await get_file_storage(db_dir_path).remove_all(db_file_name) await get_file_storage(db_dir_path).remove_all(db_file_name)
def get_data_point_schema(self, model_type: BaseModel): def get_data_point_schema(self, model_type: Optional[Any]) -> Any:
if model_type is None:
return DataPoint
related_models_fields = [] related_models_fields = []
for field_name, field_config in model_type.model_fields.items(): for field_name, field_config in model_type.model_fields.items():

View file

@ -1,4 +1,4 @@
from typing import Any, Dict from typing import Any, Dict, List, Optional
from uuid import UUID from uuid import UUID
from pydantic import BaseModel from pydantic import BaseModel
@ -14,8 +14,10 @@ class ScoredResult(BaseModel):
better outcome. better outcome.
- payload (Dict[str, Any]): Additional information related to the score, stored as - payload (Dict[str, Any]): Additional information related to the score, stored as
key-value pairs in a dictionary. key-value pairs in a dictionary.
- vector (Optional[List[float]]): Optional vector embedding associated with the result.
""" """
id: UUID id: UUID
score: float # Lower score is better score: float # Lower score is better
payload: Dict[str, Any] payload: Dict[str, Any]
vector: Optional[List[float]] = None

View file

@ -1,9 +1,9 @@
import asyncio import asyncio
from typing import List, Optional, get_type_hints from typing import List, Optional, get_type_hints, Dict, Any
from sqlalchemy.inspection import inspect from sqlalchemy.inspection import inspect
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.dialects.postgresql import insert from sqlalchemy.dialects.postgresql import insert
from sqlalchemy import JSON, Column, Table, select, delete, MetaData from sqlalchemy import JSON, Table, select, delete, MetaData
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from sqlalchemy.exc import ProgrammingError from sqlalchemy.exc import ProgrammingError
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
@ -12,6 +12,7 @@ from asyncpg import DeadlockDetectedError, DuplicateTableError, UniqueViolationE
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.models.DataPoint import MetaData as DataPointMetaData
from cognee.infrastructure.engine.utils import parse_id from cognee.infrastructure.engine.utils import parse_id
from cognee.infrastructure.databases.relational import get_relational_engine from cognee.infrastructure.databases.relational import get_relational_engine
@ -42,7 +43,7 @@ class IndexSchema(DataPoint):
text: str text: str
metadata: dict = {"index_fields": ["text"]} metadata: DataPointMetaData = {"index_fields": ["text"], "type": "IndexSchema"}
class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
@ -122,8 +123,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
stop=stop_after_attempt(5), stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=2, min=1, max=6), wait=wait_exponential(multiplier=2, min=1, max=6),
) )
async def create_collection(self, collection_name: str, payload_schema=None): async def create_collection(
data_point_types = get_type_hints(DataPoint) self, collection_name: str, payload_schema: Optional[Any] = None
) -> None:
vector_size = self.embedding_engine.get_vector_size() vector_size = self.embedding_engine.get_vector_size()
async with self.VECTOR_DB_LOCK: async with self.VECTOR_DB_LOCK:
@ -147,29 +149,31 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
__tablename__ = collection_name __tablename__ = collection_name
__table_args__ = {"extend_existing": True} __table_args__ = {"extend_existing": True}
# PGVector requires one column to be the primary key # PGVector requires one column to be the primary key
id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True) id: Mapped[str] = mapped_column(primary_key=True)
payload = Column(JSON) payload: Mapped[Dict[str, Any]] = mapped_column(JSON)
vector = Column(self.Vector(vector_size)) vector: Mapped[List[float]] = mapped_column(self.Vector(vector_size))
def __init__(self, id, payload, vector): def __init__(
self, id: str, payload: Dict[str, Any], vector: List[float]
) -> None:
self.id = id self.id = id
self.payload = payload self.payload = payload
self.vector = vector self.vector = vector
async with self.engine.begin() as connection: async with self.engine.begin() as connection:
if len(Base.metadata.tables.keys()) > 0: if len(Base.metadata.tables.keys()) > 0:
await connection.run_sync( from sqlalchemy import Table
Base.metadata.create_all, tables=[PGVectorDataPoint.__table__]
) table: Table = PGVectorDataPoint.__table__ # type: ignore
await connection.run_sync(Base.metadata.create_all, tables=[table])
@retry( @retry(
retry=retry_if_exception_type(DeadlockDetectedError), retry=retry_if_exception_type(DeadlockDetectedError),
stop=stop_after_attempt(3), stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=2, min=1, max=6), wait=wait_exponential(multiplier=2, min=1, max=6),
) )
@override_distributed(queued_add_data_points) @override_distributed(queued_add_data_points) # type: ignore
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]): async def create_data_points(self, collection_name: str, data_points: List[DataPoint]) -> None:
data_point_types = get_type_hints(DataPoint)
if not await self.has_collection(collection_name): if not await self.has_collection(collection_name):
await self.create_collection( await self.create_collection(
collection_name=collection_name, collection_name=collection_name,
@ -196,11 +200,11 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
__tablename__ = collection_name __tablename__ = collection_name
__table_args__ = {"extend_existing": True} __table_args__ = {"extend_existing": True}
# PGVector requires one column to be the primary key # PGVector requires one column to be the primary key
id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True) id: Mapped[str] = mapped_column(primary_key=True)
payload = Column(JSON) payload: Mapped[Dict[str, Any]] = mapped_column(JSON)
vector = Column(self.Vector(vector_size)) vector: Mapped[List[float]] = mapped_column(self.Vector(vector_size))
def __init__(self, id, payload, vector): def __init__(self, id: str, payload: Dict[str, Any], vector: List[float]) -> None:
self.id = id self.id = id
self.payload = payload self.payload = payload
self.vector = vector self.vector = vector
@ -225,13 +229,13 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
# else: # else:
pgvector_data_points.append( pgvector_data_points.append(
PGVectorDataPoint( PGVectorDataPoint(
id=data_point.id, id=str(data_point.id),
vector=data_vectors[data_index], vector=data_vectors[data_index],
payload=serialize_data(data_point.model_dump()), payload=serialize_data(data_point.model_dump()),
) )
) )
def to_dict(obj): def to_dict(obj: Any) -> Dict[str, Any]:
return { return {
column.key: getattr(obj, column.key) column.key: getattr(obj, column.key)
for column in inspect(obj).mapper.column_attrs for column in inspect(obj).mapper.column_attrs
@ -245,12 +249,12 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
await session.execute(insert_statement) await session.execute(insert_statement)
await session.commit() await session.commit()
async def create_vector_index(self, index_name: str, index_property_name: str): async def create_vector_index(self, index_name: str, index_property_name: str) -> None:
await self.create_collection(f"{index_name}_{index_property_name}") await self.create_collection(f"{index_name}_{index_property_name}")
async def index_data_points( async def index_data_points(
self, index_name: str, index_property_name: str, data_points: list[DataPoint] self, index_name: str, index_property_name: str, data_points: List[DataPoint]
): ) -> None:
await self.create_data_points( await self.create_data_points(
f"{index_name}_{index_property_name}", f"{index_name}_{index_property_name}",
[ [
@ -262,11 +266,12 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
], ],
) )
async def get_table(self, collection_name: str) -> Table: async def get_table(self, table_name: str, schema_name: Optional[str] = None) -> Table:
""" """
Dynamically loads a table using the given collection name Dynamically loads a table using the given table name
with an async engine. with an async engine. Schema parameter is ignored for vector collections.
""" """
collection_name = table_name
async with self.engine.begin() as connection: async with self.engine.begin() as connection:
# Create a MetaData instance to load table information # Create a MetaData instance to load table information
metadata = MetaData() metadata = MetaData()
@ -279,15 +284,15 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
f"Collection '{collection_name}' not found!", f"Collection '{collection_name}' not found!",
) )
async def retrieve(self, collection_name: str, data_point_ids: List[str]): async def retrieve(self, collection_name: str, data_point_ids: List[str]) -> List[ScoredResult]:
# Get PGVectorDataPoint Table from database # Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name) PGVectorDataPoint = await self.get_table(collection_name)
async with self.get_async_session() as session: async with self.get_async_session() as session:
results = await session.execute( query_result = await session.execute(
select(PGVectorDataPoint).where(PGVectorDataPoint.c.id.in_(data_point_ids)) select(PGVectorDataPoint).where(PGVectorDataPoint.c.id.in_(data_point_ids))
) )
results = results.all() results = query_result.all()
return [ return [
ScoredResult(id=parse_id(result.id), payload=result.payload, score=0) ScoredResult(id=parse_id(result.id), payload=result.payload, score=0)
@ -311,9 +316,6 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
# Get PGVectorDataPoint Table from database # Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name) PGVectorDataPoint = await self.get_table(collection_name)
# NOTE: This needs to be initialized in case search doesn't return a value
closest_items = []
# Use async session to connect to the database # Use async session to connect to the database
async with self.get_async_session() as session: async with self.get_async_session() as session:
query = select( query = select(
@ -325,12 +327,12 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
query = query.limit(limit) query = query.limit(limit)
# Find closest vectors to query_vector # Find closest vectors to query_vector
closest_items = await session.execute(query) query_results = await session.execute(query)
vector_list = [] vector_list = []
# Extract distances and find min/max for normalization # Extract distances and find min/max for normalization
for vector in closest_items.all(): for vector in query_results.all():
vector_list.append( vector_list.append(
{ {
"id": parse_id(str(vector.id)), "id": parse_id(str(vector.id)),
@ -349,7 +351,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
# Create and return ScoredResult objects # Create and return ScoredResult objects
return [ return [
ScoredResult(id=row.get("id"), payload=row.get("payload"), score=row.get("score")) ScoredResult(id=row["id"], payload=row["payload"] or {}, score=row["score"])
for row in vector_list for row in vector_list
] ]
@ -357,9 +359,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
self, self,
collection_name: str, collection_name: str,
query_texts: List[str], query_texts: List[str],
limit: int = None, limit: Optional[int] = None,
with_vectors: bool = False, with_vectors: bool = False,
): ) -> List[List[ScoredResult]]:
query_vectors = await self.embedding_engine.embed_text(query_texts) query_vectors = await self.embedding_engine.embed_text(query_texts)
return await asyncio.gather( return await asyncio.gather(
@ -367,14 +369,14 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
self.search( self.search(
collection_name=collection_name, collection_name=collection_name,
query_vector=query_vector, query_vector=query_vector,
limit=limit, limit=limit or 15,
with_vector=with_vectors, with_vector=with_vectors,
) )
for query_vector in query_vectors for query_vector in query_vectors
] ]
) )
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]): async def delete_data_points(self, collection_name: str, data_point_ids: List[str]) -> Any:
async with self.get_async_session() as session: async with self.get_async_session() as session:
# Get PGVectorDataPoint Table from database # Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name) PGVectorDataPoint = await self.get_table(collection_name)
@ -384,6 +386,6 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
await session.commit() await session.commit()
return results return results
async def prune(self): async def prune(self) -> None:
# Clean up the database if it was set up as temporary # Clean up the database if it was set up as temporary
await self.delete_database() await self.delete_database()

View file

@ -1,109 +0,0 @@
import os
import pathlib
import cognee
from cognee.infrastructure.files.storage import get_storage_config
from cognee.modules.search.operations import get_history
from cognee.modules.users.methods import get_default_user
from cognee.shared.logging_utils import get_logger
from cognee.modules.search.types import SearchType
logger = get_logger()
async def main():
cognee.config.set_graph_database_provider("memgraph")
data_directory_path = str(
pathlib.Path(
os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_memgraph")
).resolve()
)
cognee.config.data_root_directory(data_directory_path)
cognee_directory_path = str(
pathlib.Path(
os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_memgraph")
).resolve()
)
cognee.config.system_root_directory(cognee_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
dataset_name = "cs_explanations"
explanation_file_path = os.path.join(
pathlib.Path(__file__).parent, "test_data/Natural_language_processing.txt"
)
await cognee.add([explanation_file_path], dataset_name)
text = """A quantum computer is a computer that takes advantage of quantum mechanical phenomena.
At small scales, physical matter exhibits properties of both particles and waves, and quantum computing leverages this behavior, specifically quantum superposition and entanglement, using specialized hardware that supports the preparation and manipulation of quantum states.
Classical physics cannot explain the operation of these quantum devices, and a scalable quantum computer could perform some calculations exponentially faster (with respect to input size scaling) than any modern "classical" computer. In particular, a large-scale quantum computer could break widely used encryption schemes and aid physicists in performing physical simulations; however, the current state of the technology is largely experimental and impractical, with several obstacles to useful applications. Moreover, scalable quantum computers do not hold promise for many practical tasks, and for many important tasks quantum speedups are proven impossible.
The basic unit of information in quantum computing is the qubit, similar to the bit in traditional digital electronics. Unlike a classical bit, a qubit can exist in a superposition of its two "basis" states. When measuring a qubit, the result is a probabilistic output of a classical bit, therefore making quantum computers nondeterministic in general. If a quantum computer manipulates the qubit in a particular way, wave interference effects can amplify the desired measurement results. The design of quantum algorithms involves creating procedures that allow a quantum computer to perform calculations efficiently and quickly.
Physically engineering high-quality qubits has proven challenging. If a physical qubit is not sufficiently isolated from its environment, it suffers from quantum decoherence, introducing noise into calculations. Paradoxically, perfectly isolating qubits is also undesirable because quantum computations typically need to initialize qubits, perform controlled qubit interactions, and measure the resulting quantum states. Each of those operations introduces errors and suffers from noise, and such inaccuracies accumulate.
In principle, a non-quantum (classical) computer can solve the same computational problems as a quantum computer, given enough time. Quantum advantage comes in the form of time complexity rather than computability, and quantum complexity theory shows that some quantum algorithms for carefully selected tasks require exponentially fewer computational steps than the best known non-quantum algorithms. Such tasks can in theory be solved on a large-scale quantum computer whereas classical computers would not finish computations in any reasonable amount of time. However, quantum speedup is not universal or even typical across computational tasks, since basic tasks such as sorting are proven to not allow any asymptotic quantum speedup. Claims of quantum supremacy have drawn significant attention to the discipline, but are demonstrated on contrived tasks, while near-term practical use cases remain limited.
"""
await cognee.add([text], dataset_name)
await cognee.cognify([dataset_name])
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
random_node_name = random_node.payload["text"]
search_results = await cognee.search(
query_type=SearchType.INSIGHTS, query_text=random_node_name
)
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted sentences are:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search(query_type=SearchType.CHUNKS, query_text=random_node_name)
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted chunks are:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search(
query_type=SearchType.SUMMARIES, query_text=random_node_name
)
assert len(search_results) != 0, "Query related summaries don't exist."
print("\nExtracted results are:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search(
query_type=SearchType.NATURAL_LANGUAGE,
query_text=f"Find nodes connected to node with name {random_node_name}",
)
assert len(search_results) != 0, "Query related natural language don't exist."
print("\nExtracted results are:\n")
for result in search_results:
print(f"{result}\n")
user = await get_default_user()
history = await get_history(user.id)
assert len(history) == 8, "Search history is not correct."
await cognee.prune.prune_data()
data_root_directory = get_storage_config()["data_root_directory"]
assert not os.path.isdir(data_root_directory), "Local data files are not deleted"
await cognee.prune.prune_system(metadata=True)
from cognee.infrastructure.databases.graph import get_graph_engine
graph_engine = await get_graph_engine()
nodes, edges = await graph_engine.get_graph_data()
assert len(nodes) == 0 and len(edges) == 0, "Memgraph graph database is not empty"
if __name__ == "__main__":
import asyncio
asyncio.run(main())

View file

@ -1,7 +1,7 @@
[mypy] [mypy]
python_version=3.8 python_version=3.10
ignore_missing_imports=false ignore_missing_imports=false
strict_optional=false strict_optional=true
warn_redundant_casts=true warn_redundant_casts=true
disallow_any_generics=true disallow_any_generics=true
disallow_untyped_defs=true disallow_untyped_defs=true
@ -10,6 +10,12 @@ warn_return_any=true
namespace_packages=true namespace_packages=true
warn_unused_ignores=true warn_unused_ignores=true
show_error_codes=true show_error_codes=true
disallow_incomplete_defs=true
disallow_untyped_decorators=true
no_implicit_optional=true
warn_unreachable=true
warn_no_return=true
warn_unused_configs=true
#exclude=reflection/module_cases/* #exclude=reflection/module_cases/*
exclude=docs/examples/archive/*|tests/reflection/module_cases/* exclude=docs/examples/archive/*|tests/reflection/module_cases/*
@ -18,6 +24,22 @@ disallow_untyped_defs=false
warn_return_any=false warn_return_any=false
[mypy-cognee.infrastructure.databases.*]
ignore_missing_imports=true
# Third-party database libraries that lack type stubs
[mypy-chromadb.*]
ignore_missing_imports=true
[mypy-lancedb.*]
ignore_missing_imports=true
[mypy-asyncpg.*]
ignore_missing_imports=true
[mypy-pgvector.*]
ignore_missing_imports=true
[mypy-docs.*] [mypy-docs.*]
disallow_untyped_defs=false disallow_untyped_defs=false

View file

@ -83,16 +83,16 @@
] ]
}, },
{ {
"metadata": {},
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"import os\n", "import os\n",
"import pathlib\n", "import pathlib\n",
"from cognee import config, add, cognify, search, SearchType, prune, visualize_graph\n", "from cognee import config, add, cognify, search, SearchType, prune, visualize_graph\n",
"from dotenv import load_dotenv" "from dotenv import load_dotenv"
], ]
"outputs": [],
"execution_count": null
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
@ -106,7 +106,9 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [],
"source": [ "source": [
"# load environment variables from file .env\n", "# load environment variables from file .env\n",
"load_dotenv()\n", "load_dotenv()\n",
@ -145,9 +147,7 @@
" \"vector_db_url\": f\"neptune-graph://{graph_identifier}\", # Neptune Analytics endpoint with the format neptune-graph://<GRAPH_ID>\n", " \"vector_db_url\": f\"neptune-graph://{graph_identifier}\", # Neptune Analytics endpoint with the format neptune-graph://<GRAPH_ID>\n",
" }\n", " }\n",
")" ")"
], ]
"outputs": [],
"execution_count": null
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
@ -159,19 +159,19 @@
] ]
}, },
{ {
"metadata": {},
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"# Prune data and system metadata before running, only if we want \"fresh\" state.\n", "# Prune data and system metadata before running, only if we want \"fresh\" state.\n",
"await prune.prune_data()\n", "await prune.prune_data()\n",
"await prune.prune_system(metadata=True)" "await prune.prune_system(metadata=True)"
], ]
"outputs": [],
"execution_count": null
}, },
{ {
"metadata": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {},
"source": [ "source": [
"## Setup data and cognify\n", "## Setup data and cognify\n",
"\n", "\n",
@ -180,7 +180,9 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [],
"source": [ "source": [
"# Add sample text to the dataset\n", "# Add sample text to the dataset\n",
"sample_text_1 = \"\"\"Neptune Analytics is a memory-optimized graph database engine for analytics. With Neptune\n", "sample_text_1 = \"\"\"Neptune Analytics is a memory-optimized graph database engine for analytics. With Neptune\n",
@ -205,9 +207,7 @@
"\n", "\n",
"# Cognify the text data.\n", "# Cognify the text data.\n",
"await cognify([dataset_name])" "await cognify([dataset_name])"
], ]
"outputs": [],
"execution_count": null
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
@ -215,14 +215,16 @@
"source": [ "source": [
"## Graph Memory visualization\n", "## Graph Memory visualization\n",
"\n", "\n",
"Initialize Memgraph as a Graph Memory store and save to .artefacts/graph_visualization.html\n", "Initialize a Graph Memory store and save to .artefacts/graph_visualization.html\n",
"\n", "\n",
"![visualization](./neptune_analytics_demo.png)" "![visualization](./neptune_analytics_demo.png)"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [],
"source": [ "source": [
"# Get a graphistry url (Register for a free account at https://www.graphistry.com)\n", "# Get a graphistry url (Register for a free account at https://www.graphistry.com)\n",
"# url = await render_graph()\n", "# url = await render_graph()\n",
@ -235,9 +237,7 @@
" ).resolve()\n", " ).resolve()\n",
")\n", ")\n",
"await visualize_graph(graph_file_path)" "await visualize_graph(graph_file_path)"
], ]
"outputs": [],
"execution_count": null
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
@ -250,19 +250,19 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [],
"source": [ "source": [
"# Completion query that uses graph data to form context.\n", "# Completion query that uses graph data to form context.\n",
"graph_completion = await search(query_text=\"What is Neptune Analytics?\", query_type=SearchType.GRAPH_COMPLETION)\n", "graph_completion = await search(query_text=\"What is Neptune Analytics?\", query_type=SearchType.GRAPH_COMPLETION)\n",
"print(\"\\nGraph completion result is:\")\n", "print(\"\\nGraph completion result is:\")\n",
"print(graph_completion)" "print(graph_completion)"
], ]
"outputs": [],
"execution_count": null
}, },
{ {
"metadata": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {},
"source": [ "source": [
"## SEARCH: RAG Completion\n", "## SEARCH: RAG Completion\n",
"\n", "\n",
@ -271,19 +271,19 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [],
"source": [ "source": [
"# Completion query that uses document chunks to form context.\n", "# Completion query that uses document chunks to form context.\n",
"rag_completion = await search(query_text=\"What is Neptune Analytics?\", query_type=SearchType.RAG_COMPLETION)\n", "rag_completion = await search(query_text=\"What is Neptune Analytics?\", query_type=SearchType.RAG_COMPLETION)\n",
"print(\"\\nRAG Completion result is:\")\n", "print(\"\\nRAG Completion result is:\")\n",
"print(rag_completion)" "print(rag_completion)"
], ]
"outputs": [],
"execution_count": null
}, },
{ {
"metadata": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {},
"source": [ "source": [
"## SEARCH: Graph Insights\n", "## SEARCH: Graph Insights\n",
"\n", "\n",
@ -291,8 +291,10 @@
] ]
}, },
{ {
"metadata": {},
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"# Search graph insights\n", "# Search graph insights\n",
"insights_results = await search(query_text=\"Neptune Analytics\", query_type=SearchType.INSIGHTS)\n", "insights_results = await search(query_text=\"Neptune Analytics\", query_type=SearchType.INSIGHTS)\n",
@ -302,13 +304,11 @@
" tgt_node = result[2].get(\"name\", result[2][\"type\"])\n", " tgt_node = result[2].get(\"name\", result[2][\"type\"])\n",
" relationship = result[1].get(\"relationship_name\", \"__relationship__\")\n", " relationship = result[1].get(\"relationship_name\", \"__relationship__\")\n",
" print(f\"- {src_node} -[{relationship}]-> {tgt_node}\")" " print(f\"- {src_node} -[{relationship}]-> {tgt_node}\")"
], ]
"outputs": [],
"execution_count": null
}, },
{ {
"metadata": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {},
"source": [ "source": [
"## SEARCH: Entity Summaries\n", "## SEARCH: Entity Summaries\n",
"\n", "\n",
@ -316,8 +316,10 @@
] ]
}, },
{ {
"metadata": {},
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"# Query all summaries related to query.\n", "# Query all summaries related to query.\n",
"summaries = await search(query_text=\"Neptune Analytics\", query_type=SearchType.SUMMARIES)\n", "summaries = await search(query_text=\"Neptune Analytics\", query_type=SearchType.SUMMARIES)\n",
@ -326,13 +328,11 @@
" type = summary[\"type\"]\n", " type = summary[\"type\"]\n",
" text = summary[\"text\"]\n", " text = summary[\"text\"]\n",
" print(f\"- {type}: {text}\")" " print(f\"- {type}: {text}\")"
], ]
"outputs": [],
"execution_count": null
}, },
{ {
"metadata": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {},
"source": [ "source": [
"## SEARCH: Chunks\n", "## SEARCH: Chunks\n",
"\n", "\n",
@ -340,8 +340,10 @@
] ]
}, },
{ {
"metadata": {},
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"chunks = await search(query_text=\"Neptune Analytics\", query_type=SearchType.CHUNKS)\n", "chunks = await search(query_text=\"Neptune Analytics\", query_type=SearchType.CHUNKS)\n",
"print(\"\\nChunk results are:\")\n", "print(\"\\nChunk results are:\")\n",
@ -349,9 +351,7 @@
" type = chunk[\"type\"]\n", " type = chunk[\"type\"]\n",
" text = chunk[\"text\"]\n", " text = chunk[\"text\"]\n",
" print(f\"- {type}: {text}\")" " print(f\"- {type}: {text}\")"
], ]
"outputs": [],
"execution_count": null
} }
], ],
"metadata": { "metadata": {

41
tools/check_all_adapters.sh Executable file
View file

@ -0,0 +1,41 @@
#!/bin/bash
# All Database Adapters MyPy Check Script
set -e # Exit on any error
echo "🚀 Running MyPy checks on all database adapters..."
echo ""
# Ensure we're in the project root directory
cd "$(dirname "$0")/.."
# Run all three adapter checks
echo "========================================="
echo "1⃣ VECTOR DATABASE ADAPTERS"
echo "========================================="
./tools/check_vector_adapters.sh
echo ""
echo "========================================="
echo "2⃣ GRAPH DATABASE ADAPTERS"
echo "========================================="
./tools/check_graph_adapters.sh
echo ""
echo "========================================="
echo "3⃣ HYBRID DATABASE ADAPTERS"
echo "========================================="
./tools/check_hybrid_adapters.sh
echo ""
echo "🎉 All Database Adapters MyPy Checks Complete!"
echo ""
echo "🔍 Auto-Discovery Approach:"
echo " • Vector Adapters: cognee/infrastructure/databases/vector/**/*Adapter.py"
echo " • Graph Adapters: cognee/infrastructure/databases/graph/**/*adapter.py"
echo " • Hybrid Adapters: cognee/infrastructure/databases/hybrid/**/*Adapter.py"
echo ""
echo "🎯 Purpose: Enforce that database adapters are properly typed"
echo "🔧 MyPy Configuration: mypy.ini (strict mode enabled)"
echo "🚀 Maintenance-Free: Automatically discovers new adapters"

41
tools/check_graph_adapters.sh Executable file
View file

@ -0,0 +1,41 @@
#!/bin/bash
# Graph Database Adapters MyPy Check Script
set -e # Exit on any error
echo "🔍 Discovering Graph Database Adapters..."
# Ensure we're in the project root directory
cd "$(dirname "$0")/.."
# Activate virtual environment
source .venv/bin/activate
# Find all adapter.py and *adapter.py files in graph database directories, excluding utility files
graph_adapters=$(find cognee/infrastructure/databases/graph -name "*adapter.py" -o -name "adapter.py" | grep -v "use_graph_adapter.py" | sort)
if [ -z "$graph_adapters" ]; then
echo "No graph database adapters found"
exit 0
else
echo "Found graph database adapters:"
echo "$graph_adapters" | sed 's/^/ • /'
echo ""
echo "Running MyPy on graph database adapters..."
# Use while read to properly handle each file
echo "$graph_adapters" | while read -r adapter; do
if [ -n "$adapter" ]; then
echo "Checking: $adapter"
uv run mypy "$adapter" \
--config-file mypy.ini \
--show-error-codes \
--no-error-summary
echo ""
fi
done
fi
echo "✅ Graph Database Adapters MyPy Check Complete!"

41
tools/check_hybrid_adapters.sh Executable file
View file

@ -0,0 +1,41 @@
#!/bin/bash
# Hybrid Database Adapters MyPy Check Script
set -e # Exit on any error
echo "🔍 Discovering Hybrid Database Adapters..."
# Ensure we're in the project root directory
cd "$(dirname "$0")/.."
# Activate virtual environment
source .venv/bin/activate
# Find all *Adapter.py files in hybrid database directories
hybrid_adapters=$(find cognee/infrastructure/databases/hybrid -name "*Adapter.py" -type f | sort)
if [ -z "$hybrid_adapters" ]; then
echo "No hybrid database adapters found"
exit 0
else
echo "Found hybrid database adapters:"
echo "$hybrid_adapters" | sed 's/^/ • /'
echo ""
echo "Running MyPy on hybrid database adapters..."
# Use while read to properly handle each file
echo "$hybrid_adapters" | while read -r adapter; do
if [ -n "$adapter" ]; then
echo "Checking: $adapter"
uv run mypy "$adapter" \
--config-file mypy.ini \
--show-error-codes \
--no-error-summary
echo ""
fi
done
fi
echo "✅ Hybrid Database Adapters MyPy Check Complete!"

41
tools/check_vector_adapters.sh Executable file
View file

@ -0,0 +1,41 @@
#!/bin/bash
# Vector Database Adapters MyPy Check Script
set -e # Exit on any error
echo "🔍 Discovering Vector Database Adapters..."
# Ensure we're in the project root directory
cd "$(dirname "$0")/.."
# Activate virtual environment
source .venv/bin/activate
# Find all *Adapter.py files in vector database directories
vector_adapters=$(find cognee/infrastructure/databases/vector -name "*Adapter.py" -type f | sort)
if [ -z "$vector_adapters" ]; then
echo "No vector database adapters found"
exit 0
else
echo "Found vector database adapters:"
echo "$vector_adapters" | sed 's/^/ • /'
echo ""
echo "Running MyPy on vector database adapters..."
# Use while read to properly handle each file
echo "$vector_adapters" | while read -r adapter; do
if [ -n "$adapter" ]; then
echo "Checking: $adapter"
uv run mypy "$adapter" \
--config-file mypy.ini \
--show-error-codes \
--no-error-summary
echo ""
fi
done
fi
echo "✅ Vector Database Adapters MyPy Check Complete!"