ruff format
This commit is contained in:
parent
0e0bf9a00d
commit
c7b0da7aa6
6 changed files with 66 additions and 44 deletions
|
|
@ -145,7 +145,7 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to initialize Kuzu database: {e}")
|
logger.error(f"Failed to initialize Kuzu database: {e}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def _get_connection(self) -> Connection:
|
def _get_connection(self) -> Connection:
|
||||||
"""Get the connection to the Kuzu database."""
|
"""Get the connection to the Kuzu database."""
|
||||||
if not self.connection:
|
if not self.connection:
|
||||||
|
|
@ -173,7 +173,9 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
logger.warning(f"Kuzu S3 storage file not found: {self.db_path}")
|
logger.warning(f"Kuzu S3 storage file not found: {self.db_path}")
|
||||||
|
|
||||||
async def query(self, query: str, params: Optional[Dict[str, Any]] = None) -> List[Tuple[Any, ...]]:
|
async def query(
|
||||||
|
self, query: str, params: Optional[Dict[str, Any]] = None
|
||||||
|
) -> List[Tuple[Any, ...]]:
|
||||||
"""
|
"""
|
||||||
Execute a Kuzu query asynchronously with automatic reconnection.
|
Execute a Kuzu query asynchronously with automatic reconnection.
|
||||||
|
|
||||||
|
|
@ -221,7 +223,7 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
val = val.as_py()
|
val = val.as_py()
|
||||||
processed_rows.append(val)
|
processed_rows.append(val)
|
||||||
rows.append(tuple(processed_rows))
|
rows.append(tuple(processed_rows))
|
||||||
|
|
||||||
return rows
|
return rows
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Query execution failed: {str(e)}")
|
logger.error(f"Query execution failed: {str(e)}")
|
||||||
|
|
@ -320,7 +322,9 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
result = await self.query(query_str, {"id": node_id})
|
result = await self.query(query_str, {"id": node_id})
|
||||||
return result[0][0] if result else False
|
return result[0][0] if result else False
|
||||||
|
|
||||||
async def add_node(self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None) -> None:
|
async def add_node(
|
||||||
|
self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Add a single node to the graph if it doesn't exist.
|
Add a single node to the graph if it doesn't exist.
|
||||||
|
|
||||||
|
|
@ -343,8 +347,9 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
"type": str(node_properties.get("type", "")),
|
"type": str(node_properties.get("type", "")),
|
||||||
}
|
}
|
||||||
# Use the passed properties, excluding core fields
|
# Use the passed properties, excluding core fields
|
||||||
other_properties = {k: v for k, v in node_properties.items()
|
other_properties = {
|
||||||
if k not in ["id", "name", "type"]}
|
k: v for k, v in node_properties.items() if k not in ["id", "name", "type"]
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
# Handle DataPoint object
|
# Handle DataPoint object
|
||||||
node_properties = node.model_dump()
|
node_properties = node.model_dump()
|
||||||
|
|
@ -354,8 +359,9 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
"type": str(node_properties.get("type", "")),
|
"type": str(node_properties.get("type", "")),
|
||||||
}
|
}
|
||||||
# Remove core fields from other properties
|
# Remove core fields from other properties
|
||||||
other_properties = {k: v for k, v in node_properties.items()
|
other_properties = {
|
||||||
if k not in ["id", "name", "type"]}
|
k: v for k, v in node_properties.items() if k not in ["id", "name", "type"]
|
||||||
|
}
|
||||||
|
|
||||||
core_properties["properties"] = json.dumps(other_properties, cls=JSONEncoder)
|
core_properties["properties"] = json.dumps(other_properties, cls=JSONEncoder)
|
||||||
|
|
||||||
|
|
@ -593,7 +599,9 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
)
|
)
|
||||||
return result[0][0] if result else False
|
return result[0][0] if result else False
|
||||||
|
|
||||||
async def has_edges(self, edges: List[Tuple[str, str, str, Dict[str, Any]]]) -> List[Tuple[str, str, str, Dict[str, Any]]]:
|
async def has_edges(
|
||||||
|
self, edges: List[Tuple[str, str, str, Dict[str, Any]]]
|
||||||
|
) -> List[Tuple[str, str, str, Dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
Check if multiple edges exist in a batch operation.
|
Check if multiple edges exist in a batch operation.
|
||||||
|
|
||||||
|
|
@ -793,7 +801,7 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
relationship_name = row[1]
|
relationship_name = row[1]
|
||||||
target_node = self._parse_node_properties(row[2])
|
target_node = self._parse_node_properties(row[2])
|
||||||
# TODO: any edge properties we can add? Adding empty to avoid modifying query without reason
|
# TODO: any edge properties we can add? Adding empty to avoid modifying query without reason
|
||||||
edges.append((source_node, relationship_name, target_node, {})) # type: ignore # currently each node is a dict, wihle typing expects nodes to be strings
|
edges.append((source_node, relationship_name, target_node, {})) # type: ignore # currently each node is a dict, wihle typing expects nodes to be strings
|
||||||
return edges
|
return edges
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to get edges for node {node_id}: {e}")
|
logger.error(f"Failed to get edges for node {node_id}: {e}")
|
||||||
|
|
@ -1362,8 +1370,8 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
try:
|
try:
|
||||||
# Get basic graph data
|
# Get basic graph data
|
||||||
nodes, edges = await self.get_model_independent_graph_data()
|
nodes, edges = await self.get_model_independent_graph_data()
|
||||||
num_nodes = len(nodes[0]["nodes"]) if nodes else 0 # type: ignore # nodes is type string?
|
num_nodes = len(nodes[0]["nodes"]) if nodes else 0 # type: ignore # nodes is type string?
|
||||||
num_edges = len(edges[0]["elements"]) if edges else 0 # type: ignore # edges is type string?
|
num_edges = len(edges[0]["elements"]) if edges else 0 # type: ignore # edges is type string?
|
||||||
|
|
||||||
# Calculate mandatory metrics
|
# Calculate mandatory metrics
|
||||||
mandatory_metrics = {
|
mandatory_metrics = {
|
||||||
|
|
@ -1571,10 +1579,10 @@ class KuzuAdapter(GraphDBInterface):
|
||||||
|
|
||||||
# Reinitialize the database
|
# Reinitialize the database
|
||||||
self._initialize_connection()
|
self._initialize_connection()
|
||||||
|
|
||||||
if not self._get_connection():
|
if not self._get_connection():
|
||||||
raise RuntimeError("Failed to establish database connection")
|
raise RuntimeError("Failed to establish database connection")
|
||||||
|
|
||||||
# Verify the database is empty
|
# Verify the database is empty
|
||||||
result = self._get_connection().execute("MATCH (n:Node) RETURN COUNT(n)")
|
result = self._get_connection().execute("MATCH (n:Node) RETURN COUNT(n)")
|
||||||
if not isinstance(result, list):
|
if not isinstance(result, list):
|
||||||
|
|
|
||||||
|
|
@ -73,13 +73,15 @@ class RemoteKuzuAdapter(KuzuAdapter):
|
||||||
status=response.status,
|
status=response.status,
|
||||||
message=error_detail,
|
message=error_detail,
|
||||||
)
|
)
|
||||||
return await response.json() # type: ignore
|
return await response.json() # type: ignore
|
||||||
except aiohttp.ClientError as e:
|
except aiohttp.ClientError as e:
|
||||||
logger.error(f"API request failed: {str(e)}")
|
logger.error(f"API request failed: {str(e)}")
|
||||||
logger.error(f"Request data: {data}")
|
logger.error(f"Request data: {data}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def query(self, query: str, params: Optional[dict[str, Any]] = None) -> List[Tuple[Any, ...]]:
|
async def query(
|
||||||
|
self, query: str, params: Optional[dict[str, Any]] = None
|
||||||
|
) -> List[Tuple[Any, ...]]:
|
||||||
"""Execute a Kuzu query via the REST API."""
|
"""Execute a Kuzu query via the REST API."""
|
||||||
try:
|
try:
|
||||||
# Initialize schema if needed
|
# Initialize schema if needed
|
||||||
|
|
|
||||||
|
|
@ -142,7 +142,9 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
)
|
)
|
||||||
return results[0]["node_exists"] if len(results) > 0 else False
|
return results[0]["node_exists"] if len(results) > 0 else False
|
||||||
|
|
||||||
async def add_node(self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None) -> None:
|
async def add_node(
|
||||||
|
self, node: Union[DataPoint, str], properties: Optional[Dict[str, Any]] = None
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Add a new node to the database based on the provided DataPoint object or string ID.
|
Add a new node to the database based on the provided DataPoint object or string ID.
|
||||||
|
|
||||||
|
|
@ -407,7 +409,7 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
)
|
)
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"from_node": str(source_id), # Adding str as callsites may still be passing UUID
|
"from_node": str(source_id), # Adding str as callsites may still be passing UUID
|
||||||
"to_node": str(target_id),
|
"to_node": str(target_id),
|
||||||
"relationship_name": relationship_name,
|
"relationship_name": relationship_name,
|
||||||
"properties": serialized_properties,
|
"properties": serialized_properties,
|
||||||
|
|
@ -482,8 +484,8 @@ class Neo4jAdapter(GraphDBInterface):
|
||||||
|
|
||||||
edge_params = [
|
edge_params = [
|
||||||
{
|
{
|
||||||
"from_node": str(edge[0]), # Adding str as callsites may still be passing UUID
|
"from_node": str(edge[0]), # Adding str as callsites may still be passing UUID
|
||||||
"to_node": str(edge[1]), # Adding str as callsites may still be passing UUID
|
"to_node": str(edge[1]), # Adding str as callsites may still be passing UUID
|
||||||
"relationship_name": edge[2],
|
"relationship_name": edge[2],
|
||||||
"properties": self._flatten_edge_properties(
|
"properties": self._flatten_edge_properties(
|
||||||
{
|
{
|
||||||
|
|
|
||||||
|
|
@ -217,7 +217,9 @@ class ChromaDBAdapter(VectorDBInterface):
|
||||||
collections = await self.get_collection_names()
|
collections = await self.get_collection_names()
|
||||||
return collection_name in collections
|
return collection_name in collections
|
||||||
|
|
||||||
async def create_collection(self, collection_name: str, payload_schema: Optional[Any] = None) -> None:
|
async def create_collection(
|
||||||
|
self, collection_name: str, payload_schema: Optional[Any] = None
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Create a new collection in ChromaDB if it does not already exist.
|
Create a new collection in ChromaDB if it does not already exist.
|
||||||
|
|
||||||
|
|
@ -313,10 +315,7 @@ class ChromaDBAdapter(VectorDBInterface):
|
||||||
[
|
[
|
||||||
IndexSchema(
|
IndexSchema(
|
||||||
id=data_point.id,
|
id=data_point.id,
|
||||||
text=getattr(
|
text=getattr(data_point, data_point.metadata["index_fields"][0]),
|
||||||
data_point,
|
|
||||||
data_point.metadata["index_fields"][0]
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
for data_point in data_points
|
for data_point in data_points
|
||||||
if data_point.metadata and len(data_point.metadata["index_fields"]) > 0
|
if data_point.metadata and len(data_point.metadata["index_fields"]) > 0
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,18 @@ from uuid import UUID
|
||||||
import lancedb
|
import lancedb
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from lancedb.pydantic import LanceModel, Vector
|
from lancedb.pydantic import LanceModel, Vector
|
||||||
from typing import Generic, List, Optional, TypeVar, Union, get_args, get_origin, get_type_hints, Dict, Any
|
from typing import (
|
||||||
|
Generic,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
get_args,
|
||||||
|
get_origin,
|
||||||
|
get_type_hints,
|
||||||
|
Dict,
|
||||||
|
Any,
|
||||||
|
)
|
||||||
|
|
||||||
from cognee.infrastructure.databases.exceptions import MissingQueryParameterError
|
from cognee.infrastructure.databases.exceptions import MissingQueryParameterError
|
||||||
from cognee.infrastructure.engine import DataPoint
|
from cognee.infrastructure.engine import DataPoint
|
||||||
|
|
@ -46,7 +57,7 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
url: Optional[str], # TODO: consider if we want to make this required and/or api_key
|
url: Optional[str], # TODO: consider if we want to make this required and/or api_key
|
||||||
api_key: Optional[str],
|
api_key: Optional[str],
|
||||||
embedding_engine: EmbeddingEngine,
|
embedding_engine: EmbeddingEngine,
|
||||||
):
|
):
|
||||||
|
|
@ -109,7 +120,9 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
collection_names = await connection.table_names()
|
collection_names = await connection.table_names()
|
||||||
return collection_name in collection_names
|
return collection_name in collection_names
|
||||||
|
|
||||||
async def create_collection(self, collection_name: str, payload_schema: Optional[Any] = None) -> None:
|
async def create_collection(
|
||||||
|
self, collection_name: str, payload_schema: Optional[Any] = None
|
||||||
|
) -> None:
|
||||||
vector_size = self.embedding_engine.get_vector_size()
|
vector_size = self.embedding_engine.get_vector_size()
|
||||||
|
|
||||||
class LanceDataPoint(LanceModel):
|
class LanceDataPoint(LanceModel):
|
||||||
|
|
@ -123,7 +136,9 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
id: UUID
|
id: UUID
|
||||||
vector: Vector[vector_size] # TODO: double check and consider raising this later in Pydantic
|
vector: Vector[
|
||||||
|
vector_size
|
||||||
|
] # TODO: double check and consider raising this later in Pydantic
|
||||||
payload: Dict[str, Any]
|
payload: Dict[str, Any]
|
||||||
|
|
||||||
if not await self.has_collection(collection_name):
|
if not await self.has_collection(collection_name):
|
||||||
|
|
@ -300,10 +315,7 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
[
|
[
|
||||||
IndexSchema(
|
IndexSchema(
|
||||||
id=data_point.id,
|
id=data_point.id,
|
||||||
text=getattr(
|
text=getattr(data_point, data_point.metadata["index_fields"][0]),
|
||||||
data_point,
|
|
||||||
data_point.metadata["index_fields"][0]
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
for data_point in data_points
|
for data_point in data_points
|
||||||
if data_point.metadata and len(data_point.metadata.get("index_fields", [])) > 0
|
if data_point.metadata and len(data_point.metadata.get("index_fields", [])) > 0
|
||||||
|
|
@ -327,7 +339,7 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
def get_data_point_schema(self, model_type: Optional[Any]) -> Any:
|
def get_data_point_schema(self, model_type: Optional[Any]) -> Any:
|
||||||
if model_type is None:
|
if model_type is None:
|
||||||
return DataPoint
|
return DataPoint
|
||||||
|
|
||||||
related_models_fields = []
|
related_models_fields = []
|
||||||
|
|
||||||
for field_name, field_config in model_type.model_fields.items():
|
for field_name, field_config in model_type.model_fields.items():
|
||||||
|
|
|
||||||
|
|
@ -123,7 +123,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
stop=stop_after_attempt(5),
|
stop=stop_after_attempt(5),
|
||||||
wait=wait_exponential(multiplier=2, min=1, max=6),
|
wait=wait_exponential(multiplier=2, min=1, max=6),
|
||||||
)
|
)
|
||||||
async def create_collection(self, collection_name: str, payload_schema: Optional[Any] = None) -> None:
|
async def create_collection(
|
||||||
|
self, collection_name: str, payload_schema: Optional[Any] = None
|
||||||
|
) -> None:
|
||||||
vector_size = self.embedding_engine.get_vector_size()
|
vector_size = self.embedding_engine.get_vector_size()
|
||||||
|
|
||||||
async with self.VECTOR_DB_LOCK:
|
async with self.VECTOR_DB_LOCK:
|
||||||
|
|
@ -151,7 +153,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
payload: Mapped[Dict[str, Any]] = mapped_column(JSON)
|
payload: Mapped[Dict[str, Any]] = mapped_column(JSON)
|
||||||
vector: Mapped[List[float]] = mapped_column(self.Vector(vector_size))
|
vector: Mapped[List[float]] = mapped_column(self.Vector(vector_size))
|
||||||
|
|
||||||
def __init__(self, id: str, payload: Dict[str, Any], vector: List[float]) -> None:
|
def __init__(
|
||||||
|
self, id: str, payload: Dict[str, Any], vector: List[float]
|
||||||
|
) -> None:
|
||||||
self.id = id
|
self.id = id
|
||||||
self.payload = payload
|
self.payload = payload
|
||||||
self.vector = vector
|
self.vector = vector
|
||||||
|
|
@ -159,10 +163,9 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
async with self.engine.begin() as connection:
|
async with self.engine.begin() as connection:
|
||||||
if len(Base.metadata.tables.keys()) > 0:
|
if len(Base.metadata.tables.keys()) > 0:
|
||||||
from sqlalchemy import Table
|
from sqlalchemy import Table
|
||||||
|
|
||||||
table: Table = PGVectorDataPoint.__table__ # type: ignore
|
table: Table = PGVectorDataPoint.__table__ # type: ignore
|
||||||
await connection.run_sync(
|
await connection.run_sync(Base.metadata.create_all, tables=[table])
|
||||||
Base.metadata.create_all, tables=[table]
|
|
||||||
)
|
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
retry=retry_if_exception_type(DeadlockDetectedError),
|
retry=retry_if_exception_type(DeadlockDetectedError),
|
||||||
|
|
@ -351,11 +354,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
|
|
||||||
# Create and return ScoredResult objects
|
# Create and return ScoredResult objects
|
||||||
return [
|
return [
|
||||||
ScoredResult(
|
ScoredResult(id=row["id"], payload=row["payload"] or {}, score=row["score"])
|
||||||
id=row["id"],
|
|
||||||
payload=row["payload"] or {},
|
|
||||||
score=row["score"]
|
|
||||||
)
|
|
||||||
for row in vector_list
|
for row in vector_list
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue