feat: add falkordb adapter

This commit is contained in:
Boris Arzentar 2024-10-24 12:37:06 +02:00 committed by Leon Luithlen
parent 52180eb6b5
commit 73372df31e
5 changed files with 215 additions and 4 deletions

View file

@ -8,7 +8,6 @@ from uuid import UUID
from neo4j import AsyncSession from neo4j import AsyncSession
from neo4j import AsyncGraphDatabase from neo4j import AsyncGraphDatabase
from neo4j.exceptions import Neo4jError from neo4j.exceptions import Neo4jError
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
logger = logging.getLogger("Neo4jAdapter") logger = logging.getLogger("Neo4jAdapter")

View file

@ -0,0 +1,113 @@
import asyncio
from falkordb import FalkorDB
from ..models.DataPoint import DataPoint
from ..vector_db_interface import VectorDBInterface
from ..embeddings.EmbeddingEngine import EmbeddingEngine
class FalcorDBAdapter(VectorDBInterface):
def __init__(
self,
graph_database_url: str,
graph_database_port: int,
embedding_engine = EmbeddingEngine,
):
self.driver = FalkorDB(
host = graph_database_url,
port = graph_database_port)
self.embedding_engine = embedding_engine
async def embed_data(self, data: list[str]) -> list[list[float]]:
return await self.embedding_engine.embed_text(data)
async def has_collection(self, collection_name: str) -> bool:
collections = self.driver.list_graphs()
return collection_name in collections
async def create_collection(self, collection_name: str, payload_schema = None):
self.driver.select_graph(collection_name)
async def create_data_points(self, collection_name: str, data_points: list[DataPoint]):
graph = self.driver.select_graph(collection_name)
def stringify_properties(properties: dict) -> str:
return ",".join(f"{key}:'{value}'" for key, value in properties.items())
def create_data_point_query(data_point: DataPoint):
node_label = type(data_point.payload).__name__
node_properties = stringify_properties(data_point.payload.dict())
return f"""CREATE (:{node_label} {{{node_properties}}})"""
query = " ".join([create_data_point_query(data_point) for data_point in data_points])
graph.query(query)
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
graph = self.driver.select_graph(collection_name)
return graph.query(
f"MATCH (node) WHERE node.id IN $node_ids RETURN node",
{
"node_ids": data_point_ids,
},
)
async def search(
self,
collection_name: str,
query_text: str = None,
query_vector: list[float] = None,
limit: int = 10,
with_vector: bool = False,
):
if query_text is None and query_vector is None:
raise ValueError("One of query_text or query_vector must be provided!")
if query_text and not query_vector:
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
graph = self.driver.select_graph(collection_name)
query = f"""
CALL db.idx.vector.queryNodes(
null,
'text',
{limit},
{query_vector}
) YIELD node, score
"""
result = graph.query(query)
return result
async def batch_search(
self,
collection_name: str,
query_texts: list[str],
limit: int = None,
with_vectors: bool = False,
):
query_vectors = await self.embedding_engine.embed_text(query_texts)
return await asyncio.gather(
*[self.search(
collection_name = collection_name,
query_vector = query_vector,
limit = limit,
with_vector = with_vectors,
) for query_vector in query_vectors]
)
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
graph = self.driver.select_graph(collection_name)
return graph.query(
f"MATCH (node) WHERE node.id IN $node_ids DETACH DELETE node",
{
"node_ids": data_point_ids,
},
)

View file

@ -0,0 +1,62 @@
from typing import Optional
from uuid import UUID
from datetime import datetime
from pydantic import BaseModel
async def add_data_points(collection_name: str, data_points: list):
pass
class Summary(BaseModel):
id: UUID
text: str
chunk: "Chunk"
created_at: datetime
updated_at: Optional[datetime]
vector_index = ["text"]
class Chunk(BaseModel):
id: UUID
text: str
summary: Summary
document: "Document"
created_at: datetime
updated_at: Optional[datetime]
word_count: int
chunk_index: int
cut_type: str
vector_index = ["text"]
class Document(BaseModel):
id: UUID
chunks: list[Chunk]
created_at: datetime
updated_at: Optional[datetime]
class EntityType(BaseModel):
id: UUID
name: str
description: str
created_at: datetime
updated_at: Optional[datetime]
vector_index = ["name"]
class Entity(BaseModel):
id: UUID
name: str
type: EntityType
description: str
chunks: list[Chunk]
created_at: datetime
updated_at: Optional[datetime]
vector_index = ["name"]
class OntologyModel(BaseModel):
chunks: list[Chunk]

40
poetry.lock generated
View file

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. # This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
[[package]] [[package]]
name = "aiofiles" name = "aiofiles"
@ -3729,6 +3729,24 @@ htmlmin2 = ">=0.1.13"
jsmin = ">=3.0.1" jsmin = ">=3.0.1"
mkdocs = ">=1.4.1" mkdocs = ">=1.4.1"
[[package]]
name = "mkdocs-redirects"
version = "1.2.1"
description = "A MkDocs plugin for dynamic page redirects to prevent broken links."
optional = false
python-versions = ">=3.6"
files = [
{file = "mkdocs-redirects-1.2.1.tar.gz", hash = "sha256:9420066d70e2a6bb357adf86e67023dcdca1857f97f07c7fe450f8f1fb42f861"},
]
[package.dependencies]
mkdocs = ">=1.1.1"
[package.extras]
dev = ["autoflake", "black", "isort", "pytest", "twine (>=1.13.0)"]
release = ["twine (>=1.13.0)"]
test = ["autoflake", "black", "isort", "pytest"]
[[package]] [[package]]
name = "mkdocstrings" name = "mkdocstrings"
version = "0.26.2" version = "0.26.2"
@ -5799,6 +5817,24 @@ async-timeout = {version = ">=4.0.3", markers = "python_full_version < \"3.11.3\
hiredis = ["hiredis (>=3.0.0)"] hiredis = ["hiredis (>=3.0.0)"]
ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==23.2.1)", "requests (>=2.31.0)"] ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==23.2.1)", "requests (>=2.31.0)"]
[[package]]
name = "redis"
version = "5.1.1"
description = "Python client for Redis database and key-value store"
optional = false
python-versions = ">=3.8"
files = [
{file = "redis-5.1.1-py3-none-any.whl", hash = "sha256:f8ea06b7482a668c6475ae202ed8d9bcaa409f6e87fb77ed1043d912afd62e24"},
{file = "redis-5.1.1.tar.gz", hash = "sha256:f6c997521fedbae53387307c5d0bf784d9acc28d9f1d058abeac566ec4dbed72"},
]
[package.dependencies]
async-timeout = {version = ">=4.0.3", markers = "python_full_version < \"3.11.3\""}
[package.extras]
hiredis = ["hiredis (>=3.0.0)"]
ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==23.2.1)", "requests (>=2.31.0)"]
[[package]] [[package]]
name = "referencing" name = "referencing"
version = "0.35.1" version = "0.35.1"
@ -7746,4 +7782,4 @@ weaviate = ["weaviate-client"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.9.0,<3.12" python-versions = ">=3.9.0,<3.12"
content-hash = "fb09733ff7a70fb91c5f72ff0c8a8137b857557930a7aa025aad3154de4d8ceb" content-hash = "fef56656ead761cab7d5c3d0bf1fa5a54608db73b14616d08e5fb152dba91236"

View file

@ -69,7 +69,8 @@ fastapi-users = {version = "*", extras = ["sqlalchemy"]}
alembic = "^1.13.3" alembic = "^1.13.3"
asyncpg = "^0.29.0" asyncpg = "^0.29.0"
pgvector = "^0.3.5" pgvector = "^0.3.5"
psycopg2 = "^2.9.10" psycopg2 = {version = "^2.9.10", optional = true}
falkordb = "^1.0.9"
[tool.poetry.extras] [tool.poetry.extras]
filesystem = ["s3fs", "botocore"] filesystem = ["s3fs", "botocore"]