From 73372df31e0d4f68f0b0de267b1ae15858891df8 Mon Sep 17 00:00:00 2001 From: Boris Arzentar Date: Thu, 24 Oct 2024 12:37:06 +0200 Subject: [PATCH] feat: add falkordb adapter --- .../databases/graph/neo4j_driver/adapter.py | 1 - .../vector/falkordb/FalkorDBAdapter.py | 113 ++++++++++++++++++ examples/python/GraphModel.py | 62 ++++++++++ poetry.lock | 40 ++++++- pyproject.toml | 3 +- 5 files changed, 215 insertions(+), 4 deletions(-) create mode 100644 cognee/infrastructure/databases/vector/falkordb/FalkorDBAdapter.py create mode 100644 examples/python/GraphModel.py diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py index 26bbb5819..8831591d5 100644 --- a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +++ b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py @@ -8,7 +8,6 @@ from uuid import UUID from neo4j import AsyncSession from neo4j import AsyncGraphDatabase from neo4j.exceptions import Neo4jError -from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface logger = logging.getLogger("Neo4jAdapter") diff --git a/cognee/infrastructure/databases/vector/falkordb/FalkorDBAdapter.py b/cognee/infrastructure/databases/vector/falkordb/FalkorDBAdapter.py new file mode 100644 index 000000000..744d79f53 --- /dev/null +++ b/cognee/infrastructure/databases/vector/falkordb/FalkorDBAdapter.py @@ -0,0 +1,113 @@ +import asyncio +from falkordb import FalkorDB +from ..models.DataPoint import DataPoint +from ..vector_db_interface import VectorDBInterface +from ..embeddings.EmbeddingEngine import EmbeddingEngine + + +class FalcorDBAdapter(VectorDBInterface): + def __init__( + self, + graph_database_url: str, + graph_database_port: int, + embedding_engine = EmbeddingEngine, + ): + self.driver = FalkorDB( + host = graph_database_url, + port = graph_database_port) + self.embedding_engine = embedding_engine + + + async def embed_data(self, data: list[str]) -> list[list[float]]: + return await self.embedding_engine.embed_text(data) + + async def has_collection(self, collection_name: str) -> bool: + collections = self.driver.list_graphs() + + return collection_name in collections + + async def create_collection(self, collection_name: str, payload_schema = None): + self.driver.select_graph(collection_name) + + async def create_data_points(self, collection_name: str, data_points: list[DataPoint]): + graph = self.driver.select_graph(collection_name) + + def stringify_properties(properties: dict) -> str: + return ",".join(f"{key}:'{value}'" for key, value in properties.items()) + + def create_data_point_query(data_point: DataPoint): + node_label = type(data_point.payload).__name__ + node_properties = stringify_properties(data_point.payload.dict()) + + return f"""CREATE (:{node_label} {{{node_properties}}})""" + + query = " ".join([create_data_point_query(data_point) for data_point in data_points]) + + graph.query(query) + + async def retrieve(self, collection_name: str, data_point_ids: list[str]): + graph = self.driver.select_graph(collection_name) + + return graph.query( + f"MATCH (node) WHERE node.id IN $node_ids RETURN node", + { + "node_ids": data_point_ids, + }, + ) + + async def search( + self, + collection_name: str, + query_text: str = None, + query_vector: list[float] = None, + limit: int = 10, + with_vector: bool = False, + ): + if query_text is None and query_vector is None: + raise ValueError("One of query_text or query_vector must be provided!") + + if query_text and not query_vector: + query_vector = (await self.embedding_engine.embed_text([query_text]))[0] + + graph = self.driver.select_graph(collection_name) + + query = f""" + CALL db.idx.vector.queryNodes( + null, + 'text', + {limit}, + {query_vector} + ) YIELD node, score + """ + + result = graph.query(query) + + return result + + async def batch_search( + self, + collection_name: str, + query_texts: list[str], + limit: int = None, + with_vectors: bool = False, + ): + query_vectors = await self.embedding_engine.embed_text(query_texts) + + return await asyncio.gather( + *[self.search( + collection_name = collection_name, + query_vector = query_vector, + limit = limit, + with_vector = with_vectors, + ) for query_vector in query_vectors] + ) + + async def delete_data_points(self, collection_name: str, data_point_ids: list[str]): + graph = self.driver.select_graph(collection_name) + + return graph.query( + f"MATCH (node) WHERE node.id IN $node_ids DETACH DELETE node", + { + "node_ids": data_point_ids, + }, + ) diff --git a/examples/python/GraphModel.py b/examples/python/GraphModel.py new file mode 100644 index 000000000..01251fc20 --- /dev/null +++ b/examples/python/GraphModel.py @@ -0,0 +1,62 @@ + +from typing import Optional +from uuid import UUID +from datetime import datetime +from pydantic import BaseModel + + +async def add_data_points(collection_name: str, data_points: list): + pass + + + +class Summary(BaseModel): + id: UUID + text: str + chunk: "Chunk" + created_at: datetime + updated_at: Optional[datetime] + + vector_index = ["text"] + +class Chunk(BaseModel): + id: UUID + text: str + summary: Summary + document: "Document" + created_at: datetime + updated_at: Optional[datetime] + word_count: int + chunk_index: int + cut_type: str + + vector_index = ["text"] + +class Document(BaseModel): + id: UUID + chunks: list[Chunk] + created_at: datetime + updated_at: Optional[datetime] + +class EntityType(BaseModel): + id: UUID + name: str + description: str + created_at: datetime + updated_at: Optional[datetime] + + vector_index = ["name"] + +class Entity(BaseModel): + id: UUID + name: str + type: EntityType + description: str + chunks: list[Chunk] + created_at: datetime + updated_at: Optional[datetime] + + vector_index = ["name"] + +class OntologyModel(BaseModel): + chunks: list[Chunk] diff --git a/poetry.lock b/poetry.lock index 270e66027..12b1e59ba 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "aiofiles" @@ -3729,6 +3729,24 @@ htmlmin2 = ">=0.1.13" jsmin = ">=3.0.1" mkdocs = ">=1.4.1" +[[package]] +name = "mkdocs-redirects" +version = "1.2.1" +description = "A MkDocs plugin for dynamic page redirects to prevent broken links." +optional = false +python-versions = ">=3.6" +files = [ + {file = "mkdocs-redirects-1.2.1.tar.gz", hash = "sha256:9420066d70e2a6bb357adf86e67023dcdca1857f97f07c7fe450f8f1fb42f861"}, +] + +[package.dependencies] +mkdocs = ">=1.1.1" + +[package.extras] +dev = ["autoflake", "black", "isort", "pytest", "twine (>=1.13.0)"] +release = ["twine (>=1.13.0)"] +test = ["autoflake", "black", "isort", "pytest"] + [[package]] name = "mkdocstrings" version = "0.26.2" @@ -5799,6 +5817,24 @@ async-timeout = {version = ">=4.0.3", markers = "python_full_version < \"3.11.3\ hiredis = ["hiredis (>=3.0.0)"] ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==23.2.1)", "requests (>=2.31.0)"] +[[package]] +name = "redis" +version = "5.1.1" +description = "Python client for Redis database and key-value store" +optional = false +python-versions = ">=3.8" +files = [ + {file = "redis-5.1.1-py3-none-any.whl", hash = "sha256:f8ea06b7482a668c6475ae202ed8d9bcaa409f6e87fb77ed1043d912afd62e24"}, + {file = "redis-5.1.1.tar.gz", hash = "sha256:f6c997521fedbae53387307c5d0bf784d9acc28d9f1d058abeac566ec4dbed72"}, +] + +[package.dependencies] +async-timeout = {version = ">=4.0.3", markers = "python_full_version < \"3.11.3\""} + +[package.extras] +hiredis = ["hiredis (>=3.0.0)"] +ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==23.2.1)", "requests (>=2.31.0)"] + [[package]] name = "referencing" version = "0.35.1" @@ -7746,4 +7782,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.9.0,<3.12" -content-hash = "fb09733ff7a70fb91c5f72ff0c8a8137b857557930a7aa025aad3154de4d8ceb" +content-hash = "fef56656ead761cab7d5c3d0bf1fa5a54608db73b14616d08e5fb152dba91236" diff --git a/pyproject.toml b/pyproject.toml index 0bc3849b1..92d8f829b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,7 +69,8 @@ fastapi-users = {version = "*", extras = ["sqlalchemy"]} alembic = "^1.13.3" asyncpg = "^0.29.0" pgvector = "^0.3.5" -psycopg2 = "^2.9.10" +psycopg2 = {version = "^2.9.10", optional = true} +falkordb = "^1.0.9" [tool.poetry.extras] filesystem = ["s3fs", "botocore"]