From c901fa8b8acc23197aa46e885ce34a77c156b6e7 Mon Sep 17 00:00:00 2001 From: Boris Arzentar Date: Thu, 24 Oct 2024 12:37:06 +0200 Subject: [PATCH] feat: add falkordb adapter --- .../databases/graph/neo4j_driver/adapter.py | 1 - .../vector/falkordb/FalkorDBAdapter.py | 96 +++++++++++++++---- .../processing/document_types/__init__.py | 1 + examples/python/GraphModel.py | 62 ++++++++++++ poetry.lock | 41 ++++++-- pyproject.toml | 1 + 6 files changed, 173 insertions(+), 29 deletions(-) create mode 100644 examples/python/GraphModel.py diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py index f072d60fe..0b8925cee 100644 --- a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +++ b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py @@ -7,7 +7,6 @@ from contextlib import asynccontextmanager from neo4j import AsyncSession from neo4j import AsyncGraphDatabase from neo4j.exceptions import Neo4jError -from networkx import predecessor from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface logger = logging.getLogger("Neo4jAdapter") diff --git a/cognee/infrastructure/databases/vector/falkordb/FalkorDBAdapter.py b/cognee/infrastructure/databases/vector/falkordb/FalkorDBAdapter.py index 563219fec..744d79f53 100644 --- a/cognee/infrastructure/databases/vector/falkordb/FalkorDBAdapter.py +++ b/cognee/infrastructure/databases/vector/falkordb/FalkorDBAdapter.py @@ -1,57 +1,113 @@ - -from typing import List, Dict, Optional, Any - +import asyncio from falkordb import FalkorDB -from qdrant_client import AsyncQdrantClient, models -from ..vector_db_interface import VectorDBInterface from ..models.DataPoint import DataPoint +from ..vector_db_interface import VectorDBInterface from ..embeddings.EmbeddingEngine import EmbeddingEngine - - class FalcorDBAdapter(VectorDBInterface): def __init__( self, graph_database_url: str, - graph_database_username: str, - graph_database_password: str, graph_database_port: int, - driver: Optional[Any] = None, embedding_engine = EmbeddingEngine, - graph_name: str = "DefaultGraph", ): self.driver = FalkorDB( host = graph_database_url, port = graph_database_port) - self.graph_name = graph_name self.embedding_engine = embedding_engine - async def embed_data(self, data: list[str]) -> list[list[float]]: return await self.embedding_engine.embed_text(data) + async def has_collection(self, collection_name: str) -> bool: + collections = self.driver.list_graphs() + + return collection_name in collections async def create_collection(self, collection_name: str, payload_schema = None): - pass + self.driver.select_graph(collection_name) + async def create_data_points(self, collection_name: str, data_points: list[DataPoint]): + graph = self.driver.select_graph(collection_name) - async def create_data_points(self, collection_name: str, data_points: List[DataPoint]): - pass + def stringify_properties(properties: dict) -> str: + return ",".join(f"{key}:'{value}'" for key, value in properties.items()) + + def create_data_point_query(data_point: DataPoint): + node_label = type(data_point.payload).__name__ + node_properties = stringify_properties(data_point.payload.dict()) + + return f"""CREATE (:{node_label} {{{node_properties}}})""" + + query = " ".join([create_data_point_query(data_point) for data_point in data_points]) + + graph.query(query) async def retrieve(self, collection_name: str, data_point_ids: list[str]): - pass + graph = self.driver.select_graph(collection_name) + + return graph.query( + f"MATCH (node) WHERE node.id IN $node_ids RETURN node", + { + "node_ids": data_point_ids, + }, + ) async def search( self, collection_name: str, query_text: str = None, - query_vector: List[float] = None, + query_vector: list[float] = None, limit: int = 10, with_vector: bool = False, ): - pass + if query_text is None and query_vector is None: + raise ValueError("One of query_text or query_vector must be provided!") + + if query_text and not query_vector: + query_vector = (await self.embedding_engine.embed_text([query_text]))[0] + + graph = self.driver.select_graph(collection_name) + + query = f""" + CALL db.idx.vector.queryNodes( + null, + 'text', + {limit}, + {query_vector} + ) YIELD node, score + """ + + result = graph.query(query) + + return result + + async def batch_search( + self, + collection_name: str, + query_texts: list[str], + limit: int = None, + with_vectors: bool = False, + ): + query_vectors = await self.embedding_engine.embed_text(query_texts) + + return await asyncio.gather( + *[self.search( + collection_name = collection_name, + query_vector = query_vector, + limit = limit, + with_vector = with_vectors, + ) for query_vector in query_vectors] + ) async def delete_data_points(self, collection_name: str, data_point_ids: list[str]): - pass + graph = self.driver.select_graph(collection_name) + + return graph.query( + f"MATCH (node) WHERE node.id IN $node_ids DETACH DELETE node", + { + "node_ids": data_point_ids, + }, + ) diff --git a/cognee/modules/data/processing/document_types/__init__.py b/cognee/modules/data/processing/document_types/__init__.py index d751366b7..9682cc101 100644 --- a/cognee/modules/data/processing/document_types/__init__.py +++ b/cognee/modules/data/processing/document_types/__init__.py @@ -1,3 +1,4 @@ +from .Document import Document from .PdfDocument import PdfDocument from .TextDocument import TextDocument from .ImageDocument import ImageDocument diff --git a/examples/python/GraphModel.py b/examples/python/GraphModel.py new file mode 100644 index 000000000..01251fc20 --- /dev/null +++ b/examples/python/GraphModel.py @@ -0,0 +1,62 @@ + +from typing import Optional +from uuid import UUID +from datetime import datetime +from pydantic import BaseModel + + +async def add_data_points(collection_name: str, data_points: list): + pass + + + +class Summary(BaseModel): + id: UUID + text: str + chunk: "Chunk" + created_at: datetime + updated_at: Optional[datetime] + + vector_index = ["text"] + +class Chunk(BaseModel): + id: UUID + text: str + summary: Summary + document: "Document" + created_at: datetime + updated_at: Optional[datetime] + word_count: int + chunk_index: int + cut_type: str + + vector_index = ["text"] + +class Document(BaseModel): + id: UUID + chunks: list[Chunk] + created_at: datetime + updated_at: Optional[datetime] + +class EntityType(BaseModel): + id: UUID + name: str + description: str + created_at: datetime + updated_at: Optional[datetime] + + vector_index = ["name"] + +class Entity(BaseModel): + id: UUID + name: str + type: EntityType + description: str + chunks: list[Chunk] + created_at: datetime + updated_at: Optional[datetime] + + vector_index = ["name"] + +class OntologyModel(BaseModel): + chunks: list[Chunk] diff --git a/poetry.lock b/poetry.lock index acd56e02f..b8ff95c15 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "aiofiles" @@ -1490,6 +1490,19 @@ files = [ [package.extras] tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipython", "littleutils", "pytest", "rich"] +[[package]] +name = "falkordb" +version = "1.0.9" +description = "Python client for interacting with FalkorDB database" +optional = false +python-versions = "<4.0,>=3.8" +files = [ + {file = "falkordb-1.0.9.tar.gz", hash = "sha256:177008e63c7e4d9ebbdfeb8cad24b0e49175bb0f6e96cac9b4ffb641c0eff0f1"}, +] + +[package.dependencies] +redis = ">=5.0.1,<6.0.0" + [[package]] name = "fastapi" version = "0.109.2" @@ -3685,7 +3698,6 @@ optional = false python-versions = ">=3.6" files = [ {file = "mkdocs-redirects-1.2.1.tar.gz", hash = "sha256:9420066d70e2a6bb357adf86e67023dcdca1857f97f07c7fe450f8f1fb42f861"}, - {file = "mkdocs_redirects-1.2.1-py3-none-any.whl", hash = "sha256:497089f9e0219e7389304cffefccdfa1cac5ff9509f2cb706f4c9b221726dffb"}, ] [package.dependencies] @@ -5771,6 +5783,24 @@ files = [ [package.extras] test = ["pytest (>=3.0)", "pytest-asyncio"] +[[package]] +name = "redis" +version = "5.1.1" +description = "Python client for Redis database and key-value store" +optional = false +python-versions = ">=3.8" +files = [ + {file = "redis-5.1.1-py3-none-any.whl", hash = "sha256:f8ea06b7482a668c6475ae202ed8d9bcaa409f6e87fb77ed1043d912afd62e24"}, + {file = "redis-5.1.1.tar.gz", hash = "sha256:f6c997521fedbae53387307c5d0bf784d9acc28d9f1d058abeac566ec4dbed72"}, +] + +[package.dependencies] +async-timeout = {version = ">=4.0.3", markers = "python_full_version < \"3.11.3\""} + +[package.extras] +hiredis = ["hiredis (>=3.0.0)"] +ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==23.2.1)", "requests (>=2.31.0)"] + [[package]] name = "referencing" version = "0.35.1" @@ -6292,11 +6322,6 @@ files = [ {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f60021ec1574e56632be2a36b946f8143bf4e5e6af4a06d85281adc22938e0dd"}, {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:394397841449853c2290a32050382edaec3da89e35b3e03d6cc966aebc6a8ae6"}, {file = "scikit_learn-1.5.2-cp312-cp312-win_amd64.whl", hash = "sha256:57cc1786cfd6bd118220a92ede80270132aa353647684efa385a74244a41e3b1"}, - {file = "scikit_learn-1.5.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e9a702e2de732bbb20d3bad29ebd77fc05a6b427dc49964300340e4c9328b3f5"}, - {file = "scikit_learn-1.5.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:b0768ad641981f5d3a198430a1d31c3e044ed2e8a6f22166b4d546a5116d7908"}, - {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:178ddd0a5cb0044464fc1bfc4cca5b1833bfc7bb022d70b05db8530da4bb3dd3"}, - {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7284ade780084d94505632241bf78c44ab3b6f1e8ccab3d2af58e0e950f9c12"}, - {file = "scikit_learn-1.5.2-cp313-cp313-win_amd64.whl", hash = "sha256:b7b0f9a0b1040830d38c39b91b3a44e1b643f4b36e36567b80b7c6bd2202a27f"}, {file = "scikit_learn-1.5.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:757c7d514ddb00ae249832fe87100d9c73c6ea91423802872d9e74970a0e40b9"}, {file = "scikit_learn-1.5.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:52788f48b5d8bca5c0736c175fa6bdaab2ef00a8f536cda698db61bd89c551c1"}, {file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:643964678f4b5fbdc95cbf8aec638acc7aa70f5f79ee2cdad1eec3df4ba6ead8"}, @@ -7766,4 +7791,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.9.0,<3.12" -content-hash = "70a0072dce8de95d64b862f9a9df48aaec84c8d8515ae018fce4426a0dcacf88" +content-hash = "fef56656ead761cab7d5c3d0bf1fa5a54608db73b14616d08e5fb152dba91236" diff --git a/pyproject.toml b/pyproject.toml index 220749590..65d54978b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,6 +72,7 @@ asyncpg = "^0.29.0" alembic = "^1.13.3" pgvector = "^0.3.5" psycopg2 = {version = "^2.9.10", optional = true} +falkordb = "^1.0.9" [tool.poetry.extras] filesystem = ["s3fs", "botocore"]