ruff format
This commit is contained in:
parent
0e0bf9a00d
commit
c7b0da7aa6
6 changed files with 66 additions and 44 deletions
|
|
@ -173,7 +173,9 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
logger.warning(f"Kuzu S3 storage file not found: {self.db_path}")
|
logger.warning(f"Kuzu S3 storage file not found: {self.db_path}")
|
||||||
|
|
||||||
async def query(self, query: str, params: Optional[Dict[str, Any]] = None) -> List[Tuple[Any, ...]]:
|
async def query(
|
||||||
|
self, query: str, params: Optional[Dict[str, Any]] = None
|
||||||
|
) -> List[Tuple[Any, ...]]:
|
||||||
"""
|
"""
|
||||||
Execute a Kuzu query asynchronously with automatic reconnection.
|
Execute a Kuzu query asynchronously with automatic reconnection.
|
||||||
|
|
||||||
|
|
@ -320,7 +322,9 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
result = await self.query(query_str, {"id": node_id})
|
result = await self.query(query_str, {"id": node_id})
|
||||||
return result[0][0] if result else False
|
return result[0][0] if result else False
|
||||||
|
|
||||||
async def add_node(self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None) -> None:
|
async def add_node(
|
||||||
|
self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Add a single node to the graph if it doesn't exist.
|
Add a single node to the graph if it doesn't exist.
|
||||||
|
|
||||||
|
|
@ -343,8 +347,9 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
"type": str(node_properties.get("type", "")),
|
"type": str(node_properties.get("type", "")),
|
||||||
}
|
}
|
||||||
# Use the passed properties, excluding core fields
|
# Use the passed properties, excluding core fields
|
||||||
other_properties = {k: v for k, v in node_properties.items()
|
other_properties = {
|
||||||
if k not in ["id", "name", "type"]}
|
k: v for k, v in node_properties.items() if k not in ["id", "name", "type"]
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
# Handle DataPoint object
|
# Handle DataPoint object
|
||||||
node_properties = node.model_dump()
|
node_properties = node.model_dump()
|
||||||
|
|
@ -354,8 +359,9 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
"type": str(node_properties.get("type", "")),
|
"type": str(node_properties.get("type", "")),
|
||||||
}
|
}
|
||||||
# Remove core fields from other properties
|
# Remove core fields from other properties
|
||||||
other_properties = {k: v for k, v in node_properties.items()
|
other_properties = {
|
||||||
if k not in ["id", "name", "type"]}
|
k: v for k, v in node_properties.items() if k not in ["id", "name", "type"]
|
||||||
|
}
|
||||||
|
|
||||||
core_properties["properties"] = json.dumps(other_properties, cls=JSONEncoder)
|
core_properties["properties"] = json.dumps(other_properties, cls=JSONEncoder)
|
||||||
|
|
||||||
|
|
@ -593,7 +599,9 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
)
|
)
|
||||||
return result[0][0] if result else False
|
return result[0][0] if result else False
|
||||||
|
|
||||||
async def has_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> List[Tuple[str, str, str, Dict[str, Any]]]:
|
async def has_edges(
|
||||||
|
self, edges: List[Tuple[str, str, str, Dict[str, Any]]]
|
||||||
|
) -> List[Tuple[str, str, str, Dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
Check if multiple edges exist in a batch operation.
|
Check if multiple edges exist in a batch operation.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -79,7 +79,9 @@ class RemoteKuzuAdapter(KuzuAdapter):
|
||||||
logger.error(f"Request data: {data}")
|
logger.error(f"Request data: {data}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def query(self, query: str, params: Optional[dict[str, Any]] = None) -> List[Tuple[Any, ...]]:
|
async def query(
|
||||||
|
self, query: str, params: Optional[dict[str, Any]] = None
|
||||||
|
) -> List[Tuple[Any, ...]]:
|
||||||
"""Execute a Kuzu query via the REST API."""
|
"""Execute a Kuzu query via the REST API."""
|
||||||
try:
|
try:
|
||||||
# Initialize schema if needed
|
# Initialize schema if needed
|
||||||
|
|
|
||||||
|
|
@ -142,7 +142,9 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
)
|
)
|
||||||
return results[0]["node_exists"] if len(results) > 0 else False
|
return results[0]["node_exists"] if len(results) > 0 else False
|
||||||
|
|
||||||
async def add_node(self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None) -> None:
|
async def add_node(
|
||||||
|
self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Add a new node to the database based on the provided DataPoint object or string ID.
|
Add a new node to the database based on the provided DataPoint object or string ID.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -217,7 +217,9 @@ class ChromaDBAdapter(VectorDBInterface):
|
||||||
collections = await self.get_collection_names()
|
collections = await self.get_collection_names()
|
||||||
return collection_name in collections
|
return collection_name in collections
|
||||||
|
|
||||||
async def create_collection(self, collection_name: str, payload_schema: Optional[Any] = None) -> None:
|
async def create_collection(
|
||||||
|
self, collection_name: str, payload_schema: Optional[Any] = None
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Create a new collection in ChromaDB if it does not already exist.
|
Create a new collection in ChromaDB if it does not already exist.
|
||||||
|
|
||||||
|
|
@ -313,10 +315,7 @@ class ChromaDBAdapter(VectorDBInterface):
|
||||||
[
|
[
|
||||||
IndexSchema(
|
IndexSchema(
|
||||||
id=data_point.id,
|
id=data_point.id,
|
||||||
text=getattr(
|
text=getattr(data_point, data_point.metadata["index_fields"][0]),
|
||||||
data_point,
|
|
||||||
data_point.metadata["index_fields"][0]
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
for data_point in data_points
|
for data_point in data_points
|
||||||
if data_point.metadata and len(data_point.metadata["index_fields"]) > 0
|
if data_point.metadata and len(data_point.metadata["index_fields"]) > 0
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,18 @@ from uuid import UUID
|
||||||
import lancedb
|
import lancedb
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from lancedb.pydantic import LanceModel, Vector
|
from lancedb.pydantic import LanceModel, Vector
|
||||||
from typing import Generic, List, Optional, TypeVar, Union, get_args, get_origin, get_type_hints, Dict, Any
|
from typing import (
|
||||||
|
Generic,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
get_args,
|
||||||
|
get_origin,
|
||||||
|
get_type_hints,
|
||||||
|
Dict,
|
||||||
|
Any,
|
||||||
|
)
|
||||||
|
|
||||||
from cognee.infrastructure.databases.exceptions import MissingQueryParameterError
|
from cognee.infrastructure.databases.exceptions import MissingQueryParameterError
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
|
|
@ -109,7 +120,9 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
collection_names = await connection.table_names()
|
collection_names = await connection.table_names()
|
||||||
return collection_name in collection_names
|
return collection_name in collection_names
|
||||||
|
|
||||||
async def create_collection(self, collection_name: str, payload_schema: Optional[Any] = None) -> None:
|
async def create_collection(
|
||||||
|
self, collection_name: str, payload_schema: Optional[Any] = None
|
||||||
|
) -> None:
|
||||||
vector_size = self.embedding_engine.get_vector_size()
|
vector_size = self.embedding_engine.get_vector_size()
|
||||||
|
|
||||||
class LanceDataPoint(LanceModel):
|
class LanceDataPoint(LanceModel):
|
||||||
|
|
@ -123,7 +136,9 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
id: UUID
|
id: UUID
|
||||||
vector: Vector[vector_size] # TODO: double check and consider raising this later in Pydantic
|
vector: Vector[
|
||||||
|
vector_size
|
||||||
|
] # TODO: double check and consider raising this later in Pydantic
|
||||||
payload: Dict[str, Any]
|
payload: Dict[str, Any]
|
||||||
|
|
||||||
if not await self.has_collection(collection_name):
|
if not await self.has_collection(collection_name):
|
||||||
|
|
@ -300,10 +315,7 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
[
|
[
|
||||||
IndexSchema(
|
IndexSchema(
|
||||||
id=data_point.id,
|
id=data_point.id,
|
||||||
text=getattr(
|
text=getattr(data_point, data_point.metadata["index_fields"][0]),
|
||||||
data_point,
|
|
||||||
data_point.metadata["index_fields"][0]
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
for data_point in data_points
|
for data_point in data_points
|
||||||
if data_point.metadata and len(data_point.metadata.get("index_fields", [])) > 0
|
if data_point.metadata and len(data_point.metadata.get("index_fields", [])) > 0
|
||||||
|
|
|
||||||
|
|
@ -123,7 +123,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
stop=stop_after_attempt(5),
|
stop=stop_after_attempt(5),
|
||||||
wait=wait_exponential(multiplier=2, min=1, max=6),
|
wait=wait_exponential(multiplier=2, min=1, max=6),
|
||||||
)
|
)
|
||||||
async def create_collection(self, collection_name: str, payload_schema: Optional[Any] = None) -> None:
|
async def create_collection(
|
||||||
|
self, collection_name: str, payload_schema: Optional[Any] = None
|
||||||
|
) -> None:
|
||||||
vector_size = self.embedding_engine.get_vector_size()
|
vector_size = self.embedding_engine.get_vector_size()
|
||||||
|
|
||||||
async with self.VECTOR_DB_LOCK:
|
async with self.VECTOR_DB_LOCK:
|
||||||
|
|
@ -151,7 +153,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
payload: Mapped[Dict[str, Any]] = mapped_column(JSON)
|
payload: Mapped[Dict[str, Any]] = mapped_column(JSON)
|
||||||
vector: Mapped[List[float]] = mapped_column(self.Vector(vector_size))
|
vector: Mapped[List[float]] = mapped_column(self.Vector(vector_size))
|
||||||
|
|
||||||
def __init__(self, id: str, payload: Dict[str, Any], vector: List[float]) -> None:
|
def __init__(
|
||||||
|
self, id: str, payload: Dict[str, Any], vector: List[float]
|
||||||
|
) -> None:
|
||||||
self.id = id
|
self.id = id
|
||||||
self.payload = payload
|
self.payload = payload
|
||||||
self.vector = vector
|
self.vector = vector
|
||||||
|
|
@ -159,10 +163,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
async with self.engine.begin() as connection:
|
async with self.engine.begin() as connection:
|
||||||
if len(Base.metadata.tables.keys()) > 0:
|
if len(Base.metadata.tables.keys()) > 0:
|
||||||
from sqlalchemy import Table
|
from sqlalchemy import Table
|
||||||
|
|
||||||
table: Table = PGVectorDataPoint.__table__ # type: ignore
|
table: Table = PGVectorDataPoint.__table__ # type: ignore
|
||||||
await connection.run_sync(
|
await connection.run_sync(Base.metadata.create_all, tables=[table])
|
||||||
Base.metadata.create_all, tables=[table]
|
|
||||||
)
|
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
retry=retry_if_exception_type(DeadlockDetectedError),
|
retry=retry_if_exception_type(DeadlockDetectedError),
|
||||||
|
|
@ -351,11 +354,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
|
|
||||||
# Create and return ScoredResult objects
|
# Create and return ScoredResult objects
|
||||||
return [
|
return [
|
||||||
ScoredResult(
|
ScoredResult(id=row["id"], payload=row["payload"] or {}, score=row["score"])
|
||||||
id=row["id"],
|
|
||||||
payload=row["payload"] or {},
|
|
||||||
score=row["score"]
|
|
||||||
)
|
|
||||||
for row in vector_list
|
for row in vector_list
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue