Compare commits

...
Sign in to create a new pull request.

20 commits

Author SHA1 Message Date
Daulet Amirkhanov
a3d2e7bd2a ruff format 2025-09-07 20:51:24 +01:00
Daulet Amirkhanov
ed83f76f25 clean up todos from lancedb 2025-09-07 20:48:58 +01:00
Daulet Amirkhanov
3a65160839 move check adapters scripts to /tools and update mypy workflow 2025-09-07 20:45:17 +01:00
Daulet Amirkhanov
1d282f3f83 ruff check fix 2025-09-07 20:39:39 +01:00
Daulet Amirkhanov
1a5914e0ab ruff format 2025-09-07 20:39:39 +01:00
Daulet Amirkhanov
d2a2ade643 mypy fix: Fix ChromaDBAdapter mypy errors 2025-09-07 20:39:39 +01:00
Daulet Amirkhanov
0fb962e29a mypy: version Neo4j adapter 2025-09-07 20:39:13 +01:00
Daulet Amirkhanov
b9cd847e9d Remove Memgraph and references to it 2025-09-07 20:38:37 +01:00
Daulet Amirkhanov
aa74672aeb mypy: fix RemoteKuzuAdapter mypy errors 2025-09-07 20:38:26 +01:00
Daulet Amirkhanov
9a2bf0f137 kuzu - improve type inference for connection 2025-09-07 20:38:14 +01:00
Daulet Amirkhanov
b59087841b mypy: first fix KuzuAdapter mypy errors 2025-09-07 20:38:14 +01:00
Daulet Amirkhanov
deaf3debbf mypy: fix PGVectorAdapter mypy errors 2025-09-07 20:37:39 +01:00
Daulet Amirkhanov
eebca89855 mypy: fix LanceDBAdapter mypy errors 2025-09-07 20:37:09 +01:00
Daulet Amirkhanov
4ae41fede3 mypy fix: Fix ChromaDBAdapter mypy errors 2025-09-03 18:10:01 +01:00
Daulet Amirkhanov
26f5ab4f0f mypy: ignore missing imports for third party adapter libraries 2025-09-03 17:34:20 +01:00
Daulet Amirkhanov
baffd9187e adding temporary mypy scripts 2025-09-03 17:33:59 +01:00
Daulet Amirkhanov
85f37a6ee5 make protocols_mypy workflow manually dispatchable 2025-09-03 17:03:40 +01:00
Daulet Amirkhanov
aa686cefe8 Potential fix for code scanning alert no. 150: Workflow does not contain permissions
Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
2025-09-03 16:59:09 +01:00
Daulet Amirkhanov
636d38c018 refactor: remove old MyPy workflow and add new database adapter MyPy check workflow 2025-09-03 16:59:09 +01:00
Daulet Amirkhanov
7d80701381 chore: update mypy and create a GitHub workflow 2025-09-03 16:59:09 +01:00
18 changed files with 648 additions and 1524 deletions

View file

@ -0,0 +1,76 @@
permissions:
contents: read
name: Database Adapter MyPy Check
on:
workflow_dispatch:
push:
branches: [ main, dev ]
paths:
- 'cognee/infrastructure/databases/**'
- 'tools/check_*_adapters.sh'
- 'mypy.ini'
- '.github/workflows/database_protocol_mypy_check.yml'
pull_request:
branches: [ main, dev ]
paths:
- 'cognee/infrastructure/databases/**'
- 'tools/check_*_adapters.sh'
- 'mypy.ini'
- '.github/workflows/database_protocol_mypy_check.yml'
env:
RUNTIME__LOG_LEVEL: ERROR
ENV: 'dev'
jobs:
mypy-database-adapters:
name: MyPy Database Adapter Type Check
runs-on: ubuntu-22.04
steps:
- name: Check out repository
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11.x'
- name: Discover and Check Vector Database Adapters
run: ./tools/check_vector_adapters.sh
- name: Discover and Check Graph Database Adapters
run: ./tools/check_graph_adapters.sh
- name: Discover and Check Hybrid Database Adapters
run: ./tools/check_hybrid_adapters.sh
- name: Protocol Compliance Summary
run: |
echo "✅ Database Adapter MyPy Check Passed!"
echo ""
echo "🔍 Auto-Discovery Approach:"
echo " • Vector Adapters: cognee/infrastructure/databases/vector/**/*Adapter.py"
echo " • Graph Adapters: cognee/infrastructure/databases/graph/**/*adapter.py"
echo " • Hybrid Adapters: cognee/infrastructure/databases/hybrid/**/*Adapter.py"
echo ""
echo "🚀 Using Dedicated Scripts:"
echo " • Vector: ./tools/check_vector_adapters.sh"
echo " • Graph: ./tools/check_graph_adapters.sh"
echo " • Hybrid: ./tools/check_hybrid_adapters.sh"
echo " • All: ./tools/check_all_adapters.sh"
echo ""
echo "🎯 Purpose: Enforce that database adapters are properly typed"
echo "🔧 MyPy Configuration: mypy.ini (strict mode enabled)"
echo "🚀 Maintenance-Free: Automatically discovers new adapters"
echo ""
echo "⚠️ This workflow FAILS on any type errors to ensure adapter quality."
echo " All database adapters must be properly typed."
echo ""
echo "🛠️ To fix type issues locally, run:"
echo " ./tools/check_all_adapters.sh # Check all adapters"
echo " ./tools/check_vector_adapters.sh # Check vector adapters only"
echo " mypy <adapter_file_path> --config-file mypy.ini # Check specific file"

View file

@ -118,7 +118,7 @@ class HealthChecker:
# Test basic operation with actual graph query
if hasattr(engine, "execute"):
# For SQL-like graph DBs (Neo4j, Memgraph)
# For SQL-like graph DBs (Neo4j)
await engine.execute("MATCH () RETURN count(*) LIMIT 1")
elif hasattr(engine, "query"):
# For other graph engines

View file

@ -179,5 +179,5 @@ def create_graph_engine(
raise EnvironmentError(
f"Unsupported graph database provider: {graph_database_provider}. "
f"Supported providers are: {', '.join(list(supported_databases.keys()) + ['neo4j', 'falkordb', 'kuzu', 'kuzu-remote', 'memgraph', 'neptune', 'neptune_analytics'])}"
f"Supported providers are: {', '.join(list(supported_databases.keys()) + ['neo4j', 'falkordb', 'kuzu', 'kuzu-remote', 'neptune', 'neptune_analytics'])}"
)

View file

@ -10,7 +10,7 @@ from kuzu.database import Database
from datetime import datetime, timezone
from contextlib import asynccontextmanager
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, Any, List, Union, Optional, Tuple, Type
from typing import Dict, Any, List, Union, Optional, Tuple, Type, AsyncGenerator
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.utils.run_sync import run_sync
@ -22,7 +22,7 @@ from cognee.infrastructure.databases.graph.graph_db_interface import (
from cognee.infrastructure.engine import DataPoint
from cognee.modules.storage.utils import JSONEncoder
from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int
from cognee.tasks.temporal_graph.models import Timestamp
from cognee.modules.engine.models.Timestamp import Timestamp
logger = get_logger()
@ -146,15 +146,21 @@ class KuzuAdapter(GraphDBInterface):
logger.error(f"Failed to initialize Kuzu database: {e}")
raise e
def _get_connection(self) -> Connection:
"""Get the connection to the Kuzu database."""
if not self.connection:
raise RuntimeError("Kuzu database connection not initialized")
return self.connection
async def push_to_s3(self) -> None:
if os.getenv("STORAGE_BACKEND", "").lower() == "s3" and hasattr(self, "temp_graph_file"):
from cognee.infrastructure.files.storage.S3FileStorage import S3FileStorage
s3_file_storage = S3FileStorage("")
if self.connection:
if self._get_connection():
async with self.KUZU_ASYNC_LOCK:
self.connection.execute("CHECKPOINT;")
self._get_connection().execute("CHECKPOINT;")
s3_file_storage.s3.put(self.temp_graph_file, self.db_path, recursive=True)
@ -167,7 +173,9 @@ class KuzuAdapter(GraphDBInterface):
except FileNotFoundError:
logger.warning(f"Kuzu S3 storage file not found: {self.db_path}")
async def query(self, query: str, params: Optional[dict] = None) -> List[Tuple]:
async def query(
self, query: str, params: Optional[Dict[str, Any]] = None
) -> List[Tuple[Any, ...]]:
"""
Execute a Kuzu query asynchronously with automatic reconnection.
@ -190,23 +198,32 @@ class KuzuAdapter(GraphDBInterface):
loop = asyncio.get_running_loop()
params = params or {}
def blocking_query():
def blocking_query() -> List[Tuple[Any, ...]]:
try:
if not self.connection:
if not self._get_connection():
logger.debug("Reconnecting to Kuzu database...")
self._initialize_connection()
result = self.connection.execute(query, params)
if not self._get_connection():
raise RuntimeError("Failed to establish database connection")
result = self._get_connection().execute(query, params)
rows = []
while result.has_next():
row = result.get_next()
processed_rows = []
for val in row:
if hasattr(val, "as_py"):
val = val.as_py()
processed_rows.append(val)
rows.append(tuple(processed_rows))
if not isinstance(result, list):
result = [result]
# Handle QueryResult vs List[QueryResult] union type
for single_result in result:
while single_result.has_next():
row = single_result.get_next()
processed_rows = []
for val in row:
if hasattr(val, "as_py"):
val = val.as_py()
processed_rows.append(val)
rows.append(tuple(processed_rows))
return rows
except Exception as e:
logger.error(f"Query execution failed: {str(e)}")
@ -215,7 +232,7 @@ class KuzuAdapter(GraphDBInterface):
return await loop.run_in_executor(self.executor, blocking_query)
@asynccontextmanager
async def get_session(self):
async def get_session(self) -> AsyncGenerator[Optional[Connection], None]:
"""
Get a database session.
@ -224,7 +241,7 @@ class KuzuAdapter(GraphDBInterface):
and on exit performs cleanup if necessary.
"""
try:
yield self.connection
yield self._get_connection()
finally:
pass
@ -255,7 +272,7 @@ class KuzuAdapter(GraphDBInterface):
def _edge_query_and_params(
self, from_node: str, to_node: str, relationship_name: str, properties: Dict[str, Any]
) -> Tuple[str, dict]:
) -> Tuple[str, Dict[str, Any]]:
"""Build the edge creation query and parameters."""
now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")
query = """
@ -305,7 +322,9 @@ class KuzuAdapter(GraphDBInterface):
result = await self.query(query_str, {"id": node_id})
return result[0][0] if result else False
async def add_node(self, node: DataPoint) -> None:
async def add_node(
self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None
) -> None:
"""
Add a single node to the graph if it doesn't exist.
@ -319,20 +338,32 @@ class KuzuAdapter(GraphDBInterface):
- node (DataPoint): The node to be added, represented as a DataPoint.
"""
try:
properties = node.model_dump() if hasattr(node, "model_dump") else vars(node)
if isinstance(node, str):
# Handle string node ID with properties parameter
node_properties = properties or {}
core_properties = {
"id": node,
"name": str(node_properties.get("name", "")),
"type": str(node_properties.get("type", "")),
}
# Use the passed properties, excluding core fields
other_properties = {
k: v for k, v in node_properties.items() if k not in ["id", "name", "type"]
}
else:
# Handle DataPoint object
node_properties = node.model_dump()
core_properties = {
"id": str(node_properties.get("id", "")),
"name": str(node_properties.get("name", "")),
"type": str(node_properties.get("type", "")),
}
# Remove core fields from other properties
other_properties = {
k: v for k, v in node_properties.items() if k not in ["id", "name", "type"]
}
# Extract core fields with defaults if not present
core_properties = {
"id": str(properties.get("id", "")),
"name": str(properties.get("name", "")),
"type": str(properties.get("type", "")),
}
# Remove core fields from other properties
for key in core_properties:
properties.pop(key, None)
core_properties["properties"] = json.dumps(properties, cls=JSONEncoder)
core_properties["properties"] = json.dumps(other_properties, cls=JSONEncoder)
# Add timestamps for new node
now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")
@ -360,7 +391,7 @@ class KuzuAdapter(GraphDBInterface):
logger.error(f"Failed to add node: {e}")
raise
@record_graph_changes
@record_graph_changes # type: ignore
async def add_nodes(self, nodes: List[DataPoint]) -> None:
"""
Add multiple nodes to the graph in a batch operation.
@ -568,7 +599,9 @@ class KuzuAdapter(GraphDBInterface):
)
return result[0][0] if result else False
async def has_edges(self, edges: List[Tuple[str, str, str]]) -> List[Tuple[str, str, str]]:
async def has_edges(
self, edges: List[Tuple[str, str, str, Dict[str, Any]]]
) -> List[Tuple[str, str, str, Dict[str, Any]]]:
"""
Check if multiple edges exist in a batch operation.
@ -599,7 +632,7 @@ class KuzuAdapter(GraphDBInterface):
"to_id": str(to_node), # Ensure string type
"relationship_name": str(edge_label), # Ensure string type
}
for from_node, to_node, edge_label in edges
for from_node, to_node, edge_label, _ in edges
]
# Batch check query with direct string comparison
@ -615,9 +648,21 @@ class KuzuAdapter(GraphDBInterface):
results = await self.query(query, {"edges": edge_params})
# Convert results back to tuples and ensure string types
existing_edges = [(str(row[0]), str(row[1]), str(row[2])) for row in results]
# Find the original edge properties for each existing edge
# TODO: get review on this
existing_edges = []
for row in results:
from_id, to_id, rel_name = str(row[0]), str(row[1]), str(row[2])
# Find the original properties from the input edges
original_props = {}
for orig_from, orig_to, orig_rel, orig_props in edges:
if orig_from == from_id and orig_to == to_id and orig_rel == rel_name:
original_props = orig_props
break
existing_edges.append((from_id, to_id, rel_name, original_props))
logger.debug(f"Found {len(existing_edges)} existing edges out of {len(edges)} checked")
# TODO: otherwise, we can just return dummy properties since they are not used apparently
return existing_edges
except Exception as e:
@ -626,10 +671,10 @@ class KuzuAdapter(GraphDBInterface):
async def add_edge(
self,
from_node: str,
to_node: str,
source_id: str,
target_id: str,
relationship_name: str,
edge_properties: Dict[str, Any] = {},
properties: Optional[Dict[str, Any]] = None,
) -> None:
"""
Add an edge between two nodes.
@ -641,23 +686,23 @@ class KuzuAdapter(GraphDBInterface):
Parameters:
-----------
- from_node (str): The identifier of the source node from which the edge originates.
- to_node (str): The identifier of the target node to which the edge points.
- source_id (str): The identifier of the source node from which the edge originates.
- target_id (str): The identifier of the target node to which the edge points.
- relationship_name (str): The label of the edge to be created, representing the
relationship name.
- edge_properties (Dict[str, Any]): A dictionary containing properties for the edge.
(default {})
- properties (Optional[Dict[str, Any]]): A dictionary containing properties for the edge.
(default None)
"""
try:
query, params = self._edge_query_and_params(
from_node, to_node, relationship_name, edge_properties
source_id, target_id, relationship_name, properties or {}
)
await self.query(query, params)
except Exception as e:
logger.error(f"Failed to add edge: {e}")
raise
@record_graph_changes
@record_graph_changes # type: ignore
async def add_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> None:
"""
Add multiple edges in a batch operation.
@ -712,7 +757,7 @@ class KuzuAdapter(GraphDBInterface):
logger.error(f"Failed to add edges in batch: {e}")
raise
async def get_edges(self, node_id: str) -> List[Tuple[Dict[str, Any], str, Dict[str, Any]]]:
async def get_edges(self, node_id: str) -> List[Tuple[str, str, str, Dict[str, Any]]]:
"""
Get all edges connected to a node.
@ -727,9 +772,8 @@ class KuzuAdapter(GraphDBInterface):
Returns:
--------
- List[Tuple[Dict[str, Any], str, Dict[str, Any]]]: A list of tuples where each
tuple contains (source_node, relationship_name, target_node), with source_node and
target_node as dictionaries of node properties.
- List[Tuple[str, str, str, Dict[str, Any]]]: A list of tuples where each
tuple contains (source_id, relationship_name, target_id, edge_properties).
"""
query_str = """
MATCH (n:Node)-[r]-(m:Node)
@ -750,12 +794,14 @@ class KuzuAdapter(GraphDBInterface):
"""
try:
results = await self.query(query_str, {"node_id": node_id})
edges = []
edges: List[Tuple[str, str, str, Dict[str, Any]]] = []
for row in results:
if row and len(row) == 3:
source_node = self._parse_node_properties(row[0])
relationship_name = row[1]
target_node = self._parse_node_properties(row[2])
edges.append((source_node, row[1], target_node))
# TODO: any edge properties we can add? Adding empty to avoid modifying query without reason
edges.append((source_node, relationship_name, target_node, {})) # type: ignore # currently each node is a dict, wihle typing expects nodes to be strings
return edges
except Exception as e:
logger.error(f"Failed to get edges for node {node_id}: {e}")
@ -977,7 +1023,7 @@ class KuzuAdapter(GraphDBInterface):
return []
async def get_connections(
self, node_id: str
self, node_id: Union[str, UUID]
) -> List[Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]]:
"""
Get all nodes connected to a given node.
@ -1019,7 +1065,9 @@ class KuzuAdapter(GraphDBInterface):
}
"""
try:
results = await self.query(query_str, {"node_id": node_id})
# Convert UUID to string if needed
node_id_str = str(node_id)
results = await self.query(query_str, {"node_id": node_id_str})
edges = []
for row in results:
if row and len(row) == 3:
@ -1177,7 +1225,7 @@ class KuzuAdapter(GraphDBInterface):
async def get_nodeset_subgraph(
self, node_type: Type[Any], node_name: List[str]
) -> Tuple[List[Tuple[str, dict]], List[Tuple[str, str, str, dict]]]:
) -> Tuple[List[Tuple[int, Dict[str, Any]]], List[Tuple[int, int, str, Dict[str, Any]]]]:
"""
Get subgraph for a set of nodes based on type and names.
@ -1225,9 +1273,9 @@ class KuzuAdapter(GraphDBInterface):
RETURN n.id, n.name, n.type, n.properties
"""
node_rows = await self.query(nodes_query, {"ids": all_ids})
nodes: List[Tuple[str, dict]] = []
nodes: List[Tuple[str, Dict[str, Any]]] = []
for node_id, name, typ, props in node_rows:
data = {"id": node_id, "name": name, "type": typ}
data: Dict[str, Any] = {"id": node_id, "name": name, "type": typ}
if props:
try:
data.update(json.loads(props))
@ -1241,22 +1289,22 @@ class KuzuAdapter(GraphDBInterface):
RETURN a.id, b.id, r.relationship_name, r.properties
"""
edge_rows = await self.query(edges_query, {"ids": all_ids})
edges: List[Tuple[str, str, str, dict]] = []
edges: List[Tuple[str, str, str, Dict[str, Any]]] = []
for from_id, to_id, rel_type, props in edge_rows:
data = {}
edge_data: Dict[str, Any] = {}
if props:
try:
data = json.loads(props)
edge_data = json.loads(props)
except json.JSONDecodeError:
logger.warning(f"Failed to parse JSON props for edge {from_id}->{to_id}")
edges.append((from_id, to_id, rel_type, data))
edges.append((from_id, to_id, rel_type, edge_data))
return nodes, edges
return nodes, edges # type: ignore # Interface expects int IDs but string IDs are more natural for graph DBs
async def get_filtered_graph_data(
self, attribute_filters: List[Dict[str, List[Union[str, int]]]]
):
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
"""
Get filtered nodes and relationships based on attributes.
@ -1299,7 +1347,7 @@ class KuzuAdapter(GraphDBInterface):
)
return ([n[0] for n in nodes], [e[0] for e in edges])
async def get_graph_metrics(self, include_optional=False) -> Dict[str, Any]:
async def get_graph_metrics(self, include_optional: bool = False) -> Dict[str, Any]:
"""
Get metrics on graph structure and connectivity.
@ -1322,8 +1370,8 @@ class KuzuAdapter(GraphDBInterface):
try:
# Get basic graph data
nodes, edges = await self.get_model_independent_graph_data()
num_nodes = len(nodes[0]["nodes"]) if nodes else 0
num_edges = len(edges[0]["elements"]) if edges else 0
num_nodes = len(nodes[0]["nodes"]) if nodes else 0 # type: ignore # nodes is type string?
num_edges = len(edges[0]["elements"]) if edges else 0 # type: ignore # edges is type string?
# Calculate mandatory metrics
mandatory_metrics = {
@ -1483,8 +1531,8 @@ class KuzuAdapter(GraphDBInterface):
It raises exceptions for failures occurring during deletion processes.
"""
try:
if self.connection:
self.connection.close()
if self._get_connection():
self._get_connection().close()
self.connection = None
if self.db:
self.db.close()
@ -1515,7 +1563,7 @@ class KuzuAdapter(GraphDBInterface):
occur during file deletions or initializations carefully.
"""
try:
if self.connection:
if self._get_connection():
self.connection = None
if self.db:
self.db.close()
@ -1531,20 +1579,30 @@ class KuzuAdapter(GraphDBInterface):
# Reinitialize the database
self._initialize_connection()
if not self._get_connection():
raise RuntimeError("Failed to establish database connection")
# Verify the database is empty
result = self.connection.execute("MATCH (n:Node) RETURN COUNT(n)")
count = result.get_next()[0] if result.has_next() else 0
result = self._get_connection().execute("MATCH (n:Node) RETURN COUNT(n)")
if not isinstance(result, list):
result = [result]
for single_result in result:
_next = single_result.get_next()
if not isinstance(_next, list):
raise RuntimeError("Expected list of results")
count = _next[0] if _next else 0
if count > 0:
logger.warning(
f"Database still contains {count} nodes after clearing, forcing deletion"
)
self.connection.execute("MATCH (n:Node) DETACH DELETE n")
self._get_connection().execute("MATCH (n:Node) DETACH DELETE n")
logger.info("Database cleared successfully")
except Exception as e:
logger.error(f"Error during database clearing: {e}")
raise
async def get_document_subgraph(self, data_id: str):
async def get_document_subgraph(self, data_id: str) -> Optional[Dict[str, Any]]:
"""
Get all nodes that should be deleted when removing a document.
@ -1616,7 +1674,7 @@ class KuzuAdapter(GraphDBInterface):
"orphan_types": result[0][4],
}
async def get_degree_one_nodes(self, node_type: str):
async def get_degree_one_nodes(self, node_type: str) -> List[Dict[str, Any]]:
"""
Get all nodes that have only one connection.
@ -1769,8 +1827,8 @@ class KuzuAdapter(GraphDBInterface):
ids: List[str] = []
if time_from and time_to:
time_from = date_to_int(time_from)
time_to = date_to_int(time_to)
time_from_int = date_to_int(time_from)
time_to_int = date_to_int(time_to)
cypher = f"""
MATCH (n:Node)
@ -1782,13 +1840,13 @@ class KuzuAdapter(GraphDBInterface):
WHEN t_str IS NULL OR t_str = '' THEN NULL
ELSE CAST(t_str AS INT64)
END AS t
WHERE t >= {time_from}
AND t <= {time_to}
WHERE t >= {time_from_int}
AND t <= {time_to_int}
RETURN n.id as id
"""
elif time_from:
time_from = date_to_int(time_from)
time_from_int = date_to_int(time_from)
cypher = f"""
MATCH (n:Node)
@ -1800,12 +1858,12 @@ class KuzuAdapter(GraphDBInterface):
WHEN t_str IS NULL OR t_str = '' THEN NULL
ELSE CAST(t_str AS INT64)
END AS t
WHERE t >= {time_from}
WHERE t >= {time_from_int}
RETURN n.id as id
"""
elif time_to:
time_to = date_to_int(time_to)
time_to_int = date_to_int(time_to)
cypher = f"""
MATCH (n:Node)
@ -1817,12 +1875,12 @@ class KuzuAdapter(GraphDBInterface):
WHEN t_str IS NULL OR t_str = '' THEN NULL
ELSE CAST(t_str AS INT64)
END AS t
WHERE t <= {time_to}
WHERE t <= {time_to_int}
RETURN n.id as id
"""
else:
return ids
return ", ".join(f"'{uid}'" for uid in ids)
time_nodes = await self.query(cypher)
time_ids_list = [item[0] for item in time_nodes]

View file

@ -2,7 +2,7 @@
from cognee.shared.logging_utils import get_logger
import json
from typing import Dict, Any, List, Optional, Tuple
from typing import Dict, Any, List, Optional, Tuple, Union
import aiohttp
from uuid import UUID
@ -14,7 +14,7 @@ logger = get_logger()
class UUIDEncoder(json.JSONEncoder):
"""Custom JSON encoder that handles UUID objects."""
def default(self, obj):
def default(self, obj: Union[UUID, Any]) -> Any:
if isinstance(obj, UUID):
return str(obj)
return super().default(obj)
@ -36,7 +36,7 @@ class RemoteKuzuAdapter(KuzuAdapter):
self.api_url = api_url
self.username = username
self.password = password
self._session = None
self._session: Optional[aiohttp.ClientSession] = None
self._schema_initialized = False
async def _get_session(self) -> aiohttp.ClientSession:
@ -45,13 +45,13 @@ class RemoteKuzuAdapter(KuzuAdapter):
self._session = aiohttp.ClientSession()
return self._session
async def close(self):
async def close(self) -> None:
"""Close the adapter and its session."""
if self._session and not self._session.closed:
await self._session.close()
self._session = None
async def _make_request(self, endpoint: str, data: dict) -> dict:
async def _make_request(self, endpoint: str, data: dict[str, Any]) -> dict[str, Any]:
"""Make a request to the Kuzu API."""
url = f"{self.api_url}{endpoint}"
session = await self._get_session()
@ -73,13 +73,15 @@ class RemoteKuzuAdapter(KuzuAdapter):
status=response.status,
message=error_detail,
)
return await response.json()
return await response.json() # type: ignore
except aiohttp.ClientError as e:
logger.error(f"API request failed: {str(e)}")
logger.error(f"Request data: {data}")
raise
async def query(self, query: str, params: Optional[dict] = None) -> List[Tuple]:
async def query(
self, query: str, params: Optional[dict[str, Any]] = None
) -> List[Tuple[Any, ...]]:
"""Execute a Kuzu query via the REST API."""
try:
# Initialize schema if needed
@ -126,7 +128,7 @@ class RemoteKuzuAdapter(KuzuAdapter):
logger.error(f"Failed to check schema: {e}")
return False
async def _create_schema(self):
async def _create_schema(self) -> None:
"""Create the required schema tables."""
try:
# Create Node table if it doesn't exist
@ -180,7 +182,7 @@ class RemoteKuzuAdapter(KuzuAdapter):
logger.error(f"Failed to create schema: {e}")
raise
async def _initialize_schema(self):
async def _initialize_schema(self) -> None:
"""Initialize the database schema if it doesn't exist."""
if self._schema_initialized:
return

View file

@ -8,11 +8,11 @@ from neo4j import AsyncSession
from neo4j import AsyncGraphDatabase
from neo4j.exceptions import Neo4jError
from contextlib import asynccontextmanager
from typing import Optional, Any, List, Dict, Type, Tuple
from typing import Optional, Any, List, Dict, Type, Tuple, Union, AsyncGenerator
from cognee.infrastructure.engine import DataPoint
from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int
from cognee.tasks.temporal_graph.models import Timestamp
from cognee.modules.engine.models.Timestamp import Timestamp
from cognee.shared.logging_utils import get_logger, ERROR
from cognee.infrastructure.databases.graph.graph_db_interface import (
GraphDBInterface,
@ -79,14 +79,14 @@ class Neo4jAdapter(GraphDBInterface):
)
@asynccontextmanager
async def get_session(self) -> AsyncSession:
async def get_session(self) -> AsyncGenerator[AsyncSession, None]:
"""
Get a session for database operations.
"""
async with self.driver.session(database=self.graph_database_name) as session:
yield session
@deadlock_retry()
@deadlock_retry() # type: ignore
async def query(
self,
query: str,
@ -112,6 +112,7 @@ class Neo4jAdapter(GraphDBInterface):
async with self.get_session() as session:
result = await session.run(query, parameters=params)
data = await result.data()
# TODO: why we don't return List[Dict[str, Any]]?
return data
except Neo4jError as error:
logger.error("Neo4j query error: %s", error, exc_info=True)
@ -141,21 +142,29 @@ class Neo4jAdapter(GraphDBInterface):
)
return results[0]["node_exists"] if len(results) > 0 else False
async def add_node(self, node: DataPoint):
async def add_node(
self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None
) -> None:
"""
Add a new node to the database based on the provided DataPoint object.
Add a new node to the database based on the provided DataPoint object or string ID.
Parameters:
-----------
- node (DataPoint): An instance of DataPoint representing the node to add.
Returns:
--------
The result of the query execution, typically the ID of the added node.
- node (Union[DataPoint, str]): An instance of DataPoint or string ID representing the node to add.
- properties (Optional[Dict[str, Any]]): Properties to set on the node when node is a string ID.
"""
serialized_properties = self.serialize_properties(node.model_dump())
if isinstance(node, str):
# TODO: this was not handled in the original code, check if it is correct
# Handle string node ID with properties parameter
node_id = node
node_label = "Node" # Default label for string nodes
serialized_properties = self.serialize_properties(properties or {})
else:
# Handle DataPoint object
node_id = str(node.id)
node_label = type(node).__name__
serialized_properties = self.serialize_properties(node.model_dump())
query = dedent(
f"""MERGE (node: `{BASE_LABEL}`{{id: $node_id}})
@ -167,16 +176,16 @@ class Neo4jAdapter(GraphDBInterface):
)
params = {
"node_id": str(node.id),
"node_label": type(node).__name__,
"node_id": node_id,
"node_label": node_label,
"properties": serialized_properties,
}
return await self.query(query, params)
await self.query(query, params)
@record_graph_changes
@override_distributed(queued_add_nodes)
async def add_nodes(self, nodes: list[DataPoint]) -> None:
@record_graph_changes # type: ignore
@override_distributed(queued_add_nodes) # type: ignore
async def add_nodes(self, nodes: List[DataPoint]) -> None:
"""
Add multiple nodes to the database in a single query.
@ -201,7 +210,7 @@ class Neo4jAdapter(GraphDBInterface):
RETURN ID(labeledNode) AS internal_id, labeledNode.id AS nodeId
"""
nodes = [
node_params = [
{
"node_id": str(node.id),
"label": type(node).__name__,
@ -210,10 +219,9 @@ class Neo4jAdapter(GraphDBInterface):
for node in nodes
]
results = await self.query(query, dict(nodes=nodes))
return results
await self.query(query, dict(nodes=node_params))
async def extract_node(self, node_id: str):
async def extract_node(self, node_id: str) -> Optional[Dict[str, Any]]:
"""
Retrieve a single node from the database by its ID.
@ -231,7 +239,7 @@ class Neo4jAdapter(GraphDBInterface):
return results[0] if len(results) > 0 else None
async def extract_nodes(self, node_ids: List[str]):
async def extract_nodes(self, node_ids: List[str]) -> List[Dict[str, Any]]:
"""
Retrieve multiple nodes from the database by their IDs.
@ -256,7 +264,7 @@ class Neo4jAdapter(GraphDBInterface):
return [result["node"] for result in results]
async def delete_node(self, node_id: str):
async def delete_node(self, node_id: str) -> None:
"""
Remove a node from the database identified by its ID.
@ -273,7 +281,7 @@ class Neo4jAdapter(GraphDBInterface):
query = f"MATCH (node: `{BASE_LABEL}`{{id: $node_id}}) DETACH DELETE node"
params = {"node_id": node_id}
return await self.query(query, params)
await self.query(query, params)
async def delete_nodes(self, node_ids: list[str]) -> None:
"""
@ -296,18 +304,18 @@ class Neo4jAdapter(GraphDBInterface):
params = {"node_ids": node_ids}
return await self.query(query, params)
await self.query(query, params)
async def has_edge(self, from_node: UUID, to_node: UUID, edge_label: str) -> bool:
async def has_edge(self, source_id: str, target_id: str, relationship_name: str) -> bool:
"""
Check if an edge exists between two nodes with the specified IDs and edge label.
Parameters:
-----------
- from_node (UUID): The ID of the node from which the edge originates.
- to_node (UUID): The ID of the node to which the edge points.
- edge_label (str): The label of the edge to check for existence.
- source_id (str): The ID of the node from which the edge originates.
- target_id (str): The ID of the node to which the edge points.
- relationship_name (str): The label of the edge to check for existence.
Returns:
--------
@ -315,27 +323,28 @@ class Neo4jAdapter(GraphDBInterface):
- bool: True if the edge exists, otherwise False.
"""
query = f"""
MATCH (from_node: `{BASE_LABEL}`)-[:`{edge_label}`]->(to_node: `{BASE_LABEL}`)
WHERE from_node.id = $from_node_id AND to_node.id = $to_node_id
MATCH (from_node: `{BASE_LABEL}`)-[:`{relationship_name}`]->(to_node: `{BASE_LABEL}`)
WHERE from_node.id = $source_id AND to_node.id = $target_id
RETURN COUNT(relationship) > 0 AS edge_exists
"""
params = {
"from_node_id": str(from_node),
"to_node_id": str(to_node),
"source_id": str(source_id),
"target_id": str(target_id),
}
edge_exists = await self.query(query, params)
assert isinstance(edge_exists, bool), "Edge existence check should return a boolean"
return edge_exists
async def has_edges(self, edges):
async def has_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> List[bool]:
"""
Check if multiple edges exist based on provided edge criteria.
Parameters:
-----------
- edges: A list of edge specifications to check for existence.
- edges: A list of edge specifications to check for existence. (source_id, target_id, relationship_name, properties)
Returns:
--------
@ -369,29 +378,24 @@ class Neo4jAdapter(GraphDBInterface):
async def add_edge(
self,
from_node: UUID,
to_node: UUID,
source_id: str,
target_id: str,
relationship_name: str,
edge_properties: Optional[Dict[str, Any]] = {},
):
properties: Optional[Dict[str, Any]] = None,
) -> None:
"""
Create a new edge between two nodes with specified properties.
Parameters:
-----------
- from_node (UUID): The ID of the source node of the edge.
- to_node (UUID): The ID of the target node of the edge.
- source_id (str): The ID of the source node of the edge.
- target_id (str): The ID of the target node of the edge.
- relationship_name (str): The type/label of the edge to create.
- edge_properties (Optional[Dict[str, Any]]): A dictionary of properties to assign
to the edge. (default {})
Returns:
--------
The result of the query execution, typically indicating the created edge.
- properties (Optional[Dict[str, Any]]): A dictionary of properties to assign
to the edge. (default None)
"""
serialized_properties = self.serialize_properties(edge_properties)
serialized_properties = self.serialize_properties(properties or {})
query = dedent(
f"""\
@ -405,13 +409,13 @@ class Neo4jAdapter(GraphDBInterface):
)
params = {
"from_node": str(from_node),
"to_node": str(to_node),
"from_node": str(source_id), # Adding str as callsites may still be passing UUID
"to_node": str(target_id),
"relationship_name": relationship_name,
"properties": serialized_properties,
}
return await self.query(query, params)
await self.query(query, params)
def _flatten_edge_properties(self, properties: Dict[str, Any]) -> Dict[str, Any]:
"""
@ -445,9 +449,9 @@ class Neo4jAdapter(GraphDBInterface):
return flattened
@record_graph_changes
@override_distributed(queued_add_edges)
async def add_edges(self, edges: list[tuple[str, str, str, dict[str, Any]]]) -> None:
@record_graph_changes # type: ignore
@override_distributed(queued_add_edges) # type: ignore
async def add_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> None:
"""
Add multiple edges between nodes in a single query.
@ -478,10 +482,10 @@ class Neo4jAdapter(GraphDBInterface):
) YIELD rel
RETURN rel"""
edges = [
edge_params = [
{
"from_node": str(edge[0]),
"to_node": str(edge[1]),
"from_node": str(edge[0]), # Adding str as callsites may still be passing UUID
"to_node": str(edge[1]), # Adding str as callsites may still be passing UUID
"relationship_name": edge[2],
"properties": self._flatten_edge_properties(
{
@ -495,13 +499,12 @@ class Neo4jAdapter(GraphDBInterface):
]
try:
results = await self.query(query, dict(edges=edges))
return results
await self.query(query, dict(edges=edge_params))
except Neo4jError as error:
logger.error("Neo4j query error: %s", error, exc_info=True)
raise error
async def get_edges(self, node_id: str):
async def get_edges(self, node_id: str) -> List[Tuple[str, str, str, Dict[str, Any]]]:
"""
Retrieve all edges connected to a specified node.

View file

@ -1,12 +1,13 @@
import json
import asyncio
from uuid import UUID
from typing import List, Optional
from typing import List, Optional, Dict, Any
from chromadb import AsyncHttpClient, Settings
from cognee.shared.logging_utils import get_logger
from cognee.modules.storage.utils import get_own_properties
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.models.DataPoint import MetaData
from cognee.infrastructure.engine.utils import parse_id
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult
@ -35,9 +36,9 @@ class IndexSchema(DataPoint):
text: str
metadata: dict = {"index_fields": ["text"]}
metadata: MetaData = {"index_fields": ["text"], "type": "IndexSchema"}
def model_dump(self):
def model_dump(self, **kwargs: Any) -> Dict[str, Any]:
"""
Serialize the instance data for storage.
@ -49,11 +50,11 @@ class IndexSchema(DataPoint):
A dictionary containing serialized data processed for ChromaDB storage.
"""
data = super().model_dump()
data = super().model_dump(**kwargs)
return process_data_for_chroma(data)
def process_data_for_chroma(data):
def process_data_for_chroma(data: Dict[str, Any]) -> Dict[str, Any]:
"""
Convert complex data types to a format suitable for ChromaDB storage.
@ -73,7 +74,7 @@ def process_data_for_chroma(data):
A dictionary containing the processed key-value pairs suitable for ChromaDB storage.
"""
processed_data = {}
processed_data: Dict[str, Any] = {}
for key, value in data.items():
if isinstance(value, UUID):
processed_data[key] = str(value)
@ -90,7 +91,7 @@ def process_data_for_chroma(data):
return processed_data
def restore_data_from_chroma(data):
def restore_data_from_chroma(data: Dict[str, Any]) -> Dict[str, Any]:
"""
Restore original data structure from ChromaDB storage format.
@ -152,8 +153,8 @@ class ChromaDBAdapter(VectorDBInterface):
"""
name = "ChromaDB"
url: str
api_key: str
url: str | None
api_key: str | None
connection: AsyncHttpClient = None
def __init__(
@ -216,7 +217,9 @@ class ChromaDBAdapter(VectorDBInterface):
collections = await self.get_collection_names()
return collection_name in collections
async def create_collection(self, collection_name: str, payload_schema=None):
async def create_collection(
self, collection_name: str, payload_schema: Optional[Any] = None
) -> None:
"""
Create a new collection in ChromaDB if it does not already exist.
@ -254,7 +257,7 @@ class ChromaDBAdapter(VectorDBInterface):
client = await self.get_connection()
return await client.get_collection(collection_name)
async def create_data_points(self, collection_name: str, data_points: list[DataPoint]):
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]) -> None:
"""
Create and upsert data points into the specified collection in ChromaDB.
@ -282,7 +285,7 @@ class ChromaDBAdapter(VectorDBInterface):
ids=ids, embeddings=embeddings, metadatas=metadatas, documents=texts
)
async def create_vector_index(self, index_name: str, index_property_name: str):
async def create_vector_index(self, index_name: str, index_property_name: str) -> None:
"""
Create a vector index as a ChromaDB collection based on provided names.
@ -296,7 +299,7 @@ class ChromaDBAdapter(VectorDBInterface):
async def index_data_points(
self, index_name: str, index_property_name: str, data_points: list[DataPoint]
):
) -> None:
"""
Index the provided data points based on the specified index property in ChromaDB.
@ -315,10 +318,11 @@ class ChromaDBAdapter(VectorDBInterface):
text=getattr(data_point, data_point.metadata["index_fields"][0]),
)
for data_point in data_points
if data_point.metadata and len(data_point.metadata["index_fields"]) > 0
],
)
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
async def retrieve(self, collection_name: str, data_point_ids: List[str]) -> List[ScoredResult]:
"""
Retrieve data points by their IDs from a ChromaDB collection.
@ -350,12 +354,12 @@ class ChromaDBAdapter(VectorDBInterface):
async def search(
self,
collection_name: str,
query_text: str = None,
query_vector: List[float] = None,
query_text: Optional[str] = None,
query_vector: Optional[List[float]] = None,
limit: int = 15,
with_vector: bool = False,
normalized: bool = True,
):
) -> List[ScoredResult]:
"""
Search for items in a collection using either a text or a vector query.
@ -437,7 +441,7 @@ class ChromaDBAdapter(VectorDBInterface):
query_texts: List[str],
limit: int = 5,
with_vectors: bool = False,
):
) -> List[List[ScoredResult]]:
"""
Perform multiple searches in a single request for efficiency, returning results for each
query.
@ -507,7 +511,7 @@ class ChromaDBAdapter(VectorDBInterface):
return all_results
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
async def delete_data_points(self, collection_name: str, data_point_ids: List[str]) -> bool:
"""
Remove data points from a collection based on their IDs.
@ -528,7 +532,7 @@ class ChromaDBAdapter(VectorDBInterface):
await collection.delete(ids=data_point_ids)
return True
async def prune(self):
async def prune(self) -> bool:
"""
Delete all collections in the ChromaDB database.
@ -538,12 +542,12 @@ class ChromaDBAdapter(VectorDBInterface):
Returns True upon successful deletion of all collections.
"""
client = await self.get_connection()
collections = await self.list_collections()
for collection_name in collections:
collection_names = await self.get_collection_names()
for collection_name in collection_names:
await client.delete_collection(collection_name)
return True
async def get_collection_names(self):
async def get_collection_names(self) -> List[str]:
"""
Retrieve the names of all collections in the ChromaDB database.

View file

@ -1,12 +1,25 @@
import asyncio
from os import path
from uuid import UUID
import lancedb
from pydantic import BaseModel
from lancedb.pydantic import LanceModel, Vector
from typing import Generic, List, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
from typing import (
Generic,
List,
Optional,
TypeVar,
Union,
get_args,
get_origin,
get_type_hints,
Dict,
Any,
)
from cognee.infrastructure.databases.exceptions import MissingQueryParameterError
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.models.DataPoint import MetaData
from cognee.infrastructure.engine.utils import parse_id
from cognee.infrastructure.files.storage import get_file_storage
from cognee.modules.storage.utils import copy_model, get_own_properties
@ -30,16 +43,16 @@ class IndexSchema(DataPoint):
to include 'text'.
"""
id: str
id: UUID
text: str
metadata: dict = {"index_fields": ["text"]}
metadata: MetaData = {"index_fields": ["text"], "type": "IndexSchema"}
class LanceDBAdapter(VectorDBInterface):
name = "LanceDB"
url: str
api_key: str
url: Optional[str]
api_key: Optional[str]
connection: lancedb.AsyncConnection = None
def __init__(
@ -53,7 +66,7 @@ class LanceDBAdapter(VectorDBInterface):
self.embedding_engine = embedding_engine
self.VECTOR_DB_LOCK = asyncio.Lock()
async def get_connection(self):
async def get_connection(self) -> lancedb.AsyncConnection:
"""
Establishes and returns a connection to the LanceDB.
@ -107,12 +120,11 @@ class LanceDBAdapter(VectorDBInterface):
collection_names = await connection.table_names()
return collection_name in collection_names
async def create_collection(self, collection_name: str, payload_schema: BaseModel):
async def create_collection(
self, collection_name: str, payload_schema: Optional[Any] = None
) -> None:
vector_size = self.embedding_engine.get_vector_size()
payload_schema = self.get_data_point_schema(payload_schema)
data_point_types = get_type_hints(payload_schema)
class LanceDataPoint(LanceModel):
"""
Represents a data point in the Lance model with an ID, vector, and associated payload.
@ -123,28 +135,28 @@ class LanceDBAdapter(VectorDBInterface):
- payload: Additional data or metadata associated with the data point.
"""
id: data_point_types["id"]
vector: Vector(vector_size)
payload: payload_schema
id: UUID
vector: Vector[vector_size]
payload: Dict[str, Any]
if not await self.has_collection(collection_name):
async with self.VECTOR_DB_LOCK:
if not await self.has_collection(collection_name):
connection = await self.get_connection()
return await connection.create_table(
await connection.create_table(
name=collection_name,
schema=LanceDataPoint,
exist_ok=True,
)
async def get_collection(self, collection_name: str):
async def get_collection(self, collection_name: str) -> Any:
if not await self.has_collection(collection_name):
raise CollectionNotFoundError(f"Collection '{collection_name}' not found!")
connection = await self.get_connection()
return await connection.open_table(collection_name)
async def create_data_points(self, collection_name: str, data_points: list[DataPoint]):
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]) -> None:
payload_schema = type(data_points[0])
if not await self.has_collection(collection_name):
@ -175,14 +187,14 @@ class LanceDBAdapter(VectorDBInterface):
"""
id: IdType
vector: Vector(vector_size)
vector: Vector[vector_size]
payload: PayloadSchema
def create_lance_data_point(data_point: DataPoint, vector: list[float]) -> LanceDataPoint:
def create_lance_data_point(data_point: DataPoint, vector: List[float]) -> Any:
properties = get_own_properties(data_point)
properties["id"] = str(properties["id"])
return LanceDataPoint[str, self.get_data_point_schema(type(data_point))](
return LanceDataPoint(
id=str(data_point.id),
vector=vector,
payload=properties,
@ -201,7 +213,7 @@ class LanceDBAdapter(VectorDBInterface):
.execute(lance_data_points)
)
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
async def retrieve(self, collection_name: str, data_point_ids: list[str]) -> List[ScoredResult]:
collection = await self.get_collection(collection_name)
if len(data_point_ids) == 1:
@ -221,12 +233,12 @@ class LanceDBAdapter(VectorDBInterface):
async def search(
self,
collection_name: str,
query_text: str = None,
query_vector: List[float] = None,
query_text: Optional[str] = None,
query_vector: Optional[List[float]] = None,
limit: int = 15,
with_vector: bool = False,
normalized: bool = True,
):
) -> List[ScoredResult]:
if query_text is None and query_vector is None:
raise MissingQueryParameterError()
@ -264,9 +276,9 @@ class LanceDBAdapter(VectorDBInterface):
self,
collection_name: str,
query_texts: List[str],
limit: int = None,
limit: Optional[int] = None,
with_vectors: bool = False,
):
) -> List[List[ScoredResult]]:
query_vectors = await self.embedding_engine.embed_text(query_texts)
return await asyncio.gather(
@ -274,40 +286,41 @@ class LanceDBAdapter(VectorDBInterface):
self.search(
collection_name=collection_name,
query_vector=query_vector,
limit=limit,
limit=limit or 15,
with_vector=with_vectors,
)
for query_vector in query_vectors
]
)
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
async def delete_data_points(self, collection_name: str, data_point_ids: List[str]) -> None:
collection = await self.get_collection(collection_name)
# Delete one at a time to avoid commit conflicts
for data_point_id in data_point_ids:
await collection.delete(f"id = '{data_point_id}'")
async def create_vector_index(self, index_name: str, index_property_name: str):
async def create_vector_index(self, index_name: str, index_property_name: str) -> None:
await self.create_collection(
f"{index_name}_{index_property_name}", payload_schema=IndexSchema
)
async def index_data_points(
self, index_name: str, index_property_name: str, data_points: list[DataPoint]
):
self, index_name: str, index_property_name: str, data_points: List[DataPoint]
) -> None:
await self.create_data_points(
f"{index_name}_{index_property_name}",
[
IndexSchema(
id=str(data_point.id),
id=data_point.id,
text=getattr(data_point, data_point.metadata["index_fields"][0]),
)
for data_point in data_points
if data_point.metadata and len(data_point.metadata.get("index_fields", [])) > 0
],
)
async def prune(self):
async def prune(self) -> None:
connection = await self.get_connection()
collection_names = await connection.table_names()
@ -316,12 +329,15 @@ class LanceDBAdapter(VectorDBInterface):
await collection.delete("id IS NOT NULL")
await connection.drop_table(collection_name)
if self.url.startswith("/"):
if self.url and self.url.startswith("/"):
db_dir_path = path.dirname(self.url)
db_file_name = path.basename(self.url)
await get_file_storage(db_dir_path).remove_all(db_file_name)
def get_data_point_schema(self, model_type: BaseModel):
def get_data_point_schema(self, model_type: Optional[Any]) -> Any:
if model_type is None:
return DataPoint
related_models_fields = []
for field_name, field_config in model_type.model_fields.items():

View file

@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any, Dict, List, Optional
from uuid import UUID
from pydantic import BaseModel
@ -14,8 +14,10 @@ class ScoredResult(BaseModel):
better outcome.
- payload (Dict[str, Any]): Additional information related to the score, stored as
key-value pairs in a dictionary.
- vector (Optional[List[float]]): Optional vector embedding associated with the result.
"""
id: UUID
score: float # Lower score is better
payload: Dict[str, Any]
vector: Optional[List[float]] = None

View file

@ -1,9 +1,9 @@
import asyncio
from typing import List, Optional, get_type_hints
from typing import List, Optional, get_type_hints, Dict, Any
from sqlalchemy.inspection import inspect
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy import JSON, Column, Table, select, delete, MetaData
from sqlalchemy import JSON, Table, select, delete, MetaData
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from sqlalchemy.exc import ProgrammingError
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
@ -12,6 +12,7 @@ from asyncpg import DeadlockDetectedError, DuplicateTableError, UniqueViolationE
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.models.DataPoint import MetaData as DataPointMetaData
from cognee.infrastructure.engine.utils import parse_id
from cognee.infrastructure.databases.relational import get_relational_engine
@ -42,7 +43,7 @@ class IndexSchema(DataPoint):
text: str
metadata: dict = {"index_fields": ["text"]}
metadata: DataPointMetaData = {"index_fields": ["text"], "type": "IndexSchema"}
class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
@ -122,8 +123,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=2, min=1, max=6),
)
async def create_collection(self, collection_name: str, payload_schema=None):
data_point_types = get_type_hints(DataPoint)
async def create_collection(
self, collection_name: str, payload_schema: Optional[Any] = None
) -> None:
vector_size = self.embedding_engine.get_vector_size()
async with self.VECTOR_DB_LOCK:
@ -147,29 +149,31 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
__tablename__ = collection_name
__table_args__ = {"extend_existing": True}
# PGVector requires one column to be the primary key
id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True)
payload = Column(JSON)
vector = Column(self.Vector(vector_size))
id: Mapped[str] = mapped_column(primary_key=True)
payload: Mapped[Dict[str, Any]] = mapped_column(JSON)
vector: Mapped[List[float]] = mapped_column(self.Vector(vector_size))
def __init__(self, id, payload, vector):
def __init__(
self, id: str, payload: Dict[str, Any], vector: List[float]
) -> None:
self.id = id
self.payload = payload
self.vector = vector
async with self.engine.begin() as connection:
if len(Base.metadata.tables.keys()) > 0:
await connection.run_sync(
Base.metadata.create_all, tables=[PGVectorDataPoint.__table__]
)
from sqlalchemy import Table
table: Table = PGVectorDataPoint.__table__ # type: ignore
await connection.run_sync(Base.metadata.create_all, tables=[table])
@retry(
retry=retry_if_exception_type(DeadlockDetectedError),
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=2, min=1, max=6),
)
@override_distributed(queued_add_data_points)
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
data_point_types = get_type_hints(DataPoint)
@override_distributed(queued_add_data_points) # type: ignore
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]) -> None:
if not await self.has_collection(collection_name):
await self.create_collection(
collection_name=collection_name,
@ -196,11 +200,11 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
__tablename__ = collection_name
__table_args__ = {"extend_existing": True}
# PGVector requires one column to be the primary key
id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True)
payload = Column(JSON)
vector = Column(self.Vector(vector_size))
id: Mapped[str] = mapped_column(primary_key=True)
payload: Mapped[Dict[str, Any]] = mapped_column(JSON)
vector: Mapped[List[float]] = mapped_column(self.Vector(vector_size))
def __init__(self, id, payload, vector):
def __init__(self, id: str, payload: Dict[str, Any], vector: List[float]) -> None:
self.id = id
self.payload = payload
self.vector = vector
@ -225,13 +229,13 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
# else:
pgvector_data_points.append(
PGVectorDataPoint(
id=data_point.id,
id=str(data_point.id),
vector=data_vectors[data_index],
payload=serialize_data(data_point.model_dump()),
)
)
def to_dict(obj):
def to_dict(obj: Any) -> Dict[str, Any]:
return {
column.key: getattr(obj, column.key)
for column in inspect(obj).mapper.column_attrs
@ -245,12 +249,12 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
await session.execute(insert_statement)
await session.commit()
async def create_vector_index(self, index_name: str, index_property_name: str):
async def create_vector_index(self, index_name: str, index_property_name: str) -> None:
await self.create_collection(f"{index_name}_{index_property_name}")
async def index_data_points(
self, index_name: str, index_property_name: str, data_points: list[DataPoint]
):
self, index_name: str, index_property_name: str, data_points: List[DataPoint]
) -> None:
await self.create_data_points(
f"{index_name}_{index_property_name}",
[
@ -262,11 +266,12 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
],
)
async def get_table(self, collection_name: str) -> Table:
async def get_table(self, table_name: str, schema_name: Optional[str] = None) -> Table:
"""
Dynamically loads a table using the given collection name
with an async engine.
Dynamically loads a table using the given table name
with an async engine. Schema parameter is ignored for vector collections.
"""
collection_name = table_name
async with self.engine.begin() as connection:
# Create a MetaData instance to load table information
metadata = MetaData()
@ -279,15 +284,15 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
f"Collection '{collection_name}' not found!",
)
async def retrieve(self, collection_name: str, data_point_ids: List[str]):
async def retrieve(self, collection_name: str, data_point_ids: List[str]) -> List[ScoredResult]:
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
async with self.get_async_session() as session:
results = await session.execute(
query_result = await session.execute(
select(PGVectorDataPoint).where(PGVectorDataPoint.c.id.in_(data_point_ids))
)
results = results.all()
results = query_result.all()
return [
ScoredResult(id=parse_id(result.id), payload=result.payload, score=0)
@ -311,9 +316,6 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
# NOTE: This needs to be initialized in case search doesn't return a value
closest_items = []
# Use async session to connect to the database
async with self.get_async_session() as session:
query = select(
@ -325,12 +327,12 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
query = query.limit(limit)
# Find closest vectors to query_vector
closest_items = await session.execute(query)
query_results = await session.execute(query)
vector_list = []
# Extract distances and find min/max for normalization
for vector in closest_items.all():
for vector in query_results.all():
vector_list.append(
{
"id": parse_id(str(vector.id)),
@ -349,7 +351,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
# Create and return ScoredResult objects
return [
ScoredResult(id=row.get("id"), payload=row.get("payload"), score=row.get("score"))
ScoredResult(id=row["id"], payload=row["payload"] or {}, score=row["score"])
for row in vector_list
]
@ -357,9 +359,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
self,
collection_name: str,
query_texts: List[str],
limit: int = None,
limit: Optional[int] = None,
with_vectors: bool = False,
):
) -> List[List[ScoredResult]]:
query_vectors = await self.embedding_engine.embed_text(query_texts)
return await asyncio.gather(
@ -367,14 +369,14 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
self.search(
collection_name=collection_name,
query_vector=query_vector,
limit=limit,
limit=limit or 15,
with_vector=with_vectors,
)
for query_vector in query_vectors
]
)
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
async def delete_data_points(self, collection_name: str, data_point_ids: List[str]) -> Any:
async with self.get_async_session() as session:
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
@ -384,6 +386,6 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
await session.commit()
return results
async def prune(self):
async def prune(self) -> None:
# Clean up the database if it was set up as temporary
await self.delete_database()

View file

@ -1,109 +0,0 @@
import os
import pathlib
import cognee
from cognee.infrastructure.files.storage import get_storage_config
from cognee.modules.search.operations import get_history
from cognee.modules.users.methods import get_default_user
from cognee.shared.logging_utils import get_logger
from cognee.modules.search.types import SearchType
logger = get_logger()
async def main():
cognee.config.set_graph_database_provider("memgraph")
data_directory_path = str(
pathlib.Path(
os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_memgraph")
).resolve()
)
cognee.config.data_root_directory(data_directory_path)
cognee_directory_path = str(
pathlib.Path(
os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_memgraph")
).resolve()
)
cognee.config.system_root_directory(cognee_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
dataset_name = "cs_explanations"
explanation_file_path = os.path.join(
pathlib.Path(__file__).parent, "test_data/Natural_language_processing.txt"
)
await cognee.add([explanation_file_path], dataset_name)
text = """A quantum computer is a computer that takes advantage of quantum mechanical phenomena.
At small scales, physical matter exhibits properties of both particles and waves, and quantum computing leverages this behavior, specifically quantum superposition and entanglement, using specialized hardware that supports the preparation and manipulation of quantum states.
Classical physics cannot explain the operation of these quantum devices, and a scalable quantum computer could perform some calculations exponentially faster (with respect to input size scaling) than any modern "classical" computer. In particular, a large-scale quantum computer could break widely used encryption schemes and aid physicists in performing physical simulations; however, the current state of the technology is largely experimental and impractical, with several obstacles to useful applications. Moreover, scalable quantum computers do not hold promise for many practical tasks, and for many important tasks quantum speedups are proven impossible.
The basic unit of information in quantum computing is the qubit, similar to the bit in traditional digital electronics. Unlike a classical bit, a qubit can exist in a superposition of its two "basis" states. When measuring a qubit, the result is a probabilistic output of a classical bit, therefore making quantum computers nondeterministic in general. If a quantum computer manipulates the qubit in a particular way, wave interference effects can amplify the desired measurement results. The design of quantum algorithms involves creating procedures that allow a quantum computer to perform calculations efficiently and quickly.
Physically engineering high-quality qubits has proven challenging. If a physical qubit is not sufficiently isolated from its environment, it suffers from quantum decoherence, introducing noise into calculations. Paradoxically, perfectly isolating qubits is also undesirable because quantum computations typically need to initialize qubits, perform controlled qubit interactions, and measure the resulting quantum states. Each of those operations introduces errors and suffers from noise, and such inaccuracies accumulate.
In principle, a non-quantum (classical) computer can solve the same computational problems as a quantum computer, given enough time. Quantum advantage comes in the form of time complexity rather than computability, and quantum complexity theory shows that some quantum algorithms for carefully selected tasks require exponentially fewer computational steps than the best known non-quantum algorithms. Such tasks can in theory be solved on a large-scale quantum computer whereas classical computers would not finish computations in any reasonable amount of time. However, quantum speedup is not universal or even typical across computational tasks, since basic tasks such as sorting are proven to not allow any asymptotic quantum speedup. Claims of quantum supremacy have drawn significant attention to the discipline, but are demonstrated on contrived tasks, while near-term practical use cases remain limited.
"""
await cognee.add([text], dataset_name)
await cognee.cognify([dataset_name])
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
random_node_name = random_node.payload["text"]
search_results = await cognee.search(
query_type=SearchType.INSIGHTS, query_text=random_node_name
)
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted sentences are:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search(query_type=SearchType.CHUNKS, query_text=random_node_name)
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted chunks are:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search(
query_type=SearchType.SUMMARIES, query_text=random_node_name
)
assert len(search_results) != 0, "Query related summaries don't exist."
print("\nExtracted results are:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search(
query_type=SearchType.NATURAL_LANGUAGE,
query_text=f"Find nodes connected to node with name {random_node_name}",
)
assert len(search_results) != 0, "Query related natural language don't exist."
print("\nExtracted results are:\n")
for result in search_results:
print(f"{result}\n")
user = await get_default_user()
history = await get_history(user.id)
assert len(history) == 8, "Search history is not correct."
await cognee.prune.prune_data()
data_root_directory = get_storage_config()["data_root_directory"]
assert not os.path.isdir(data_root_directory), "Local data files are not deleted"
await cognee.prune.prune_system(metadata=True)
from cognee.infrastructure.databases.graph import get_graph_engine
graph_engine = await get_graph_engine()
nodes, edges = await graph_engine.get_graph_data()
assert len(nodes) == 0 and len(edges) == 0, "Memgraph graph database is not empty"
if __name__ == "__main__":
import asyncio
asyncio.run(main())

View file

@ -1,7 +1,7 @@
[mypy]
python_version=3.8
python_version=3.10
ignore_missing_imports=false
strict_optional=false
strict_optional=true
warn_redundant_casts=true
disallow_any_generics=true
disallow_untyped_defs=true
@ -10,6 +10,12 @@ warn_return_any=true
namespace_packages=true
warn_unused_ignores=true
show_error_codes=true
disallow_incomplete_defs=true
disallow_untyped_decorators=true
no_implicit_optional=true
warn_unreachable=true
warn_no_return=true
warn_unused_configs=true
#exclude=reflection/module_cases/*
exclude=docs/examples/archive/*|tests/reflection/module_cases/*
@ -18,6 +24,22 @@ disallow_untyped_defs=false
warn_return_any=false
[mypy-cognee.infrastructure.databases.*]
ignore_missing_imports=true
# Third-party database libraries that lack type stubs
[mypy-chromadb.*]
ignore_missing_imports=true
[mypy-lancedb.*]
ignore_missing_imports=true
[mypy-asyncpg.*]
ignore_missing_imports=true
[mypy-pgvector.*]
ignore_missing_imports=true
[mypy-docs.*]
disallow_untyped_defs=false

View file

@ -83,16 +83,16 @@
]
},
{
"metadata": {},
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import pathlib\n",
"from cognee import config, add, cognify, search, SearchType, prune, visualize_graph\n",
"from dotenv import load_dotenv"
],
"outputs": [],
"execution_count": null
]
},
{
"cell_type": "markdown",
@ -106,7 +106,9 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# load environment variables from file .env\n",
"load_dotenv()\n",
@ -145,9 +147,7 @@
" \"vector_db_url\": f\"neptune-graph://{graph_identifier}\", # Neptune Analytics endpoint with the format neptune-graph://<GRAPH_ID>\n",
" }\n",
")"
],
"outputs": [],
"execution_count": null
]
},
{
"cell_type": "markdown",
@ -159,19 +159,19 @@
]
},
{
"metadata": {},
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Prune data and system metadata before running, only if we want \"fresh\" state.\n",
"await prune.prune_data()\n",
"await prune.prune_system(metadata=True)"
],
"outputs": [],
"execution_count": null
]
},
{
"metadata": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup data and cognify\n",
"\n",
@ -180,7 +180,9 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Add sample text to the dataset\n",
"sample_text_1 = \"\"\"Neptune Analytics is a memory-optimized graph database engine for analytics. With Neptune\n",
@ -205,9 +207,7 @@
"\n",
"# Cognify the text data.\n",
"await cognify([dataset_name])"
],
"outputs": [],
"execution_count": null
]
},
{
"cell_type": "markdown",
@ -215,14 +215,16 @@
"source": [
"## Graph Memory visualization\n",
"\n",
"Initialize Memgraph as a Graph Memory store and save to .artefacts/graph_visualization.html\n",
"Initialize a Graph Memory store and save to .artefacts/graph_visualization.html\n",
"\n",
"![visualization](./neptune_analytics_demo.png)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Get a graphistry url (Register for a free account at https://www.graphistry.com)\n",
"# url = await render_graph()\n",
@ -235,9 +237,7 @@
" ).resolve()\n",
")\n",
"await visualize_graph(graph_file_path)"
],
"outputs": [],
"execution_count": null
]
},
{
"cell_type": "markdown",
@ -250,19 +250,19 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Completion query that uses graph data to form context.\n",
"graph_completion = await search(query_text=\"What is Neptune Analytics?\", query_type=SearchType.GRAPH_COMPLETION)\n",
"print(\"\\nGraph completion result is:\")\n",
"print(graph_completion)"
],
"outputs": [],
"execution_count": null
]
},
{
"metadata": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## SEARCH: RAG Completion\n",
"\n",
@ -271,19 +271,19 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Completion query that uses document chunks to form context.\n",
"rag_completion = await search(query_text=\"What is Neptune Analytics?\", query_type=SearchType.RAG_COMPLETION)\n",
"print(\"\\nRAG Completion result is:\")\n",
"print(rag_completion)"
],
"outputs": [],
"execution_count": null
]
},
{
"metadata": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## SEARCH: Graph Insights\n",
"\n",
@ -291,8 +291,10 @@
]
},
{
"metadata": {},
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Search graph insights\n",
"insights_results = await search(query_text=\"Neptune Analytics\", query_type=SearchType.INSIGHTS)\n",
@ -302,13 +304,11 @@
" tgt_node = result[2].get(\"name\", result[2][\"type\"])\n",
" relationship = result[1].get(\"relationship_name\", \"__relationship__\")\n",
" print(f\"- {src_node} -[{relationship}]-> {tgt_node}\")"
],
"outputs": [],
"execution_count": null
]
},
{
"metadata": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## SEARCH: Entity Summaries\n",
"\n",
@ -316,8 +316,10 @@
]
},
{
"metadata": {},
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Query all summaries related to query.\n",
"summaries = await search(query_text=\"Neptune Analytics\", query_type=SearchType.SUMMARIES)\n",
@ -326,13 +328,11 @@
" type = summary[\"type\"]\n",
" text = summary[\"text\"]\n",
" print(f\"- {type}: {text}\")"
],
"outputs": [],
"execution_count": null
]
},
{
"metadata": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## SEARCH: Chunks\n",
"\n",
@ -340,8 +340,10 @@
]
},
{
"metadata": {},
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"chunks = await search(query_text=\"Neptune Analytics\", query_type=SearchType.CHUNKS)\n",
"print(\"\\nChunk results are:\")\n",
@ -349,9 +351,7 @@
" type = chunk[\"type\"]\n",
" text = chunk[\"text\"]\n",
" print(f\"- {type}: {text}\")"
],
"outputs": [],
"execution_count": null
]
}
],
"metadata": {

41
tools/check_all_adapters.sh Executable file
View file

@ -0,0 +1,41 @@
#!/bin/bash
# All Database Adapters MyPy Check Script
set -e # Exit on any error
echo "🚀 Running MyPy checks on all database adapters..."
echo ""
# Ensure we're in the project root directory
cd "$(dirname "$0")/.."
# Run all three adapter checks
echo "========================================="
echo "1⃣ VECTOR DATABASE ADAPTERS"
echo "========================================="
./tools/check_vector_adapters.sh
echo ""
echo "========================================="
echo "2⃣ GRAPH DATABASE ADAPTERS"
echo "========================================="
./tools/check_graph_adapters.sh
echo ""
echo "========================================="
echo "3⃣ HYBRID DATABASE ADAPTERS"
echo "========================================="
./tools/check_hybrid_adapters.sh
echo ""
echo "🎉 All Database Adapters MyPy Checks Complete!"
echo ""
echo "🔍 Auto-Discovery Approach:"
echo " • Vector Adapters: cognee/infrastructure/databases/vector/**/*Adapter.py"
echo " • Graph Adapters: cognee/infrastructure/databases/graph/**/*adapter.py"
echo " • Hybrid Adapters: cognee/infrastructure/databases/hybrid/**/*Adapter.py"
echo ""
echo "🎯 Purpose: Enforce that database adapters are properly typed"
echo "🔧 MyPy Configuration: mypy.ini (strict mode enabled)"
echo "🚀 Maintenance-Free: Automatically discovers new adapters"

41
tools/check_graph_adapters.sh Executable file
View file

@ -0,0 +1,41 @@
#!/bin/bash
# Graph Database Adapters MyPy Check Script
set -e # Exit on any error
echo "🔍 Discovering Graph Database Adapters..."
# Ensure we're in the project root directory
cd "$(dirname "$0")/.."
# Activate virtual environment
source .venv/bin/activate
# Find all adapter.py and *adapter.py files in graph database directories, excluding utility files
graph_adapters=$(find cognee/infrastructure/databases/graph -name "*adapter.py" -o -name "adapter.py" | grep -v "use_graph_adapter.py" | sort)
if [ -z "$graph_adapters" ]; then
echo "No graph database adapters found"
exit 0
else
echo "Found graph database adapters:"
echo "$graph_adapters" | sed 's/^/ • /'
echo ""
echo "Running MyPy on graph database adapters..."
# Use while read to properly handle each file
echo "$graph_adapters" | while read -r adapter; do
if [ -n "$adapter" ]; then
echo "Checking: $adapter"
uv run mypy "$adapter" \
--config-file mypy.ini \
--show-error-codes \
--no-error-summary
echo ""
fi
done
fi
echo "✅ Graph Database Adapters MyPy Check Complete!"

41
tools/check_hybrid_adapters.sh Executable file
View file

@ -0,0 +1,41 @@
#!/bin/bash
# Hybrid Database Adapters MyPy Check Script
set -e # Exit on any error
echo "🔍 Discovering Hybrid Database Adapters..."
# Ensure we're in the project root directory
cd "$(dirname "$0")/.."
# Activate virtual environment
source .venv/bin/activate
# Find all *Adapter.py files in hybrid database directories
hybrid_adapters=$(find cognee/infrastructure/databases/hybrid -name "*Adapter.py" -type f | sort)
if [ -z "$hybrid_adapters" ]; then
echo "No hybrid database adapters found"
exit 0
else
echo "Found hybrid database adapters:"
echo "$hybrid_adapters" | sed 's/^/ • /'
echo ""
echo "Running MyPy on hybrid database adapters..."
# Use while read to properly handle each file
echo "$hybrid_adapters" | while read -r adapter; do
if [ -n "$adapter" ]; then
echo "Checking: $adapter"
uv run mypy "$adapter" \
--config-file mypy.ini \
--show-error-codes \
--no-error-summary
echo ""
fi
done
fi
echo "✅ Hybrid Database Adapters MyPy Check Complete!"

41
tools/check_vector_adapters.sh Executable file
View file

@ -0,0 +1,41 @@
#!/bin/bash
# Vector Database Adapters MyPy Check Script
set -e # Exit on any error
echo "🔍 Discovering Vector Database Adapters..."
# Ensure we're in the project root directory
cd "$(dirname "$0")/.."
# Activate virtual environment
source .venv/bin/activate
# Find all *Adapter.py files in vector database directories
vector_adapters=$(find cognee/infrastructure/databases/vector -name "*Adapter.py" -type f | sort)
if [ -z "$vector_adapters" ]; then
echo "No vector database adapters found"
exit 0
else
echo "Found vector database adapters:"
echo "$vector_adapters" | sed 's/^/ • /'
echo ""
echo "Running MyPy on vector database adapters..."
# Use while read to properly handle each file
echo "$vector_adapters" | while read -r adapter; do
if [ -n "$adapter" ]; then
echo "Checking: $adapter"
uv run mypy "$adapter" \
--config-file mypy.ini \
--show-error-codes \
--no-error-summary
echo ""
fi
done
fi
echo "✅ Vector Database Adapters MyPy Check Complete!"