Compare commits
25 commits
main
...
chore/upda
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5c3a56a7fc | ||
|
|
cd4e2e7063 | ||
|
|
85e5b4d811 | ||
|
|
5f4c06efd1 | ||
|
|
e87b77fda6 | ||
|
|
e68a89f737 | ||
|
|
6db69e635e | ||
|
|
a79ca4a7a4 | ||
|
|
69acac42e2 | ||
|
|
c7b0da7aa6 | ||
|
|
0e0bf9a00d | ||
|
|
e6256b90b2 | ||
|
|
c8b2e1295d | ||
|
|
a4c48d1104 | ||
|
|
1b221148c3 | ||
|
|
6b2301ff28 | ||
|
|
86f3d46bf5 | ||
|
|
5715998f43 | ||
|
|
c8dbe0ee38 | ||
|
|
2992a38acd | ||
|
|
ee9d7f5b02 | ||
|
|
85684d2534 | ||
|
|
a2180e7c66 | ||
|
|
92b20ab6cd | ||
|
|
9bf5f76169 |
14 changed files with 436 additions and 1375 deletions
77
.github/workflows/database_protocol_mypy_check.yml
vendored
Normal file
77
.github/workflows/database_protocol_mypy_check.yml
vendored
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
permissions:
|
||||
contents: read
|
||||
name: Database Adapter MyPy Check
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
branches: [ main, dev ]
|
||||
paths:
|
||||
- 'cognee/infrastructure/databases/**'
|
||||
- 'tools/check_*_adapters.sh'
|
||||
- 'mypy.ini'
|
||||
- '.github/workflows/database_protocol_mypy_check.yml'
|
||||
pull_request:
|
||||
branches: [ main, dev ]
|
||||
paths:
|
||||
- 'cognee/infrastructure/databases/**'
|
||||
- 'tools/check_*_adapters.sh'
|
||||
- 'mypy.ini'
|
||||
- '.github/workflows/database_protocol_mypy_check.yml'
|
||||
|
||||
env:
|
||||
RUNTIME__LOG_LEVEL: ERROR
|
||||
ENV: 'dev'
|
||||
|
||||
jobs:
|
||||
mypy-database-adapters:
|
||||
name: MyPy Database Adapter Type Check
|
||||
runs-on: ubuntu-22.04
|
||||
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
|
||||
- name: Discover and Check Vector Database Adapters
|
||||
run: ./tools/check_vector_adapters.sh
|
||||
|
||||
# Commeting out graph and hybrid adapters for now as we're currently focusing on vector adapters
|
||||
# - name: Discover and Check Graph Database Adapters
|
||||
# run: ./tools/check_graph_adapters.sh
|
||||
|
||||
# - name: Discover and Check Hybrid Database Adapters
|
||||
# run: ./tools/check_hybrid_adapters.sh
|
||||
|
||||
- name: Protocol Compliance Summary
|
||||
run: |
|
||||
echo "✅ Database Adapter MyPy Check Passed!"
|
||||
echo ""
|
||||
echo "🔍 Auto-Discovery Approach:"
|
||||
echo " • Vector Adapters: cognee/infrastructure/databases/vector/**/*Adapter.py"
|
||||
# echo " • Graph Adapters: cognee/infrastructure/databases/graph/**/*adapter.py"
|
||||
# echo " • Hybrid Adapters: cognee/infrastructure/databases/hybrid/**/*Adapter.py"
|
||||
echo ""
|
||||
echo "🚀 Using Dedicated Scripts:"
|
||||
echo " • Vector: ./tools/check_vector_adapters.sh"
|
||||
# echo " • Graph: ./tools/check_graph_adapters.sh"
|
||||
# echo " • Hybrid: ./tools/check_hybrid_adapters.sh"
|
||||
echo " • All: ./tools/check_all_adapters.sh"
|
||||
echo ""
|
||||
echo "🎯 Purpose: Enforce that database adapters are properly typed"
|
||||
echo "🔧 MyPy Configuration: mypy.ini (strict mode enabled)"
|
||||
echo "🚀 Maintenance-Free: Automatically discovers new adapters"
|
||||
echo ""
|
||||
echo "⚠️ This workflow FAILS on any type errors to ensure adapter quality."
|
||||
echo " All database adapters must be properly typed."
|
||||
echo ""
|
||||
echo "🛠️ To fix type issues locally, run:"
|
||||
echo " ./tools/check_all_adapters.sh # Check all adapters"
|
||||
echo " ./tools/check_vector_adapters.sh # Check vector adapters only"
|
||||
echo " mypy <adapter_file_path> --config-file mypy.ini # Check specific file"
|
||||
|
|
@ -179,5 +179,5 @@ def create_graph_engine(
|
|||
|
||||
raise EnvironmentError(
|
||||
f"Unsupported graph database provider: {graph_database_provider}. "
|
||||
f"Supported providers are: {', '.join(list(supported_databases.keys()) + ['neo4j', 'falkordb', 'kuzu', 'kuzu-remote', 'memgraph', 'neptune', 'neptune_analytics'])}"
|
||||
f"Supported providers are: {', '.join(list(supported_databases.keys()) + ['neo4j', 'falkordb', 'kuzu', 'kuzu-remote', 'neptune', 'neptune_analytics'])}"
|
||||
)
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -1,12 +1,13 @@
|
|||
import json
|
||||
import asyncio
|
||||
from uuid import UUID
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Dict, Any
|
||||
from chromadb import AsyncHttpClient, Settings
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.storage.utils import get_own_properties
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.infrastructure.engine.models.DataPoint import MetaData
|
||||
from cognee.infrastructure.engine.utils import parse_id
|
||||
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
||||
from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult
|
||||
|
|
@ -35,9 +36,9 @@ class IndexSchema(DataPoint):
|
|||
|
||||
text: str
|
||||
|
||||
metadata: dict = {"index_fields": ["text"]}
|
||||
metadata: MetaData = {"index_fields": ["text"], "type": "IndexSchema"}
|
||||
|
||||
def model_dump(self):
|
||||
def model_dump(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
"""
|
||||
Serialize the instance data for storage.
|
||||
|
||||
|
|
@ -49,11 +50,11 @@ class IndexSchema(DataPoint):
|
|||
|
||||
A dictionary containing serialized data processed for ChromaDB storage.
|
||||
"""
|
||||
data = super().model_dump()
|
||||
data = super().model_dump(**kwargs)
|
||||
return process_data_for_chroma(data)
|
||||
|
||||
|
||||
def process_data_for_chroma(data):
|
||||
def process_data_for_chroma(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert complex data types to a format suitable for ChromaDB storage.
|
||||
|
||||
|
|
@ -73,7 +74,7 @@ def process_data_for_chroma(data):
|
|||
|
||||
A dictionary containing the processed key-value pairs suitable for ChromaDB storage.
|
||||
"""
|
||||
processed_data = {}
|
||||
processed_data: Dict[str, Any] = {}
|
||||
for key, value in data.items():
|
||||
if isinstance(value, UUID):
|
||||
processed_data[key] = str(value)
|
||||
|
|
@ -90,7 +91,7 @@ def process_data_for_chroma(data):
|
|||
return processed_data
|
||||
|
||||
|
||||
def restore_data_from_chroma(data):
|
||||
def restore_data_from_chroma(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Restore original data structure from ChromaDB storage format.
|
||||
|
||||
|
|
@ -152,8 +153,8 @@ class ChromaDBAdapter(VectorDBInterface):
|
|||
"""
|
||||
|
||||
name = "ChromaDB"
|
||||
url: str
|
||||
api_key: str
|
||||
url: str | None
|
||||
api_key: str | None
|
||||
connection: AsyncHttpClient = None
|
||||
|
||||
def __init__(
|
||||
|
|
@ -216,7 +217,9 @@ class ChromaDBAdapter(VectorDBInterface):
|
|||
collections = await self.get_collection_names()
|
||||
return collection_name in collections
|
||||
|
||||
async def create_collection(self, collection_name: str, payload_schema=None):
|
||||
async def create_collection(
|
||||
self, collection_name: str, payload_schema: Optional[Any] = None
|
||||
) -> None:
|
||||
"""
|
||||
Create a new collection in ChromaDB if it does not already exist.
|
||||
|
||||
|
|
@ -254,7 +257,7 @@ class ChromaDBAdapter(VectorDBInterface):
|
|||
client = await self.get_connection()
|
||||
return await client.get_collection(collection_name)
|
||||
|
||||
async def create_data_points(self, collection_name: str, data_points: list[DataPoint]):
|
||||
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]) -> None:
|
||||
"""
|
||||
Create and upsert data points into the specified collection in ChromaDB.
|
||||
|
||||
|
|
@ -282,7 +285,7 @@ class ChromaDBAdapter(VectorDBInterface):
|
|||
ids=ids, embeddings=embeddings, metadatas=metadatas, documents=texts
|
||||
)
|
||||
|
||||
async def create_vector_index(self, index_name: str, index_property_name: str):
|
||||
async def create_vector_index(self, index_name: str, index_property_name: str) -> None:
|
||||
"""
|
||||
Create a vector index as a ChromaDB collection based on provided names.
|
||||
|
||||
|
|
@ -296,7 +299,7 @@ class ChromaDBAdapter(VectorDBInterface):
|
|||
|
||||
async def index_data_points(
|
||||
self, index_name: str, index_property_name: str, data_points: list[DataPoint]
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Index the provided data points based on the specified index property in ChromaDB.
|
||||
|
||||
|
|
@ -315,10 +318,11 @@ class ChromaDBAdapter(VectorDBInterface):
|
|||
text=getattr(data_point, data_point.metadata["index_fields"][0]),
|
||||
)
|
||||
for data_point in data_points
|
||||
if data_point.metadata and len(data_point.metadata["index_fields"]) > 0
|
||||
],
|
||||
)
|
||||
|
||||
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
|
||||
async def retrieve(self, collection_name: str, data_point_ids: List[str]) -> List[ScoredResult]:
|
||||
"""
|
||||
Retrieve data points by their IDs from a ChromaDB collection.
|
||||
|
||||
|
|
@ -350,12 +354,12 @@ class ChromaDBAdapter(VectorDBInterface):
|
|||
async def search(
|
||||
self,
|
||||
collection_name: str,
|
||||
query_text: str = None,
|
||||
query_vector: List[float] = None,
|
||||
query_text: Optional[str] = None,
|
||||
query_vector: Optional[List[float]] = None,
|
||||
limit: int = 15,
|
||||
with_vector: bool = False,
|
||||
normalized: bool = True,
|
||||
):
|
||||
) -> List[ScoredResult]:
|
||||
"""
|
||||
Search for items in a collection using either a text or a vector query.
|
||||
|
||||
|
|
@ -437,7 +441,7 @@ class ChromaDBAdapter(VectorDBInterface):
|
|||
query_texts: List[str],
|
||||
limit: int = 5,
|
||||
with_vectors: bool = False,
|
||||
):
|
||||
) -> List[List[ScoredResult]]:
|
||||
"""
|
||||
Perform multiple searches in a single request for efficiency, returning results for each
|
||||
query.
|
||||
|
|
@ -507,7 +511,7 @@ class ChromaDBAdapter(VectorDBInterface):
|
|||
|
||||
return all_results
|
||||
|
||||
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
||||
async def delete_data_points(self, collection_name: str, data_point_ids: List[str]) -> bool:
|
||||
"""
|
||||
Remove data points from a collection based on their IDs.
|
||||
|
||||
|
|
@ -528,7 +532,7 @@ class ChromaDBAdapter(VectorDBInterface):
|
|||
await collection.delete(ids=data_point_ids)
|
||||
return True
|
||||
|
||||
async def prune(self):
|
||||
async def prune(self) -> bool:
|
||||
"""
|
||||
Delete all collections in the ChromaDB database.
|
||||
|
||||
|
|
@ -538,12 +542,12 @@ class ChromaDBAdapter(VectorDBInterface):
|
|||
Returns True upon successful deletion of all collections.
|
||||
"""
|
||||
client = await self.get_connection()
|
||||
collections = await client.list_collections()
|
||||
for collection_name in collections:
|
||||
collection_names = await self.get_collection_names()
|
||||
for collection_name in collection_names:
|
||||
await client.delete_collection(collection_name)
|
||||
return True
|
||||
|
||||
async def get_collection_names(self):
|
||||
async def get_collection_names(self) -> List[str]:
|
||||
"""
|
||||
Retrieve the names of all collections in the ChromaDB database.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,12 +1,26 @@
|
|||
import asyncio
|
||||
import json
|
||||
from os import path
|
||||
from uuid import UUID
|
||||
import lancedb
|
||||
from pydantic import BaseModel
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from typing import Generic, List, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
|
||||
from typing import (
|
||||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
get_args,
|
||||
get_origin,
|
||||
get_type_hints,
|
||||
Dict,
|
||||
Any,
|
||||
)
|
||||
|
||||
from cognee.infrastructure.databases.exceptions import MissingQueryParameterError
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.infrastructure.engine.models.DataPoint import MetaData
|
||||
from cognee.infrastructure.engine.utils import parse_id
|
||||
from cognee.infrastructure.files.storage import get_file_storage
|
||||
from cognee.modules.storage.utils import copy_model, get_own_properties
|
||||
|
|
@ -30,16 +44,16 @@ class IndexSchema(DataPoint):
|
|||
to include 'text'.
|
||||
"""
|
||||
|
||||
id: str
|
||||
id: UUID
|
||||
text: str
|
||||
|
||||
metadata: dict = {"index_fields": ["text"]}
|
||||
metadata: MetaData = {"index_fields": ["text"], "type": "IndexSchema"}
|
||||
|
||||
|
||||
class LanceDBAdapter(VectorDBInterface):
|
||||
name = "LanceDB"
|
||||
url: str
|
||||
api_key: str
|
||||
url: Optional[str]
|
||||
api_key: Optional[str]
|
||||
connection: lancedb.AsyncConnection = None
|
||||
|
||||
def __init__(
|
||||
|
|
@ -53,7 +67,7 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
self.embedding_engine = embedding_engine
|
||||
self.VECTOR_DB_LOCK = asyncio.Lock()
|
||||
|
||||
async def get_connection(self):
|
||||
async def get_connection(self) -> lancedb.AsyncConnection:
|
||||
"""
|
||||
Establishes and returns a connection to the LanceDB.
|
||||
|
||||
|
|
@ -107,12 +121,11 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
collection_names = await connection.table_names()
|
||||
return collection_name in collection_names
|
||||
|
||||
async def create_collection(self, collection_name: str, payload_schema: BaseModel):
|
||||
async def create_collection(
|
||||
self, collection_name: str, payload_schema: Optional[Any] = None
|
||||
) -> None:
|
||||
vector_size = self.embedding_engine.get_vector_size()
|
||||
|
||||
payload_schema = self.get_data_point_schema(payload_schema)
|
||||
data_point_types = get_type_hints(payload_schema)
|
||||
|
||||
class LanceDataPoint(LanceModel):
|
||||
"""
|
||||
Represents a data point in the Lance model with an ID, vector, and associated payload.
|
||||
|
|
@ -123,28 +136,28 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
- payload: Additional data or metadata associated with the data point.
|
||||
"""
|
||||
|
||||
id: data_point_types["id"]
|
||||
vector: Vector(vector_size)
|
||||
payload: payload_schema
|
||||
id: str
|
||||
vector: Vector(vector_size) # type: ignore
|
||||
payload: str # JSON string for LanceDB compatibility
|
||||
|
||||
if not await self.has_collection(collection_name):
|
||||
async with self.VECTOR_DB_LOCK:
|
||||
if not await self.has_collection(collection_name):
|
||||
connection = await self.get_connection()
|
||||
return await connection.create_table(
|
||||
await connection.create_table(
|
||||
name=collection_name,
|
||||
schema=LanceDataPoint,
|
||||
exist_ok=True,
|
||||
)
|
||||
|
||||
async def get_collection(self, collection_name: str):
|
||||
async def get_collection(self, collection_name: str) -> Any:
|
||||
if not await self.has_collection(collection_name):
|
||||
raise CollectionNotFoundError(f"Collection '{collection_name}' not found!")
|
||||
|
||||
connection = await self.get_connection()
|
||||
return await connection.open_table(collection_name)
|
||||
|
||||
async def create_data_points(self, collection_name: str, data_points: list[DataPoint]):
|
||||
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]) -> None:
|
||||
payload_schema = type(data_points[0])
|
||||
|
||||
if not await self.has_collection(collection_name):
|
||||
|
|
@ -161,11 +174,9 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
[DataPoint.get_embeddable_data(data_point) for data_point in data_points]
|
||||
)
|
||||
|
||||
IdType = TypeVar("IdType")
|
||||
PayloadSchema = TypeVar("PayloadSchema")
|
||||
vector_size = self.embedding_engine.get_vector_size()
|
||||
|
||||
class LanceDataPoint(LanceModel, Generic[IdType, PayloadSchema]):
|
||||
class LanceDataPoint(LanceModel):
|
||||
"""
|
||||
Represents a data point in the Lance model with an ID, vector, and payload.
|
||||
|
||||
|
|
@ -174,15 +185,15 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
to the Lance data structure.
|
||||
"""
|
||||
|
||||
id: IdType
|
||||
vector: Vector(vector_size)
|
||||
payload: PayloadSchema
|
||||
id: str
|
||||
vector: Vector(vector_size) # type: ignore
|
||||
payload: str # JSON string for LanceDB compatibility
|
||||
|
||||
def create_lance_data_point(data_point: DataPoint, vector: list[float]) -> LanceDataPoint:
|
||||
def create_lance_data_point(data_point: DataPoint, vector: List[float]) -> Any:
|
||||
properties = get_own_properties(data_point)
|
||||
properties["id"] = str(properties["id"])
|
||||
|
||||
return LanceDataPoint[str, self.get_data_point_schema(type(data_point))](
|
||||
return LanceDataPoint(
|
||||
id=str(data_point.id),
|
||||
vector=vector,
|
||||
payload=properties,
|
||||
|
|
@ -201,7 +212,7 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
.execute(lance_data_points)
|
||||
)
|
||||
|
||||
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
|
||||
async def retrieve(self, collection_name: str, data_point_ids: list[str]) -> List[ScoredResult]:
|
||||
collection = await self.get_collection(collection_name)
|
||||
|
||||
if len(data_point_ids) == 1:
|
||||
|
|
@ -212,7 +223,7 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
return [
|
||||
ScoredResult(
|
||||
id=parse_id(result["id"]),
|
||||
payload=result["payload"],
|
||||
payload=json.loads(result["payload"]),
|
||||
score=0,
|
||||
)
|
||||
for result in results.to_dict("index").values()
|
||||
|
|
@ -221,12 +232,12 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
async def search(
|
||||
self,
|
||||
collection_name: str,
|
||||
query_text: str = None,
|
||||
query_vector: List[float] = None,
|
||||
query_text: Optional[str] = None,
|
||||
query_vector: Optional[List[float]] = None,
|
||||
limit: int = 15,
|
||||
with_vector: bool = False,
|
||||
normalized: bool = True,
|
||||
):
|
||||
) -> List[ScoredResult]:
|
||||
if query_text is None and query_vector is None:
|
||||
raise MissingQueryParameterError()
|
||||
|
||||
|
|
@ -254,7 +265,7 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
return [
|
||||
ScoredResult(
|
||||
id=parse_id(result["id"]),
|
||||
payload=result["payload"],
|
||||
payload=json.loads(result["payload"]),
|
||||
score=normalized_values[value_index],
|
||||
)
|
||||
for value_index, result in enumerate(result_values)
|
||||
|
|
@ -264,9 +275,9 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
self,
|
||||
collection_name: str,
|
||||
query_texts: List[str],
|
||||
limit: int = None,
|
||||
limit: Optional[int] = None,
|
||||
with_vectors: bool = False,
|
||||
):
|
||||
) -> List[List[ScoredResult]]:
|
||||
query_vectors = await self.embedding_engine.embed_text(query_texts)
|
||||
|
||||
return await asyncio.gather(
|
||||
|
|
@ -274,40 +285,41 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
self.search(
|
||||
collection_name=collection_name,
|
||||
query_vector=query_vector,
|
||||
limit=limit,
|
||||
limit=limit or 15,
|
||||
with_vector=with_vectors,
|
||||
)
|
||||
for query_vector in query_vectors
|
||||
]
|
||||
)
|
||||
|
||||
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
||||
async def delete_data_points(self, collection_name: str, data_point_ids: List[str]) -> None:
|
||||
collection = await self.get_collection(collection_name)
|
||||
|
||||
# Delete one at a time to avoid commit conflicts
|
||||
for data_point_id in data_point_ids:
|
||||
await collection.delete(f"id = '{data_point_id}'")
|
||||
|
||||
async def create_vector_index(self, index_name: str, index_property_name: str):
|
||||
async def create_vector_index(self, index_name: str, index_property_name: str) -> None:
|
||||
await self.create_collection(
|
||||
f"{index_name}_{index_property_name}", payload_schema=IndexSchema
|
||||
)
|
||||
|
||||
async def index_data_points(
|
||||
self, index_name: str, index_property_name: str, data_points: list[DataPoint]
|
||||
):
|
||||
self, index_name: str, index_property_name: str, data_points: List[DataPoint]
|
||||
) -> None:
|
||||
await self.create_data_points(
|
||||
f"{index_name}_{index_property_name}",
|
||||
[
|
||||
IndexSchema(
|
||||
id=str(data_point.id),
|
||||
id=data_point.id,
|
||||
text=getattr(data_point, data_point.metadata["index_fields"][0]),
|
||||
)
|
||||
for data_point in data_points
|
||||
if data_point.metadata and len(data_point.metadata.get("index_fields", [])) > 0
|
||||
],
|
||||
)
|
||||
|
||||
async def prune(self):
|
||||
async def prune(self) -> None:
|
||||
connection = await self.get_connection()
|
||||
collection_names = await connection.table_names()
|
||||
|
||||
|
|
@ -316,12 +328,15 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
await collection.delete("id IS NOT NULL")
|
||||
await connection.drop_table(collection_name)
|
||||
|
||||
if self.url.startswith("/"):
|
||||
if self.url and self.url.startswith("/"):
|
||||
db_dir_path = path.dirname(self.url)
|
||||
db_file_name = path.basename(self.url)
|
||||
await get_file_storage(db_dir_path).remove_all(db_file_name)
|
||||
|
||||
def get_data_point_schema(self, model_type: BaseModel):
|
||||
def get_data_point_schema(self, model_type: Optional[Any]) -> Any:
|
||||
if model_type is None:
|
||||
return DataPoint
|
||||
|
||||
related_models_fields = []
|
||||
|
||||
for field_name, field_config in model_type.model_fields.items():
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Dict
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -14,8 +14,10 @@ class ScoredResult(BaseModel):
|
|||
better outcome.
|
||||
- payload (Dict[str, Any]): Additional information related to the score, stored as
|
||||
key-value pairs in a dictionary.
|
||||
- vector (Optional[List[float]]): Optional vector embedding associated with the result.
|
||||
"""
|
||||
|
||||
id: UUID
|
||||
score: float # Lower score is better
|
||||
payload: Dict[str, Any]
|
||||
vector: Optional[List[float]] = None
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
import asyncio
|
||||
from typing import List, Optional, get_type_hints
|
||||
from typing import List, Optional, get_type_hints, Dict, Any
|
||||
from sqlalchemy.inspection import inspect
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy import JSON, Column, Table, select, delete, MetaData
|
||||
from sqlalchemy import JSON, Table, select, delete, MetaData
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
||||
from sqlalchemy.exc import ProgrammingError
|
||||
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
|
||||
|
|
@ -12,6 +12,7 @@ from asyncpg import DeadlockDetectedError, DuplicateTableError, UniqueViolationE
|
|||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.infrastructure.engine.models.DataPoint import MetaData as DataPointMetaData
|
||||
from cognee.infrastructure.engine.utils import parse_id
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
|
||||
|
|
@ -42,7 +43,7 @@ class IndexSchema(DataPoint):
|
|||
|
||||
text: str
|
||||
|
||||
metadata: dict = {"index_fields": ["text"]}
|
||||
metadata: DataPointMetaData = {"index_fields": ["text"], "type": "IndexSchema"}
|
||||
|
||||
|
||||
class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||
|
|
@ -122,8 +123,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
stop=stop_after_attempt(5),
|
||||
wait=wait_exponential(multiplier=2, min=1, max=6),
|
||||
)
|
||||
async def create_collection(self, collection_name: str, payload_schema=None):
|
||||
data_point_types = get_type_hints(DataPoint)
|
||||
async def create_collection(
|
||||
self, collection_name: str, payload_schema: Optional[Any] = None
|
||||
) -> None:
|
||||
vector_size = self.embedding_engine.get_vector_size()
|
||||
|
||||
async with self.VECTOR_DB_LOCK:
|
||||
|
|
@ -147,29 +149,31 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
__tablename__ = collection_name
|
||||
__table_args__ = {"extend_existing": True}
|
||||
# PGVector requires one column to be the primary key
|
||||
id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True)
|
||||
payload = Column(JSON)
|
||||
vector = Column(self.Vector(vector_size))
|
||||
id: Mapped[str] = mapped_column(primary_key=True)
|
||||
payload: Mapped[Dict[str, Any]] = mapped_column(JSON)
|
||||
vector: Mapped[List[float]] = mapped_column(self.Vector(vector_size))
|
||||
|
||||
def __init__(self, id, payload, vector):
|
||||
def __init__(
|
||||
self, id: str, payload: Dict[str, Any], vector: List[float]
|
||||
) -> None:
|
||||
self.id = id
|
||||
self.payload = payload
|
||||
self.vector = vector
|
||||
|
||||
async with self.engine.begin() as connection:
|
||||
if len(Base.metadata.tables.keys()) > 0:
|
||||
await connection.run_sync(
|
||||
Base.metadata.create_all, tables=[PGVectorDataPoint.__table__]
|
||||
)
|
||||
from sqlalchemy import Table
|
||||
|
||||
table: Table = PGVectorDataPoint.__table__ # type: ignore
|
||||
await connection.run_sync(Base.metadata.create_all, tables=[table])
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type(DeadlockDetectedError),
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=2, min=1, max=6),
|
||||
)
|
||||
@override_distributed(queued_add_data_points)
|
||||
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
|
||||
data_point_types = get_type_hints(DataPoint)
|
||||
@override_distributed(queued_add_data_points) # type: ignore
|
||||
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]) -> None:
|
||||
if not await self.has_collection(collection_name):
|
||||
await self.create_collection(
|
||||
collection_name=collection_name,
|
||||
|
|
@ -196,11 +200,11 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
__tablename__ = collection_name
|
||||
__table_args__ = {"extend_existing": True}
|
||||
# PGVector requires one column to be the primary key
|
||||
id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True)
|
||||
payload = Column(JSON)
|
||||
vector = Column(self.Vector(vector_size))
|
||||
id: Mapped[str] = mapped_column(primary_key=True)
|
||||
payload: Mapped[Dict[str, Any]] = mapped_column(JSON)
|
||||
vector: Mapped[List[float]] = mapped_column(self.Vector(vector_size))
|
||||
|
||||
def __init__(self, id, payload, vector):
|
||||
def __init__(self, id: str, payload: Dict[str, Any], vector: List[float]) -> None:
|
||||
self.id = id
|
||||
self.payload = payload
|
||||
self.vector = vector
|
||||
|
|
@ -225,13 +229,13 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
# else:
|
||||
pgvector_data_points.append(
|
||||
PGVectorDataPoint(
|
||||
id=data_point.id,
|
||||
id=str(data_point.id),
|
||||
vector=data_vectors[data_index],
|
||||
payload=serialize_data(data_point.model_dump()),
|
||||
)
|
||||
)
|
||||
|
||||
def to_dict(obj):
|
||||
def to_dict(obj: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
column.key: getattr(obj, column.key)
|
||||
for column in inspect(obj).mapper.column_attrs
|
||||
|
|
@ -245,12 +249,12 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
await session.execute(insert_statement)
|
||||
await session.commit()
|
||||
|
||||
async def create_vector_index(self, index_name: str, index_property_name: str):
|
||||
async def create_vector_index(self, index_name: str, index_property_name: str) -> None:
|
||||
await self.create_collection(f"{index_name}_{index_property_name}")
|
||||
|
||||
async def index_data_points(
|
||||
self, index_name: str, index_property_name: str, data_points: list[DataPoint]
|
||||
):
|
||||
self, index_name: str, index_property_name: str, data_points: List[DataPoint]
|
||||
) -> None:
|
||||
await self.create_data_points(
|
||||
f"{index_name}_{index_property_name}",
|
||||
[
|
||||
|
|
@ -262,11 +266,12 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
],
|
||||
)
|
||||
|
||||
async def get_table(self, collection_name: str) -> Table:
|
||||
async def get_table(self, table_name: str, schema_name: Optional[str] = None) -> Table:
|
||||
"""
|
||||
Dynamically loads a table using the given collection name
|
||||
with an async engine.
|
||||
Dynamically loads a table using the given table name
|
||||
with an async engine. Schema parameter is ignored for vector collections.
|
||||
"""
|
||||
collection_name = table_name
|
||||
async with self.engine.begin() as connection:
|
||||
# Create a MetaData instance to load table information
|
||||
metadata = MetaData()
|
||||
|
|
@ -279,15 +284,15 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
f"Collection '{collection_name}' not found!",
|
||||
)
|
||||
|
||||
async def retrieve(self, collection_name: str, data_point_ids: List[str]):
|
||||
async def retrieve(self, collection_name: str, data_point_ids: List[str]) -> List[ScoredResult]:
|
||||
# Get PGVectorDataPoint Table from database
|
||||
PGVectorDataPoint = await self.get_table(collection_name)
|
||||
|
||||
async with self.get_async_session() as session:
|
||||
results = await session.execute(
|
||||
query_result = await session.execute(
|
||||
select(PGVectorDataPoint).where(PGVectorDataPoint.c.id.in_(data_point_ids))
|
||||
)
|
||||
results = results.all()
|
||||
results = query_result.all()
|
||||
|
||||
return [
|
||||
ScoredResult(id=parse_id(result.id), payload=result.payload, score=0)
|
||||
|
|
@ -311,9 +316,6 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
# Get PGVectorDataPoint Table from database
|
||||
PGVectorDataPoint = await self.get_table(collection_name)
|
||||
|
||||
# NOTE: This needs to be initialized in case search doesn't return a value
|
||||
closest_items = []
|
||||
|
||||
# Use async session to connect to the database
|
||||
async with self.get_async_session() as session:
|
||||
query = select(
|
||||
|
|
@ -325,12 +327,12 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
query = query.limit(limit)
|
||||
|
||||
# Find closest vectors to query_vector
|
||||
closest_items = await session.execute(query)
|
||||
query_results = await session.execute(query)
|
||||
|
||||
vector_list = []
|
||||
|
||||
# Extract distances and find min/max for normalization
|
||||
for vector in closest_items.all():
|
||||
for vector in query_results.all():
|
||||
vector_list.append(
|
||||
{
|
||||
"id": parse_id(str(vector.id)),
|
||||
|
|
@ -349,7 +351,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
|
||||
# Create and return ScoredResult objects
|
||||
return [
|
||||
ScoredResult(id=row.get("id"), payload=row.get("payload"), score=row.get("score"))
|
||||
ScoredResult(id=row["id"], payload=row["payload"] or {}, score=row["score"])
|
||||
for row in vector_list
|
||||
]
|
||||
|
||||
|
|
@ -357,9 +359,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
self,
|
||||
collection_name: str,
|
||||
query_texts: List[str],
|
||||
limit: int = None,
|
||||
limit: Optional[int] = None,
|
||||
with_vectors: bool = False,
|
||||
):
|
||||
) -> List[List[ScoredResult]]:
|
||||
query_vectors = await self.embedding_engine.embed_text(query_texts)
|
||||
|
||||
return await asyncio.gather(
|
||||
|
|
@ -367,14 +369,14 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
self.search(
|
||||
collection_name=collection_name,
|
||||
query_vector=query_vector,
|
||||
limit=limit,
|
||||
limit=limit or 15,
|
||||
with_vector=with_vectors,
|
||||
)
|
||||
for query_vector in query_vectors
|
||||
]
|
||||
)
|
||||
|
||||
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
||||
async def delete_data_points(self, collection_name: str, data_point_ids: List[str]) -> Any:
|
||||
async with self.get_async_session() as session:
|
||||
# Get PGVectorDataPoint Table from database
|
||||
PGVectorDataPoint = await self.get_table(collection_name)
|
||||
|
|
@ -384,6 +386,6 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
await session.commit()
|
||||
return results
|
||||
|
||||
async def prune(self):
|
||||
async def prune(self) -> None:
|
||||
# Clean up the database if it was set up as temporary
|
||||
await self.delete_database()
|
||||
|
|
|
|||
|
|
@ -1,109 +0,0 @@
|
|||
import os
|
||||
|
||||
import pathlib
|
||||
import cognee
|
||||
from cognee.infrastructure.files.storage import get_storage_config
|
||||
from cognee.modules.search.operations import get_history
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.search.types import SearchType
|
||||
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
async def main():
|
||||
cognee.config.set_graph_database_provider("memgraph")
|
||||
data_directory_path = str(
|
||||
pathlib.Path(
|
||||
os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_memgraph")
|
||||
).resolve()
|
||||
)
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
cognee_directory_path = str(
|
||||
pathlib.Path(
|
||||
os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_memgraph")
|
||||
).resolve()
|
||||
)
|
||||
cognee.config.system_root_directory(cognee_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
dataset_name = "cs_explanations"
|
||||
|
||||
explanation_file_path = os.path.join(
|
||||
pathlib.Path(__file__).parent, "test_data/Natural_language_processing.txt"
|
||||
)
|
||||
await cognee.add([explanation_file_path], dataset_name)
|
||||
|
||||
text = """A quantum computer is a computer that takes advantage of quantum mechanical phenomena.
|
||||
At small scales, physical matter exhibits properties of both particles and waves, and quantum computing leverages this behavior, specifically quantum superposition and entanglement, using specialized hardware that supports the preparation and manipulation of quantum states.
|
||||
Classical physics cannot explain the operation of these quantum devices, and a scalable quantum computer could perform some calculations exponentially faster (with respect to input size scaling) than any modern "classical" computer. In particular, a large-scale quantum computer could break widely used encryption schemes and aid physicists in performing physical simulations; however, the current state of the technology is largely experimental and impractical, with several obstacles to useful applications. Moreover, scalable quantum computers do not hold promise for many practical tasks, and for many important tasks quantum speedups are proven impossible.
|
||||
The basic unit of information in quantum computing is the qubit, similar to the bit in traditional digital electronics. Unlike a classical bit, a qubit can exist in a superposition of its two "basis" states. When measuring a qubit, the result is a probabilistic output of a classical bit, therefore making quantum computers nondeterministic in general. If a quantum computer manipulates the qubit in a particular way, wave interference effects can amplify the desired measurement results. The design of quantum algorithms involves creating procedures that allow a quantum computer to perform calculations efficiently and quickly.
|
||||
Physically engineering high-quality qubits has proven challenging. If a physical qubit is not sufficiently isolated from its environment, it suffers from quantum decoherence, introducing noise into calculations. Paradoxically, perfectly isolating qubits is also undesirable because quantum computations typically need to initialize qubits, perform controlled qubit interactions, and measure the resulting quantum states. Each of those operations introduces errors and suffers from noise, and such inaccuracies accumulate.
|
||||
In principle, a non-quantum (classical) computer can solve the same computational problems as a quantum computer, given enough time. Quantum advantage comes in the form of time complexity rather than computability, and quantum complexity theory shows that some quantum algorithms for carefully selected tasks require exponentially fewer computational steps than the best known non-quantum algorithms. Such tasks can in theory be solved on a large-scale quantum computer whereas classical computers would not finish computations in any reasonable amount of time. However, quantum speedup is not universal or even typical across computational tasks, since basic tasks such as sorting are proven to not allow any asymptotic quantum speedup. Claims of quantum supremacy have drawn significant attention to the discipline, but are demonstrated on contrived tasks, while near-term practical use cases remain limited.
|
||||
"""
|
||||
|
||||
await cognee.add([text], dataset_name)
|
||||
|
||||
await cognee.cognify([dataset_name])
|
||||
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
|
||||
random_node_name = random_node.payload["text"]
|
||||
|
||||
search_results = await cognee.search(
|
||||
query_type=SearchType.INSIGHTS, query_text=random_node_name
|
||||
)
|
||||
assert len(search_results) != 0, "The search results list is empty."
|
||||
print("\n\nExtracted sentences are:\n")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
||||
search_results = await cognee.search(query_type=SearchType.CHUNKS, query_text=random_node_name)
|
||||
assert len(search_results) != 0, "The search results list is empty."
|
||||
print("\n\nExtracted chunks are:\n")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
||||
search_results = await cognee.search(
|
||||
query_type=SearchType.SUMMARIES, query_text=random_node_name
|
||||
)
|
||||
assert len(search_results) != 0, "Query related summaries don't exist."
|
||||
print("\nExtracted results are:\n")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
||||
search_results = await cognee.search(
|
||||
query_type=SearchType.NATURAL_LANGUAGE,
|
||||
query_text=f"Find nodes connected to node with name {random_node_name}",
|
||||
)
|
||||
assert len(search_results) != 0, "Query related natural language don't exist."
|
||||
print("\nExtracted results are:\n")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
||||
user = await get_default_user()
|
||||
history = await get_history(user.id)
|
||||
|
||||
assert len(history) == 8, "Search history is not correct."
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
data_root_directory = get_storage_config()["data_root_directory"]
|
||||
assert not os.path.isdir(data_root_directory), "Local data files are not deleted"
|
||||
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
nodes, edges = await graph_engine.get_graph_data()
|
||||
assert len(nodes) == 0 and len(edges) == 0, "Memgraph graph database is not empty"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
26
mypy.ini
26
mypy.ini
|
|
@ -1,7 +1,7 @@
|
|||
[mypy]
|
||||
python_version=3.8
|
||||
python_version=3.10
|
||||
ignore_missing_imports=false
|
||||
strict_optional=false
|
||||
strict_optional=true
|
||||
warn_redundant_casts=true
|
||||
disallow_any_generics=true
|
||||
disallow_untyped_defs=true
|
||||
|
|
@ -10,6 +10,12 @@ warn_return_any=true
|
|||
namespace_packages=true
|
||||
warn_unused_ignores=true
|
||||
show_error_codes=true
|
||||
disallow_incomplete_defs=true
|
||||
disallow_untyped_decorators=true
|
||||
no_implicit_optional=true
|
||||
warn_unreachable=true
|
||||
warn_no_return=true
|
||||
warn_unused_configs=true
|
||||
#exclude=reflection/module_cases/*
|
||||
exclude=docs/examples/archive/*|tests/reflection/module_cases/*
|
||||
|
||||
|
|
@ -18,6 +24,22 @@ disallow_untyped_defs=false
|
|||
warn_return_any=false
|
||||
|
||||
|
||||
[mypy-cognee.infrastructure.databases.*]
|
||||
ignore_missing_imports=true
|
||||
|
||||
# Third-party database libraries that lack type stubs
|
||||
[mypy-chromadb.*]
|
||||
ignore_missing_imports=true
|
||||
|
||||
[mypy-lancedb.*]
|
||||
ignore_missing_imports=true
|
||||
|
||||
[mypy-asyncpg.*]
|
||||
ignore_missing_imports=true
|
||||
|
||||
[mypy-pgvector.*]
|
||||
ignore_missing_imports=true
|
||||
|
||||
[mypy-docs.*]
|
||||
disallow_untyped_defs=false
|
||||
|
||||
|
|
|
|||
82
notebooks/neptune-analytics-example.ipynb
vendored
82
notebooks/neptune-analytics-example.ipynb
vendored
|
|
@ -83,16 +83,16 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import pathlib\n",
|
||||
"from cognee import config, add, cognify, search, SearchType, prune, visualize_graph\n",
|
||||
"from dotenv import load_dotenv"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
|
|
@ -106,7 +106,9 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# load environment variables from file .env\n",
|
||||
"load_dotenv()\n",
|
||||
|
|
@ -145,9 +147,7 @@
|
|||
" \"vector_db_url\": f\"neptune-graph://{graph_identifier}\", # Neptune Analytics endpoint with the format neptune-graph://<GRAPH_ID>\n",
|
||||
" }\n",
|
||||
")"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
|
|
@ -159,19 +159,19 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Prune data and system metadata before running, only if we want \"fresh\" state.\n",
|
||||
"await prune.prune_data()\n",
|
||||
"await prune.prune_system(metadata=True)"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Setup data and cognify\n",
|
||||
"\n",
|
||||
|
|
@ -180,7 +180,9 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Add sample text to the dataset\n",
|
||||
"sample_text_1 = \"\"\"Neptune Analytics is a memory-optimized graph database engine for analytics. With Neptune\n",
|
||||
|
|
@ -205,9 +207,7 @@
|
|||
"\n",
|
||||
"# Cognify the text data.\n",
|
||||
"await cognify([dataset_name])"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
|
|
@ -215,14 +215,16 @@
|
|||
"source": [
|
||||
"## Graph Memory visualization\n",
|
||||
"\n",
|
||||
"Initialize Memgraph as a Graph Memory store and save to .artefacts/graph_visualization.html\n",
|
||||
"Initialize a Graph Memory store and save to .artefacts/graph_visualization.html\n",
|
||||
"\n",
|
||||
""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Get a graphistry url (Register for a free account at https://www.graphistry.com)\n",
|
||||
"# url = await render_graph()\n",
|
||||
|
|
@ -235,9 +237,7 @@
|
|||
" ).resolve()\n",
|
||||
")\n",
|
||||
"await visualize_graph(graph_file_path)"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
|
|
@ -250,19 +250,19 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Completion query that uses graph data to form context.\n",
|
||||
"graph_completion = await search(query_text=\"What is Neptune Analytics?\", query_type=SearchType.GRAPH_COMPLETION)\n",
|
||||
"print(\"\\nGraph completion result is:\")\n",
|
||||
"print(graph_completion)"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## SEARCH: RAG Completion\n",
|
||||
"\n",
|
||||
|
|
@ -271,19 +271,19 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Completion query that uses document chunks to form context.\n",
|
||||
"rag_completion = await search(query_text=\"What is Neptune Analytics?\", query_type=SearchType.RAG_COMPLETION)\n",
|
||||
"print(\"\\nRAG Completion result is:\")\n",
|
||||
"print(rag_completion)"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## SEARCH: Graph Insights\n",
|
||||
"\n",
|
||||
|
|
@ -291,8 +291,10 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Search graph insights\n",
|
||||
"insights_results = await search(query_text=\"Neptune Analytics\", query_type=SearchType.INSIGHTS)\n",
|
||||
|
|
@ -302,13 +304,11 @@
|
|||
" tgt_node = result[2].get(\"name\", result[2][\"type\"])\n",
|
||||
" relationship = result[1].get(\"relationship_name\", \"__relationship__\")\n",
|
||||
" print(f\"- {src_node} -[{relationship}]-> {tgt_node}\")"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## SEARCH: Entity Summaries\n",
|
||||
"\n",
|
||||
|
|
@ -316,8 +316,10 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Query all summaries related to query.\n",
|
||||
"summaries = await search(query_text=\"Neptune Analytics\", query_type=SearchType.SUMMARIES)\n",
|
||||
|
|
@ -326,13 +328,11 @@
|
|||
" type = summary[\"type\"]\n",
|
||||
" text = summary[\"text\"]\n",
|
||||
" print(f\"- {type}: {text}\")"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## SEARCH: Chunks\n",
|
||||
"\n",
|
||||
|
|
@ -340,8 +340,10 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chunks = await search(query_text=\"Neptune Analytics\", query_type=SearchType.CHUNKS)\n",
|
||||
"print(\"\\nChunk results are:\")\n",
|
||||
|
|
@ -349,9 +351,7 @@
|
|||
" type = chunk[\"type\"]\n",
|
||||
" text = chunk[\"text\"]\n",
|
||||
" print(f\"- {type}: {text}\")"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": null
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
|
|
|||
41
tools/check_all_adapters.sh
Executable file
41
tools/check_all_adapters.sh
Executable file
|
|
@ -0,0 +1,41 @@
|
|||
#!/bin/bash
|
||||
|
||||
# All Database Adapters MyPy Check Script
|
||||
|
||||
set -e # Exit on any error
|
||||
|
||||
echo "🚀 Running MyPy checks on all database adapters..."
|
||||
echo ""
|
||||
|
||||
# Ensure we're in the project root directory
|
||||
cd "$(dirname "$0")/.."
|
||||
|
||||
# Run all three adapter checks
|
||||
echo "========================================="
|
||||
echo "1️⃣ VECTOR DATABASE ADAPTERS"
|
||||
echo "========================================="
|
||||
./tools/check_vector_adapters.sh
|
||||
|
||||
echo ""
|
||||
echo "========================================="
|
||||
echo "2️⃣ GRAPH DATABASE ADAPTERS"
|
||||
echo "========================================="
|
||||
./tools/check_graph_adapters.sh
|
||||
|
||||
echo ""
|
||||
echo "========================================="
|
||||
echo "3️⃣ HYBRID DATABASE ADAPTERS"
|
||||
echo "========================================="
|
||||
./tools/check_hybrid_adapters.sh
|
||||
|
||||
echo ""
|
||||
echo "🎉 All Database Adapters MyPy Checks Complete!"
|
||||
echo ""
|
||||
echo "🔍 Auto-Discovery Approach:"
|
||||
echo " • Vector Adapters: cognee/infrastructure/databases/vector/**/*Adapter.py"
|
||||
echo " • Graph Adapters: cognee/infrastructure/databases/graph/**/*adapter.py"
|
||||
echo " • Hybrid Adapters: cognee/infrastructure/databases/hybrid/**/*Adapter.py"
|
||||
echo ""
|
||||
echo "🎯 Purpose: Enforce that database adapters are properly typed"
|
||||
echo "🔧 MyPy Configuration: mypy.ini (strict mode enabled)"
|
||||
echo "🚀 Maintenance-Free: Automatically discovers new adapters"
|
||||
41
tools/check_graph_adapters.sh
Executable file
41
tools/check_graph_adapters.sh
Executable file
|
|
@ -0,0 +1,41 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Graph Database Adapters MyPy Check Script
|
||||
|
||||
set -e # Exit on any error
|
||||
|
||||
echo "🔍 Discovering Graph Database Adapters..."
|
||||
|
||||
# Ensure we're in the project root directory
|
||||
cd "$(dirname "$0")/.."
|
||||
|
||||
# Activate virtual environment
|
||||
source .venv/bin/activate
|
||||
|
||||
# Find all adapter.py and *adapter.py files in graph database directories, excluding utility files
|
||||
graph_adapters=$(find cognee/infrastructure/databases/graph -name "*adapter.py" -o -name "adapter.py" | grep -v "use_graph_adapter.py" | sort)
|
||||
|
||||
if [ -z "$graph_adapters" ]; then
|
||||
echo "No graph database adapters found"
|
||||
exit 0
|
||||
else
|
||||
echo "Found graph database adapters:"
|
||||
echo "$graph_adapters" | sed 's/^/ • /'
|
||||
echo ""
|
||||
|
||||
echo "Running MyPy on graph database adapters..."
|
||||
|
||||
# Use while read to properly handle each file
|
||||
echo "$graph_adapters" | while read -r adapter; do
|
||||
if [ -n "$adapter" ]; then
|
||||
echo "Checking: $adapter"
|
||||
uv run mypy "$adapter" \
|
||||
--config-file mypy.ini \
|
||||
--show-error-codes \
|
||||
--no-error-summary
|
||||
echo ""
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
||||
echo "✅ Graph Database Adapters MyPy Check Complete!"
|
||||
41
tools/check_hybrid_adapters.sh
Executable file
41
tools/check_hybrid_adapters.sh
Executable file
|
|
@ -0,0 +1,41 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Hybrid Database Adapters MyPy Check Script
|
||||
|
||||
set -e # Exit on any error
|
||||
|
||||
echo "🔍 Discovering Hybrid Database Adapters..."
|
||||
|
||||
# Ensure we're in the project root directory
|
||||
cd "$(dirname "$0")/.."
|
||||
|
||||
# Activate virtual environment
|
||||
source .venv/bin/activate
|
||||
|
||||
# Find all *Adapter.py files in hybrid database directories
|
||||
hybrid_adapters=$(find cognee/infrastructure/databases/hybrid -name "*Adapter.py" -type f | sort)
|
||||
|
||||
if [ -z "$hybrid_adapters" ]; then
|
||||
echo "No hybrid database adapters found"
|
||||
exit 0
|
||||
else
|
||||
echo "Found hybrid database adapters:"
|
||||
echo "$hybrid_adapters" | sed 's/^/ • /'
|
||||
echo ""
|
||||
|
||||
echo "Running MyPy on hybrid database adapters..."
|
||||
|
||||
# Use while read to properly handle each file
|
||||
echo "$hybrid_adapters" | while read -r adapter; do
|
||||
if [ -n "$adapter" ]; then
|
||||
echo "Checking: $adapter"
|
||||
uv run mypy "$adapter" \
|
||||
--config-file mypy.ini \
|
||||
--show-error-codes \
|
||||
--no-error-summary
|
||||
echo ""
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
||||
echo "✅ Hybrid Database Adapters MyPy Check Complete!"
|
||||
41
tools/check_vector_adapters.sh
Executable file
41
tools/check_vector_adapters.sh
Executable file
|
|
@ -0,0 +1,41 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Vector Database Adapters MyPy Check Script
|
||||
|
||||
set -e # Exit on any error
|
||||
|
||||
echo "🔍 Discovering Vector Database Adapters..."
|
||||
|
||||
# Ensure we're in the project root directory
|
||||
cd "$(dirname "$0")/.."
|
||||
|
||||
# Activate virtual environment
|
||||
source .venv/bin/activate
|
||||
|
||||
# Find all *Adapter.py files in vector database directories
|
||||
vector_adapters=$(find cognee/infrastructure/databases/vector -name "*Adapter.py" -type f | sort)
|
||||
|
||||
if [ -z "$vector_adapters" ]; then
|
||||
echo "No vector database adapters found"
|
||||
exit 0
|
||||
else
|
||||
echo "Found vector database adapters:"
|
||||
echo "$vector_adapters" | sed 's/^/ • /'
|
||||
echo ""
|
||||
|
||||
echo "Running MyPy on vector database adapters..."
|
||||
|
||||
# Use while read to properly handle each file
|
||||
echo "$vector_adapters" | while read -r adapter; do
|
||||
if [ -n "$adapter" ]; then
|
||||
echo "Checking: $adapter"
|
||||
uv run mypy "$adapter" \
|
||||
--config-file mypy.ini \
|
||||
--show-error-codes \
|
||||
--no-error-summary
|
||||
echo ""
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
||||
echo "✅ Vector Database Adapters MyPy Check Complete!"
|
||||
Loading…
Add table
Reference in a new issue