Compare commits

...
Sign in to create a new pull request.

25 commits

Author SHA1 Message Date
Daulet Amirkhanov
5c3a56a7fc ruff format 2025-09-11 14:14:55 +01:00
Daulet Amirkhanov
cd4e2e7063 enable mypy for vector adapters only for now 2025-09-11 14:14:40 +01:00
Daulet Amirkhanov
85e5b4d811 fix remaining mypy errors in LanceDBAdapter 2025-09-11 14:14:06 +01:00
Daulet Amirkhanov
5f4c06efd1 temp 2025-09-11 14:06:02 +01:00
Daulet Amirkhanov
e87b77fda6 undo changes for graph engines 2025-09-11 14:06:02 +01:00
Daulet Amirkhanov
e68a89f737 ruff format 2025-09-11 14:06:02 +01:00
Daulet Amirkhanov
6db69e635e clean up todos from lancedb 2025-09-11 14:06:02 +01:00
Daulet Amirkhanov
a79ca4a7a4 move check adapters scripts to /tools and update mypy workflow 2025-09-11 14:06:02 +01:00
Daulet Amirkhanov
69acac42e2 ruff check fix 2025-09-11 14:06:02 +01:00
Daulet Amirkhanov
c7b0da7aa6 ruff format 2025-09-11 14:06:02 +01:00
Daulet Amirkhanov
0e0bf9a00d mypy fix: Fix ChromaDBAdapter mypy errors 2025-09-11 14:06:02 +01:00
Daulet Amirkhanov
e6256b90b2 mypy: version Neo4j adapter 2025-09-11 14:06:02 +01:00
Daulet Amirkhanov
c8b2e1295d Remove Memgraph and references to it 2025-09-11 14:06:02 +01:00
Daulet Amirkhanov
a4c48d1104 mypy: fix RemoteKuzuAdapter mypy errors 2025-09-11 14:06:02 +01:00
Daulet Amirkhanov
1b221148c3 kuzu - improve type inference for connection 2025-09-11 14:06:02 +01:00
Daulet Amirkhanov
6b2301ff28 mypy: first fix KuzuAdapter mypy errors 2025-09-11 14:06:02 +01:00
Daulet Amirkhanov
86f3d46bf5 mypy: fix PGVectorAdapter mypy errors 2025-09-11 14:06:02 +01:00
Daulet Amirkhanov
5715998f43 mypy: fix LanceDBAdapter mypy errors 2025-09-11 14:06:02 +01:00
Daulet Amirkhanov
c8dbe0ee38 mypy fix: Fix ChromaDBAdapter mypy errors 2025-09-11 14:06:02 +01:00
Daulet Amirkhanov
2992a38acd mypy: ignore missing imports for third party adapter libraries 2025-09-11 14:02:42 +01:00
Daulet Amirkhanov
ee9d7f5b02 adding temporary mypy scripts 2025-09-11 14:02:42 +01:00
Daulet Amirkhanov
85684d2534 make protocols_mypy workflow manually dispatchable 2025-09-11 14:02:42 +01:00
Daulet Amirkhanov
a2180e7c66 Potential fix for code scanning alert no. 150: Workflow does not contain permissions
Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
2025-09-11 14:02:42 +01:00
Daulet Amirkhanov
92b20ab6cd refactor: remove old MyPy workflow and add new database adapter MyPy check workflow 2025-09-11 14:02:42 +01:00
Daulet Amirkhanov
9bf5f76169 chore: update mypy and create a GitHub workflow 2025-09-11 14:02:42 +01:00
14 changed files with 436 additions and 1375 deletions

View file

@ -0,0 +1,77 @@
permissions:
contents: read
name: Database Adapter MyPy Check
on:
workflow_dispatch:
push:
branches: [ main, dev ]
paths:
- 'cognee/infrastructure/databases/**'
- 'tools/check_*_adapters.sh'
- 'mypy.ini'
- '.github/workflows/database_protocol_mypy_check.yml'
pull_request:
branches: [ main, dev ]
paths:
- 'cognee/infrastructure/databases/**'
- 'tools/check_*_adapters.sh'
- 'mypy.ini'
- '.github/workflows/database_protocol_mypy_check.yml'
env:
RUNTIME__LOG_LEVEL: ERROR
ENV: 'dev'
jobs:
mypy-database-adapters:
name: MyPy Database Adapter Type Check
runs-on: ubuntu-22.04
steps:
- name: Check out repository
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11.x'
- name: Discover and Check Vector Database Adapters
run: ./tools/check_vector_adapters.sh
# Commeting out graph and hybrid adapters for now as we're currently focusing on vector adapters
# - name: Discover and Check Graph Database Adapters
# run: ./tools/check_graph_adapters.sh
# - name: Discover and Check Hybrid Database Adapters
# run: ./tools/check_hybrid_adapters.sh
- name: Protocol Compliance Summary
run: |
echo "✅ Database Adapter MyPy Check Passed!"
echo ""
echo "🔍 Auto-Discovery Approach:"
echo " • Vector Adapters: cognee/infrastructure/databases/vector/**/*Adapter.py"
# echo " • Graph Adapters: cognee/infrastructure/databases/graph/**/*adapter.py"
# echo " • Hybrid Adapters: cognee/infrastructure/databases/hybrid/**/*Adapter.py"
echo ""
echo "🚀 Using Dedicated Scripts:"
echo " • Vector: ./tools/check_vector_adapters.sh"
# echo " • Graph: ./tools/check_graph_adapters.sh"
# echo " • Hybrid: ./tools/check_hybrid_adapters.sh"
echo " • All: ./tools/check_all_adapters.sh"
echo ""
echo "🎯 Purpose: Enforce that database adapters are properly typed"
echo "🔧 MyPy Configuration: mypy.ini (strict mode enabled)"
echo "🚀 Maintenance-Free: Automatically discovers new adapters"
echo ""
echo "⚠️ This workflow FAILS on any type errors to ensure adapter quality."
echo " All database adapters must be properly typed."
echo ""
echo "🛠️ To fix type issues locally, run:"
echo " ./tools/check_all_adapters.sh # Check all adapters"
echo " ./tools/check_vector_adapters.sh # Check vector adapters only"
echo " mypy <adapter_file_path> --config-file mypy.ini # Check specific file"

View file

@ -179,5 +179,5 @@ def create_graph_engine(
raise EnvironmentError(
f"Unsupported graph database provider: {graph_database_provider}. "
f"Supported providers are: {', '.join(list(supported_databases.keys()) + ['neo4j', 'falkordb', 'kuzu', 'kuzu-remote', 'memgraph', 'neptune', 'neptune_analytics'])}"
f"Supported providers are: {', '.join(list(supported_databases.keys()) + ['neo4j', 'falkordb', 'kuzu', 'kuzu-remote', 'neptune', 'neptune_analytics'])}"
)

View file

@ -1,12 +1,13 @@
import json
import asyncio
from uuid import UUID
from typing import List, Optional
from typing import List, Optional, Dict, Any
from chromadb import AsyncHttpClient, Settings
from cognee.shared.logging_utils import get_logger
from cognee.modules.storage.utils import get_own_properties
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.models.DataPoint import MetaData
from cognee.infrastructure.engine.utils import parse_id
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult
@ -35,9 +36,9 @@ class IndexSchema(DataPoint):
text: str
metadata: dict = {"index_fields": ["text"]}
metadata: MetaData = {"index_fields": ["text"], "type": "IndexSchema"}
def model_dump(self):
def model_dump(self, **kwargs: Any) -> Dict[str, Any]:
"""
Serialize the instance data for storage.
@ -49,11 +50,11 @@ class IndexSchema(DataPoint):
A dictionary containing serialized data processed for ChromaDB storage.
"""
data = super().model_dump()
data = super().model_dump(**kwargs)
return process_data_for_chroma(data)
def process_data_for_chroma(data):
def process_data_for_chroma(data: Dict[str, Any]) -> Dict[str, Any]:
"""
Convert complex data types to a format suitable for ChromaDB storage.
@ -73,7 +74,7 @@ def process_data_for_chroma(data):
A dictionary containing the processed key-value pairs suitable for ChromaDB storage.
"""
processed_data = {}
processed_data: Dict[str, Any] = {}
for key, value in data.items():
if isinstance(value, UUID):
processed_data[key] = str(value)
@ -90,7 +91,7 @@ def process_data_for_chroma(data):
return processed_data
def restore_data_from_chroma(data):
def restore_data_from_chroma(data: Dict[str, Any]) -> Dict[str, Any]:
"""
Restore original data structure from ChromaDB storage format.
@ -152,8 +153,8 @@ class ChromaDBAdapter(VectorDBInterface):
"""
name = "ChromaDB"
url: str
api_key: str
url: str | None
api_key: str | None
connection: AsyncHttpClient = None
def __init__(
@ -216,7 +217,9 @@ class ChromaDBAdapter(VectorDBInterface):
collections = await self.get_collection_names()
return collection_name in collections
async def create_collection(self, collection_name: str, payload_schema=None):
async def create_collection(
self, collection_name: str, payload_schema: Optional[Any] = None
) -> None:
"""
Create a new collection in ChromaDB if it does not already exist.
@ -254,7 +257,7 @@ class ChromaDBAdapter(VectorDBInterface):
client = await self.get_connection()
return await client.get_collection(collection_name)
async def create_data_points(self, collection_name: str, data_points: list[DataPoint]):
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]) -> None:
"""
Create and upsert data points into the specified collection in ChromaDB.
@ -282,7 +285,7 @@ class ChromaDBAdapter(VectorDBInterface):
ids=ids, embeddings=embeddings, metadatas=metadatas, documents=texts
)
async def create_vector_index(self, index_name: str, index_property_name: str):
async def create_vector_index(self, index_name: str, index_property_name: str) -> None:
"""
Create a vector index as a ChromaDB collection based on provided names.
@ -296,7 +299,7 @@ class ChromaDBAdapter(VectorDBInterface):
async def index_data_points(
self, index_name: str, index_property_name: str, data_points: list[DataPoint]
):
) -> None:
"""
Index the provided data points based on the specified index property in ChromaDB.
@ -315,10 +318,11 @@ class ChromaDBAdapter(VectorDBInterface):
text=getattr(data_point, data_point.metadata["index_fields"][0]),
)
for data_point in data_points
if data_point.metadata and len(data_point.metadata["index_fields"]) > 0
],
)
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
async def retrieve(self, collection_name: str, data_point_ids: List[str]) -> List[ScoredResult]:
"""
Retrieve data points by their IDs from a ChromaDB collection.
@ -350,12 +354,12 @@ class ChromaDBAdapter(VectorDBInterface):
async def search(
self,
collection_name: str,
query_text: str = None,
query_vector: List[float] = None,
query_text: Optional[str] = None,
query_vector: Optional[List[float]] = None,
limit: int = 15,
with_vector: bool = False,
normalized: bool = True,
):
) -> List[ScoredResult]:
"""
Search for items in a collection using either a text or a vector query.
@ -437,7 +441,7 @@ class ChromaDBAdapter(VectorDBInterface):
query_texts: List[str],
limit: int = 5,
with_vectors: bool = False,
):
) -> List[List[ScoredResult]]:
"""
Perform multiple searches in a single request for efficiency, returning results for each
query.
@ -507,7 +511,7 @@ class ChromaDBAdapter(VectorDBInterface):
return all_results
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
async def delete_data_points(self, collection_name: str, data_point_ids: List[str]) -> bool:
"""
Remove data points from a collection based on their IDs.
@ -528,7 +532,7 @@ class ChromaDBAdapter(VectorDBInterface):
await collection.delete(ids=data_point_ids)
return True
async def prune(self):
async def prune(self) -> bool:
"""
Delete all collections in the ChromaDB database.
@ -538,12 +542,12 @@ class ChromaDBAdapter(VectorDBInterface):
Returns True upon successful deletion of all collections.
"""
client = await self.get_connection()
collections = await client.list_collections()
for collection_name in collections:
collection_names = await self.get_collection_names()
for collection_name in collection_names:
await client.delete_collection(collection_name)
return True
async def get_collection_names(self):
async def get_collection_names(self) -> List[str]:
"""
Retrieve the names of all collections in the ChromaDB database.

View file

@ -1,12 +1,26 @@
import asyncio
import json
from os import path
from uuid import UUID
import lancedb
from pydantic import BaseModel
from lancedb.pydantic import LanceModel, Vector
from typing import Generic, List, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
from typing import (
Generic,
List,
Optional,
TypeVar,
Union,
get_args,
get_origin,
get_type_hints,
Dict,
Any,
)
from cognee.infrastructure.databases.exceptions import MissingQueryParameterError
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.models.DataPoint import MetaData
from cognee.infrastructure.engine.utils import parse_id
from cognee.infrastructure.files.storage import get_file_storage
from cognee.modules.storage.utils import copy_model, get_own_properties
@ -30,16 +44,16 @@ class IndexSchema(DataPoint):
to include 'text'.
"""
id: str
id: UUID
text: str
metadata: dict = {"index_fields": ["text"]}
metadata: MetaData = {"index_fields": ["text"], "type": "IndexSchema"}
class LanceDBAdapter(VectorDBInterface):
name = "LanceDB"
url: str
api_key: str
url: Optional[str]
api_key: Optional[str]
connection: lancedb.AsyncConnection = None
def __init__(
@ -53,7 +67,7 @@ class LanceDBAdapter(VectorDBInterface):
self.embedding_engine = embedding_engine
self.VECTOR_DB_LOCK = asyncio.Lock()
async def get_connection(self):
async def get_connection(self) -> lancedb.AsyncConnection:
"""
Establishes and returns a connection to the LanceDB.
@ -107,12 +121,11 @@ class LanceDBAdapter(VectorDBInterface):
collection_names = await connection.table_names()
return collection_name in collection_names
async def create_collection(self, collection_name: str, payload_schema: BaseModel):
async def create_collection(
self, collection_name: str, payload_schema: Optional[Any] = None
) -> None:
vector_size = self.embedding_engine.get_vector_size()
payload_schema = self.get_data_point_schema(payload_schema)
data_point_types = get_type_hints(payload_schema)
class LanceDataPoint(LanceModel):
"""
Represents a data point in the Lance model with an ID, vector, and associated payload.
@ -123,28 +136,28 @@ class LanceDBAdapter(VectorDBInterface):
- payload: Additional data or metadata associated with the data point.
"""
id: data_point_types["id"]
vector: Vector(vector_size)
payload: payload_schema
id: str
vector: Vector(vector_size) # type: ignore
payload: str # JSON string for LanceDB compatibility
if not await self.has_collection(collection_name):
async with self.VECTOR_DB_LOCK:
if not await self.has_collection(collection_name):
connection = await self.get_connection()
return await connection.create_table(
await connection.create_table(
name=collection_name,
schema=LanceDataPoint,
exist_ok=True,
)
async def get_collection(self, collection_name: str):
async def get_collection(self, collection_name: str) -> Any:
if not await self.has_collection(collection_name):
raise CollectionNotFoundError(f"Collection '{collection_name}' not found!")
connection = await self.get_connection()
return await connection.open_table(collection_name)
async def create_data_points(self, collection_name: str, data_points: list[DataPoint]):
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]) -> None:
payload_schema = type(data_points[0])
if not await self.has_collection(collection_name):
@ -161,11 +174,9 @@ class LanceDBAdapter(VectorDBInterface):
[DataPoint.get_embeddable_data(data_point) for data_point in data_points]
)
IdType = TypeVar("IdType")
PayloadSchema = TypeVar("PayloadSchema")
vector_size = self.embedding_engine.get_vector_size()
class LanceDataPoint(LanceModel, Generic[IdType, PayloadSchema]):
class LanceDataPoint(LanceModel):
"""
Represents a data point in the Lance model with an ID, vector, and payload.
@ -174,15 +185,15 @@ class LanceDBAdapter(VectorDBInterface):
to the Lance data structure.
"""
id: IdType
vector: Vector(vector_size)
payload: PayloadSchema
id: str
vector: Vector(vector_size) # type: ignore
payload: str # JSON string for LanceDB compatibility
def create_lance_data_point(data_point: DataPoint, vector: list[float]) -> LanceDataPoint:
def create_lance_data_point(data_point: DataPoint, vector: List[float]) -> Any:
properties = get_own_properties(data_point)
properties["id"] = str(properties["id"])
return LanceDataPoint[str, self.get_data_point_schema(type(data_point))](
return LanceDataPoint(
id=str(data_point.id),
vector=vector,
payload=properties,
@ -201,7 +212,7 @@ class LanceDBAdapter(VectorDBInterface):
.execute(lance_data_points)
)
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
async def retrieve(self, collection_name: str, data_point_ids: list[str]) -> List[ScoredResult]:
collection = await self.get_collection(collection_name)
if len(data_point_ids) == 1:
@ -212,7 +223,7 @@ class LanceDBAdapter(VectorDBInterface):
return [
ScoredResult(
id=parse_id(result["id"]),
payload=result["payload"],
payload=json.loads(result["payload"]),
score=0,
)
for result in results.to_dict("index").values()
@ -221,12 +232,12 @@ class LanceDBAdapter(VectorDBInterface):
async def search(
self,
collection_name: str,
query_text: str = None,
query_vector: List[float] = None,
query_text: Optional[str] = None,
query_vector: Optional[List[float]] = None,
limit: int = 15,
with_vector: bool = False,
normalized: bool = True,
):
) -> List[ScoredResult]:
if query_text is None and query_vector is None:
raise MissingQueryParameterError()
@ -254,7 +265,7 @@ class LanceDBAdapter(VectorDBInterface):
return [
ScoredResult(
id=parse_id(result["id"]),
payload=result["payload"],
payload=json.loads(result["payload"]),
score=normalized_values[value_index],
)
for value_index, result in enumerate(result_values)
@ -264,9 +275,9 @@ class LanceDBAdapter(VectorDBInterface):
self,
collection_name: str,
query_texts: List[str],
limit: int = None,
limit: Optional[int] = None,
with_vectors: bool = False,
):
) -> List[List[ScoredResult]]:
query_vectors = await self.embedding_engine.embed_text(query_texts)
return await asyncio.gather(
@ -274,40 +285,41 @@ class LanceDBAdapter(VectorDBInterface):
self.search(
collection_name=collection_name,
query_vector=query_vector,
limit=limit,
limit=limit or 15,
with_vector=with_vectors,
)
for query_vector in query_vectors
]
)
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
async def delete_data_points(self, collection_name: str, data_point_ids: List[str]) -> None:
collection = await self.get_collection(collection_name)
# Delete one at a time to avoid commit conflicts
for data_point_id in data_point_ids:
await collection.delete(f"id = '{data_point_id}'")
async def create_vector_index(self, index_name: str, index_property_name: str):
async def create_vector_index(self, index_name: str, index_property_name: str) -> None:
await self.create_collection(
f"{index_name}_{index_property_name}", payload_schema=IndexSchema
)
async def index_data_points(
self, index_name: str, index_property_name: str, data_points: list[DataPoint]
):
self, index_name: str, index_property_name: str, data_points: List[DataPoint]
) -> None:
await self.create_data_points(
f"{index_name}_{index_property_name}",
[
IndexSchema(
id=str(data_point.id),
id=data_point.id,
text=getattr(data_point, data_point.metadata["index_fields"][0]),
)
for data_point in data_points
if data_point.metadata and len(data_point.metadata.get("index_fields", [])) > 0
],
)
async def prune(self):
async def prune(self) -> None:
connection = await self.get_connection()
collection_names = await connection.table_names()
@ -316,12 +328,15 @@ class LanceDBAdapter(VectorDBInterface):
await collection.delete("id IS NOT NULL")
await connection.drop_table(collection_name)
if self.url.startswith("/"):
if self.url and self.url.startswith("/"):
db_dir_path = path.dirname(self.url)
db_file_name = path.basename(self.url)
await get_file_storage(db_dir_path).remove_all(db_file_name)
def get_data_point_schema(self, model_type: BaseModel):
def get_data_point_schema(self, model_type: Optional[Any]) -> Any:
if model_type is None:
return DataPoint
related_models_fields = []
for field_name, field_config in model_type.model_fields.items():

View file

@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any, Dict, List, Optional
from uuid import UUID
from pydantic import BaseModel
@ -14,8 +14,10 @@ class ScoredResult(BaseModel):
better outcome.
- payload (Dict[str, Any]): Additional information related to the score, stored as
key-value pairs in a dictionary.
- vector (Optional[List[float]]): Optional vector embedding associated with the result.
"""
id: UUID
score: float # Lower score is better
payload: Dict[str, Any]
vector: Optional[List[float]] = None

View file

@ -1,9 +1,9 @@
import asyncio
from typing import List, Optional, get_type_hints
from typing import List, Optional, get_type_hints, Dict, Any
from sqlalchemy.inspection import inspect
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy import JSON, Column, Table, select, delete, MetaData
from sqlalchemy import JSON, Table, select, delete, MetaData
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from sqlalchemy.exc import ProgrammingError
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
@ -12,6 +12,7 @@ from asyncpg import DeadlockDetectedError, DuplicateTableError, UniqueViolationE
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.models.DataPoint import MetaData as DataPointMetaData
from cognee.infrastructure.engine.utils import parse_id
from cognee.infrastructure.databases.relational import get_relational_engine
@ -42,7 +43,7 @@ class IndexSchema(DataPoint):
text: str
metadata: dict = {"index_fields": ["text"]}
metadata: DataPointMetaData = {"index_fields": ["text"], "type": "IndexSchema"}
class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
@ -122,8 +123,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=2, min=1, max=6),
)
async def create_collection(self, collection_name: str, payload_schema=None):
data_point_types = get_type_hints(DataPoint)
async def create_collection(
self, collection_name: str, payload_schema: Optional[Any] = None
) -> None:
vector_size = self.embedding_engine.get_vector_size()
async with self.VECTOR_DB_LOCK:
@ -147,29 +149,31 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
__tablename__ = collection_name
__table_args__ = {"extend_existing": True}
# PGVector requires one column to be the primary key
id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True)
payload = Column(JSON)
vector = Column(self.Vector(vector_size))
id: Mapped[str] = mapped_column(primary_key=True)
payload: Mapped[Dict[str, Any]] = mapped_column(JSON)
vector: Mapped[List[float]] = mapped_column(self.Vector(vector_size))
def __init__(self, id, payload, vector):
def __init__(
self, id: str, payload: Dict[str, Any], vector: List[float]
) -> None:
self.id = id
self.payload = payload
self.vector = vector
async with self.engine.begin() as connection:
if len(Base.metadata.tables.keys()) > 0:
await connection.run_sync(
Base.metadata.create_all, tables=[PGVectorDataPoint.__table__]
)
from sqlalchemy import Table
table: Table = PGVectorDataPoint.__table__ # type: ignore
await connection.run_sync(Base.metadata.create_all, tables=[table])
@retry(
retry=retry_if_exception_type(DeadlockDetectedError),
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=2, min=1, max=6),
)
@override_distributed(queued_add_data_points)
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
data_point_types = get_type_hints(DataPoint)
@override_distributed(queued_add_data_points) # type: ignore
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]) -> None:
if not await self.has_collection(collection_name):
await self.create_collection(
collection_name=collection_name,
@ -196,11 +200,11 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
__tablename__ = collection_name
__table_args__ = {"extend_existing": True}
# PGVector requires one column to be the primary key
id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True)
payload = Column(JSON)
vector = Column(self.Vector(vector_size))
id: Mapped[str] = mapped_column(primary_key=True)
payload: Mapped[Dict[str, Any]] = mapped_column(JSON)
vector: Mapped[List[float]] = mapped_column(self.Vector(vector_size))
def __init__(self, id, payload, vector):
def __init__(self, id: str, payload: Dict[str, Any], vector: List[float]) -> None:
self.id = id
self.payload = payload
self.vector = vector
@ -225,13 +229,13 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
# else:
pgvector_data_points.append(
PGVectorDataPoint(
id=data_point.id,
id=str(data_point.id),
vector=data_vectors[data_index],
payload=serialize_data(data_point.model_dump()),
)
)
def to_dict(obj):
def to_dict(obj: Any) -> Dict[str, Any]:
return {
column.key: getattr(obj, column.key)
for column in inspect(obj).mapper.column_attrs
@ -245,12 +249,12 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
await session.execute(insert_statement)
await session.commit()
async def create_vector_index(self, index_name: str, index_property_name: str):
async def create_vector_index(self, index_name: str, index_property_name: str) -> None:
await self.create_collection(f"{index_name}_{index_property_name}")
async def index_data_points(
self, index_name: str, index_property_name: str, data_points: list[DataPoint]
):
self, index_name: str, index_property_name: str, data_points: List[DataPoint]
) -> None:
await self.create_data_points(
f"{index_name}_{index_property_name}",
[
@ -262,11 +266,12 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
],
)
async def get_table(self, collection_name: str) -> Table:
async def get_table(self, table_name: str, schema_name: Optional[str] = None) -> Table:
"""
Dynamically loads a table using the given collection name
with an async engine.
Dynamically loads a table using the given table name
with an async engine. Schema parameter is ignored for vector collections.
"""
collection_name = table_name
async with self.engine.begin() as connection:
# Create a MetaData instance to load table information
metadata = MetaData()
@ -279,15 +284,15 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
f"Collection '{collection_name}' not found!",
)
async def retrieve(self, collection_name: str, data_point_ids: List[str]):
async def retrieve(self, collection_name: str, data_point_ids: List[str]) -> List[ScoredResult]:
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
async with self.get_async_session() as session:
results = await session.execute(
query_result = await session.execute(
select(PGVectorDataPoint).where(PGVectorDataPoint.c.id.in_(data_point_ids))
)
results = results.all()
results = query_result.all()
return [
ScoredResult(id=parse_id(result.id), payload=result.payload, score=0)
@ -311,9 +316,6 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
# NOTE: This needs to be initialized in case search doesn't return a value
closest_items = []
# Use async session to connect to the database
async with self.get_async_session() as session:
query = select(
@ -325,12 +327,12 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
query = query.limit(limit)
# Find closest vectors to query_vector
closest_items = await session.execute(query)
query_results = await session.execute(query)
vector_list = []
# Extract distances and find min/max for normalization
for vector in closest_items.all():
for vector in query_results.all():
vector_list.append(
{
"id": parse_id(str(vector.id)),
@ -349,7 +351,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
# Create and return ScoredResult objects
return [
ScoredResult(id=row.get("id"), payload=row.get("payload"), score=row.get("score"))
ScoredResult(id=row["id"], payload=row["payload"] or {}, score=row["score"])
for row in vector_list
]
@ -357,9 +359,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
self,
collection_name: str,
query_texts: List[str],
limit: int = None,
limit: Optional[int] = None,
with_vectors: bool = False,
):
) -> List[List[ScoredResult]]:
query_vectors = await self.embedding_engine.embed_text(query_texts)
return await asyncio.gather(
@ -367,14 +369,14 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
self.search(
collection_name=collection_name,
query_vector=query_vector,
limit=limit,
limit=limit or 15,
with_vector=with_vectors,
)
for query_vector in query_vectors
]
)
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
async def delete_data_points(self, collection_name: str, data_point_ids: List[str]) -> Any:
async with self.get_async_session() as session:
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
@ -384,6 +386,6 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
await session.commit()
return results
async def prune(self):
async def prune(self) -> None:
# Clean up the database if it was set up as temporary
await self.delete_database()

View file

@ -1,109 +0,0 @@
import os
import pathlib
import cognee
from cognee.infrastructure.files.storage import get_storage_config
from cognee.modules.search.operations import get_history
from cognee.modules.users.methods import get_default_user
from cognee.shared.logging_utils import get_logger
from cognee.modules.search.types import SearchType
logger = get_logger()
async def main():
cognee.config.set_graph_database_provider("memgraph")
data_directory_path = str(
pathlib.Path(
os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_memgraph")
).resolve()
)
cognee.config.data_root_directory(data_directory_path)
cognee_directory_path = str(
pathlib.Path(
os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_memgraph")
).resolve()
)
cognee.config.system_root_directory(cognee_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
dataset_name = "cs_explanations"
explanation_file_path = os.path.join(
pathlib.Path(__file__).parent, "test_data/Natural_language_processing.txt"
)
await cognee.add([explanation_file_path], dataset_name)
text = """A quantum computer is a computer that takes advantage of quantum mechanical phenomena.
At small scales, physical matter exhibits properties of both particles and waves, and quantum computing leverages this behavior, specifically quantum superposition and entanglement, using specialized hardware that supports the preparation and manipulation of quantum states.
Classical physics cannot explain the operation of these quantum devices, and a scalable quantum computer could perform some calculations exponentially faster (with respect to input size scaling) than any modern "classical" computer. In particular, a large-scale quantum computer could break widely used encryption schemes and aid physicists in performing physical simulations; however, the current state of the technology is largely experimental and impractical, with several obstacles to useful applications. Moreover, scalable quantum computers do not hold promise for many practical tasks, and for many important tasks quantum speedups are proven impossible.
The basic unit of information in quantum computing is the qubit, similar to the bit in traditional digital electronics. Unlike a classical bit, a qubit can exist in a superposition of its two "basis" states. When measuring a qubit, the result is a probabilistic output of a classical bit, therefore making quantum computers nondeterministic in general. If a quantum computer manipulates the qubit in a particular way, wave interference effects can amplify the desired measurement results. The design of quantum algorithms involves creating procedures that allow a quantum computer to perform calculations efficiently and quickly.
Physically engineering high-quality qubits has proven challenging. If a physical qubit is not sufficiently isolated from its environment, it suffers from quantum decoherence, introducing noise into calculations. Paradoxically, perfectly isolating qubits is also undesirable because quantum computations typically need to initialize qubits, perform controlled qubit interactions, and measure the resulting quantum states. Each of those operations introduces errors and suffers from noise, and such inaccuracies accumulate.
In principle, a non-quantum (classical) computer can solve the same computational problems as a quantum computer, given enough time. Quantum advantage comes in the form of time complexity rather than computability, and quantum complexity theory shows that some quantum algorithms for carefully selected tasks require exponentially fewer computational steps than the best known non-quantum algorithms. Such tasks can in theory be solved on a large-scale quantum computer whereas classical computers would not finish computations in any reasonable amount of time. However, quantum speedup is not universal or even typical across computational tasks, since basic tasks such as sorting are proven to not allow any asymptotic quantum speedup. Claims of quantum supremacy have drawn significant attention to the discipline, but are demonstrated on contrived tasks, while near-term practical use cases remain limited.
"""
await cognee.add([text], dataset_name)
await cognee.cognify([dataset_name])
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
random_node_name = random_node.payload["text"]
search_results = await cognee.search(
query_type=SearchType.INSIGHTS, query_text=random_node_name
)
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted sentences are:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search(query_type=SearchType.CHUNKS, query_text=random_node_name)
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted chunks are:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search(
query_type=SearchType.SUMMARIES, query_text=random_node_name
)
assert len(search_results) != 0, "Query related summaries don't exist."
print("\nExtracted results are:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search(
query_type=SearchType.NATURAL_LANGUAGE,
query_text=f"Find nodes connected to node with name {random_node_name}",
)
assert len(search_results) != 0, "Query related natural language don't exist."
print("\nExtracted results are:\n")
for result in search_results:
print(f"{result}\n")
user = await get_default_user()
history = await get_history(user.id)
assert len(history) == 8, "Search history is not correct."
await cognee.prune.prune_data()
data_root_directory = get_storage_config()["data_root_directory"]
assert not os.path.isdir(data_root_directory), "Local data files are not deleted"
await cognee.prune.prune_system(metadata=True)
from cognee.infrastructure.databases.graph import get_graph_engine
graph_engine = await get_graph_engine()
nodes, edges = await graph_engine.get_graph_data()
assert len(nodes) == 0 and len(edges) == 0, "Memgraph graph database is not empty"
if __name__ == "__main__":
import asyncio
asyncio.run(main())

View file

@ -1,7 +1,7 @@
[mypy]
python_version=3.8
python_version=3.10
ignore_missing_imports=false
strict_optional=false
strict_optional=true
warn_redundant_casts=true
disallow_any_generics=true
disallow_untyped_defs=true
@ -10,6 +10,12 @@ warn_return_any=true
namespace_packages=true
warn_unused_ignores=true
show_error_codes=true
disallow_incomplete_defs=true
disallow_untyped_decorators=true
no_implicit_optional=true
warn_unreachable=true
warn_no_return=true
warn_unused_configs=true
#exclude=reflection/module_cases/*
exclude=docs/examples/archive/*|tests/reflection/module_cases/*
@ -18,6 +24,22 @@ disallow_untyped_defs=false
warn_return_any=false
[mypy-cognee.infrastructure.databases.*]
ignore_missing_imports=true
# Third-party database libraries that lack type stubs
[mypy-chromadb.*]
ignore_missing_imports=true
[mypy-lancedb.*]
ignore_missing_imports=true
[mypy-asyncpg.*]
ignore_missing_imports=true
[mypy-pgvector.*]
ignore_missing_imports=true
[mypy-docs.*]
disallow_untyped_defs=false

View file

@ -83,16 +83,16 @@
]
},
{
"metadata": {},
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import pathlib\n",
"from cognee import config, add, cognify, search, SearchType, prune, visualize_graph\n",
"from dotenv import load_dotenv"
],
"outputs": [],
"execution_count": null
]
},
{
"cell_type": "markdown",
@ -106,7 +106,9 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# load environment variables from file .env\n",
"load_dotenv()\n",
@ -145,9 +147,7 @@
" \"vector_db_url\": f\"neptune-graph://{graph_identifier}\", # Neptune Analytics endpoint with the format neptune-graph://<GRAPH_ID>\n",
" }\n",
")"
],
"outputs": [],
"execution_count": null
]
},
{
"cell_type": "markdown",
@ -159,19 +159,19 @@
]
},
{
"metadata": {},
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Prune data and system metadata before running, only if we want \"fresh\" state.\n",
"await prune.prune_data()\n",
"await prune.prune_system(metadata=True)"
],
"outputs": [],
"execution_count": null
]
},
{
"metadata": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup data and cognify\n",
"\n",
@ -180,7 +180,9 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Add sample text to the dataset\n",
"sample_text_1 = \"\"\"Neptune Analytics is a memory-optimized graph database engine for analytics. With Neptune\n",
@ -205,9 +207,7 @@
"\n",
"# Cognify the text data.\n",
"await cognify([dataset_name])"
],
"outputs": [],
"execution_count": null
]
},
{
"cell_type": "markdown",
@ -215,14 +215,16 @@
"source": [
"## Graph Memory visualization\n",
"\n",
"Initialize Memgraph as a Graph Memory store and save to .artefacts/graph_visualization.html\n",
"Initialize a Graph Memory store and save to .artefacts/graph_visualization.html\n",
"\n",
"![visualization](./neptune_analytics_demo.png)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Get a graphistry url (Register for a free account at https://www.graphistry.com)\n",
"# url = await render_graph()\n",
@ -235,9 +237,7 @@
" ).resolve()\n",
")\n",
"await visualize_graph(graph_file_path)"
],
"outputs": [],
"execution_count": null
]
},
{
"cell_type": "markdown",
@ -250,19 +250,19 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Completion query that uses graph data to form context.\n",
"graph_completion = await search(query_text=\"What is Neptune Analytics?\", query_type=SearchType.GRAPH_COMPLETION)\n",
"print(\"\\nGraph completion result is:\")\n",
"print(graph_completion)"
],
"outputs": [],
"execution_count": null
]
},
{
"metadata": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## SEARCH: RAG Completion\n",
"\n",
@ -271,19 +271,19 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Completion query that uses document chunks to form context.\n",
"rag_completion = await search(query_text=\"What is Neptune Analytics?\", query_type=SearchType.RAG_COMPLETION)\n",
"print(\"\\nRAG Completion result is:\")\n",
"print(rag_completion)"
],
"outputs": [],
"execution_count": null
]
},
{
"metadata": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## SEARCH: Graph Insights\n",
"\n",
@ -291,8 +291,10 @@
]
},
{
"metadata": {},
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Search graph insights\n",
"insights_results = await search(query_text=\"Neptune Analytics\", query_type=SearchType.INSIGHTS)\n",
@ -302,13 +304,11 @@
" tgt_node = result[2].get(\"name\", result[2][\"type\"])\n",
" relationship = result[1].get(\"relationship_name\", \"__relationship__\")\n",
" print(f\"- {src_node} -[{relationship}]-> {tgt_node}\")"
],
"outputs": [],
"execution_count": null
]
},
{
"metadata": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## SEARCH: Entity Summaries\n",
"\n",
@ -316,8 +316,10 @@
]
},
{
"metadata": {},
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Query all summaries related to query.\n",
"summaries = await search(query_text=\"Neptune Analytics\", query_type=SearchType.SUMMARIES)\n",
@ -326,13 +328,11 @@
" type = summary[\"type\"]\n",
" text = summary[\"text\"]\n",
" print(f\"- {type}: {text}\")"
],
"outputs": [],
"execution_count": null
]
},
{
"metadata": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## SEARCH: Chunks\n",
"\n",
@ -340,8 +340,10 @@
]
},
{
"metadata": {},
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"chunks = await search(query_text=\"Neptune Analytics\", query_type=SearchType.CHUNKS)\n",
"print(\"\\nChunk results are:\")\n",
@ -349,9 +351,7 @@
" type = chunk[\"type\"]\n",
" text = chunk[\"text\"]\n",
" print(f\"- {type}: {text}\")"
],
"outputs": [],
"execution_count": null
]
}
],
"metadata": {

41
tools/check_all_adapters.sh Executable file
View file

@ -0,0 +1,41 @@
#!/bin/bash
# All Database Adapters MyPy Check Script
set -e # Exit on any error
echo "🚀 Running MyPy checks on all database adapters..."
echo ""
# Ensure we're in the project root directory
cd "$(dirname "$0")/.."
# Run all three adapter checks
echo "========================================="
echo "1⃣ VECTOR DATABASE ADAPTERS"
echo "========================================="
./tools/check_vector_adapters.sh
echo ""
echo "========================================="
echo "2⃣ GRAPH DATABASE ADAPTERS"
echo "========================================="
./tools/check_graph_adapters.sh
echo ""
echo "========================================="
echo "3⃣ HYBRID DATABASE ADAPTERS"
echo "========================================="
./tools/check_hybrid_adapters.sh
echo ""
echo "🎉 All Database Adapters MyPy Checks Complete!"
echo ""
echo "🔍 Auto-Discovery Approach:"
echo " • Vector Adapters: cognee/infrastructure/databases/vector/**/*Adapter.py"
echo " • Graph Adapters: cognee/infrastructure/databases/graph/**/*adapter.py"
echo " • Hybrid Adapters: cognee/infrastructure/databases/hybrid/**/*Adapter.py"
echo ""
echo "🎯 Purpose: Enforce that database adapters are properly typed"
echo "🔧 MyPy Configuration: mypy.ini (strict mode enabled)"
echo "🚀 Maintenance-Free: Automatically discovers new adapters"

41
tools/check_graph_adapters.sh Executable file
View file

@ -0,0 +1,41 @@
#!/bin/bash
# Graph Database Adapters MyPy Check Script
set -e # Exit on any error
echo "🔍 Discovering Graph Database Adapters..."
# Ensure we're in the project root directory
cd "$(dirname "$0")/.."
# Activate virtual environment
source .venv/bin/activate
# Find all adapter.py and *adapter.py files in graph database directories, excluding utility files
graph_adapters=$(find cognee/infrastructure/databases/graph -name "*adapter.py" -o -name "adapter.py" | grep -v "use_graph_adapter.py" | sort)
if [ -z "$graph_adapters" ]; then
echo "No graph database adapters found"
exit 0
else
echo "Found graph database adapters:"
echo "$graph_adapters" | sed 's/^/ • /'
echo ""
echo "Running MyPy on graph database adapters..."
# Use while read to properly handle each file
echo "$graph_adapters" | while read -r adapter; do
if [ -n "$adapter" ]; then
echo "Checking: $adapter"
uv run mypy "$adapter" \
--config-file mypy.ini \
--show-error-codes \
--no-error-summary
echo ""
fi
done
fi
echo "✅ Graph Database Adapters MyPy Check Complete!"

41
tools/check_hybrid_adapters.sh Executable file
View file

@ -0,0 +1,41 @@
#!/bin/bash
# Hybrid Database Adapters MyPy Check Script
set -e # Exit on any error
echo "🔍 Discovering Hybrid Database Adapters..."
# Ensure we're in the project root directory
cd "$(dirname "$0")/.."
# Activate virtual environment
source .venv/bin/activate
# Find all *Adapter.py files in hybrid database directories
hybrid_adapters=$(find cognee/infrastructure/databases/hybrid -name "*Adapter.py" -type f | sort)
if [ -z "$hybrid_adapters" ]; then
echo "No hybrid database adapters found"
exit 0
else
echo "Found hybrid database adapters:"
echo "$hybrid_adapters" | sed 's/^/ • /'
echo ""
echo "Running MyPy on hybrid database adapters..."
# Use while read to properly handle each file
echo "$hybrid_adapters" | while read -r adapter; do
if [ -n "$adapter" ]; then
echo "Checking: $adapter"
uv run mypy "$adapter" \
--config-file mypy.ini \
--show-error-codes \
--no-error-summary
echo ""
fi
done
fi
echo "✅ Hybrid Database Adapters MyPy Check Complete!"

41
tools/check_vector_adapters.sh Executable file
View file

@ -0,0 +1,41 @@
#!/bin/bash
# Vector Database Adapters MyPy Check Script
set -e # Exit on any error
echo "🔍 Discovering Vector Database Adapters..."
# Ensure we're in the project root directory
cd "$(dirname "$0")/.."
# Activate virtual environment
source .venv/bin/activate
# Find all *Adapter.py files in vector database directories
vector_adapters=$(find cognee/infrastructure/databases/vector -name "*Adapter.py" -type f | sort)
if [ -z "$vector_adapters" ]; then
echo "No vector database adapters found"
exit 0
else
echo "Found vector database adapters:"
echo "$vector_adapters" | sed 's/^/ • /'
echo ""
echo "Running MyPy on vector database adapters..."
# Use while read to properly handle each file
echo "$vector_adapters" | while read -r adapter; do
if [ -n "$adapter" ]; then
echo "Checking: $adapter"
uv run mypy "$adapter" \
--config-file mypy.ini \
--show-error-codes \
--no-error-summary
echo ""
fi
done
fi
echo "✅ Vector Database Adapters MyPy Check Complete!"