feat: add falkordb adapter

This commit is contained in:
Boris Arzentar 2024-10-24 12:37:06 +02:00
parent 1088b58f11
commit c901fa8b8a
6 changed files with 173 additions and 29 deletions

View file

@ -7,7 +7,6 @@ from contextlib import asynccontextmanager
from neo4j import AsyncSession from neo4j import AsyncSession
from neo4j import AsyncGraphDatabase from neo4j import AsyncGraphDatabase
from neo4j.exceptions import Neo4jError from neo4j.exceptions import Neo4jError
from networkx import predecessor
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
logger = logging.getLogger("Neo4jAdapter") logger = logging.getLogger("Neo4jAdapter")

View file

@ -1,57 +1,113 @@
import asyncio
from typing import List, Dict, Optional, Any
from falkordb import FalkorDB from falkordb import FalkorDB
from qdrant_client import AsyncQdrantClient, models
from ..vector_db_interface import VectorDBInterface
from ..models.DataPoint import DataPoint from ..models.DataPoint import DataPoint
from ..vector_db_interface import VectorDBInterface
from ..embeddings.EmbeddingEngine import EmbeddingEngine from ..embeddings.EmbeddingEngine import EmbeddingEngine
class FalcorDBAdapter(VectorDBInterface): class FalcorDBAdapter(VectorDBInterface):
def __init__( def __init__(
self, self,
graph_database_url: str, graph_database_url: str,
graph_database_username: str,
graph_database_password: str,
graph_database_port: int, graph_database_port: int,
driver: Optional[Any] = None,
embedding_engine = EmbeddingEngine, embedding_engine = EmbeddingEngine,
graph_name: str = "DefaultGraph",
): ):
self.driver = FalkorDB( self.driver = FalkorDB(
host = graph_database_url, host = graph_database_url,
port = graph_database_port) port = graph_database_port)
self.graph_name = graph_name
self.embedding_engine = embedding_engine self.embedding_engine = embedding_engine
async def embed_data(self, data: list[str]) -> list[list[float]]: async def embed_data(self, data: list[str]) -> list[list[float]]:
return await self.embedding_engine.embed_text(data) return await self.embedding_engine.embed_text(data)
async def has_collection(self, collection_name: str) -> bool:
collections = self.driver.list_graphs()
return collection_name in collections
async def create_collection(self, collection_name: str, payload_schema = None): async def create_collection(self, collection_name: str, payload_schema = None):
pass self.driver.select_graph(collection_name)
async def create_data_points(self, collection_name: str, data_points: list[DataPoint]):
graph = self.driver.select_graph(collection_name)
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]): def stringify_properties(properties: dict) -> str:
pass return ",".join(f"{key}:'{value}'" for key, value in properties.items())
def create_data_point_query(data_point: DataPoint):
node_label = type(data_point.payload).__name__
node_properties = stringify_properties(data_point.payload.dict())
return f"""CREATE (:{node_label} {{{node_properties}}})"""
query = " ".join([create_data_point_query(data_point) for data_point in data_points])
graph.query(query)
async def retrieve(self, collection_name: str, data_point_ids: list[str]): async def retrieve(self, collection_name: str, data_point_ids: list[str]):
pass graph = self.driver.select_graph(collection_name)
return graph.query(
f"MATCH (node) WHERE node.id IN $node_ids RETURN node",
{
"node_ids": data_point_ids,
},
)
async def search( async def search(
self, self,
collection_name: str, collection_name: str,
query_text: str = None, query_text: str = None,
query_vector: List[float] = None, query_vector: list[float] = None,
limit: int = 10, limit: int = 10,
with_vector: bool = False, with_vector: bool = False,
): ):
pass if query_text is None and query_vector is None:
raise ValueError("One of query_text or query_vector must be provided!")
if query_text and not query_vector:
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
graph = self.driver.select_graph(collection_name)
query = f"""
CALL db.idx.vector.queryNodes(
null,
'text',
{limit},
{query_vector}
) YIELD node, score
"""
result = graph.query(query)
return result
async def batch_search(
self,
collection_name: str,
query_texts: list[str],
limit: int = None,
with_vectors: bool = False,
):
query_vectors = await self.embedding_engine.embed_text(query_texts)
return await asyncio.gather(
*[self.search(
collection_name = collection_name,
query_vector = query_vector,
limit = limit,
with_vector = with_vectors,
) for query_vector in query_vectors]
)
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]): async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
pass graph = self.driver.select_graph(collection_name)
return graph.query(
f"MATCH (node) WHERE node.id IN $node_ids DETACH DELETE node",
{
"node_ids": data_point_ids,
},
)

View file

@ -1,3 +1,4 @@
from .Document import Document
from .PdfDocument import PdfDocument from .PdfDocument import PdfDocument
from .TextDocument import TextDocument from .TextDocument import TextDocument
from .ImageDocument import ImageDocument from .ImageDocument import ImageDocument

View file

@ -0,0 +1,62 @@
from typing import Optional
from uuid import UUID
from datetime import datetime
from pydantic import BaseModel
async def add_data_points(collection_name: str, data_points: list):
pass
class Summary(BaseModel):
id: UUID
text: str
chunk: "Chunk"
created_at: datetime
updated_at: Optional[datetime]
vector_index = ["text"]
class Chunk(BaseModel):
id: UUID
text: str
summary: Summary
document: "Document"
created_at: datetime
updated_at: Optional[datetime]
word_count: int
chunk_index: int
cut_type: str
vector_index = ["text"]
class Document(BaseModel):
id: UUID
chunks: list[Chunk]
created_at: datetime
updated_at: Optional[datetime]
class EntityType(BaseModel):
id: UUID
name: str
description: str
created_at: datetime
updated_at: Optional[datetime]
vector_index = ["name"]
class Entity(BaseModel):
id: UUID
name: str
type: EntityType
description: str
chunks: list[Chunk]
created_at: datetime
updated_at: Optional[datetime]
vector_index = ["name"]
class OntologyModel(BaseModel):
chunks: list[Chunk]

41
poetry.lock generated
View file

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. # This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
[[package]] [[package]]
name = "aiofiles" name = "aiofiles"
@ -1490,6 +1490,19 @@ files = [
[package.extras] [package.extras]
tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipython", "littleutils", "pytest", "rich"] tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipython", "littleutils", "pytest", "rich"]
[[package]]
name = "falkordb"
version = "1.0.9"
description = "Python client for interacting with FalkorDB database"
optional = false
python-versions = "<4.0,>=3.8"
files = [
{file = "falkordb-1.0.9.tar.gz", hash = "sha256:177008e63c7e4d9ebbdfeb8cad24b0e49175bb0f6e96cac9b4ffb641c0eff0f1"},
]
[package.dependencies]
redis = ">=5.0.1,<6.0.0"
[[package]] [[package]]
name = "fastapi" name = "fastapi"
version = "0.109.2" version = "0.109.2"
@ -3685,7 +3698,6 @@ optional = false
python-versions = ">=3.6" python-versions = ">=3.6"
files = [ files = [
{file = "mkdocs-redirects-1.2.1.tar.gz", hash = "sha256:9420066d70e2a6bb357adf86e67023dcdca1857f97f07c7fe450f8f1fb42f861"}, {file = "mkdocs-redirects-1.2.1.tar.gz", hash = "sha256:9420066d70e2a6bb357adf86e67023dcdca1857f97f07c7fe450f8f1fb42f861"},
{file = "mkdocs_redirects-1.2.1-py3-none-any.whl", hash = "sha256:497089f9e0219e7389304cffefccdfa1cac5ff9509f2cb706f4c9b221726dffb"},
] ]
[package.dependencies] [package.dependencies]
@ -5771,6 +5783,24 @@ files = [
[package.extras] [package.extras]
test = ["pytest (>=3.0)", "pytest-asyncio"] test = ["pytest (>=3.0)", "pytest-asyncio"]
[[package]]
name = "redis"
version = "5.1.1"
description = "Python client for Redis database and key-value store"
optional = false
python-versions = ">=3.8"
files = [
{file = "redis-5.1.1-py3-none-any.whl", hash = "sha256:f8ea06b7482a668c6475ae202ed8d9bcaa409f6e87fb77ed1043d912afd62e24"},
{file = "redis-5.1.1.tar.gz", hash = "sha256:f6c997521fedbae53387307c5d0bf784d9acc28d9f1d058abeac566ec4dbed72"},
]
[package.dependencies]
async-timeout = {version = ">=4.0.3", markers = "python_full_version < \"3.11.3\""}
[package.extras]
hiredis = ["hiredis (>=3.0.0)"]
ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==23.2.1)", "requests (>=2.31.0)"]
[[package]] [[package]]
name = "referencing" name = "referencing"
version = "0.35.1" version = "0.35.1"
@ -6292,11 +6322,6 @@ files = [
{file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f60021ec1574e56632be2a36b946f8143bf4e5e6af4a06d85281adc22938e0dd"}, {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f60021ec1574e56632be2a36b946f8143bf4e5e6af4a06d85281adc22938e0dd"},
{file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:394397841449853c2290a32050382edaec3da89e35b3e03d6cc966aebc6a8ae6"}, {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:394397841449853c2290a32050382edaec3da89e35b3e03d6cc966aebc6a8ae6"},
{file = "scikit_learn-1.5.2-cp312-cp312-win_amd64.whl", hash = "sha256:57cc1786cfd6bd118220a92ede80270132aa353647684efa385a74244a41e3b1"}, {file = "scikit_learn-1.5.2-cp312-cp312-win_amd64.whl", hash = "sha256:57cc1786cfd6bd118220a92ede80270132aa353647684efa385a74244a41e3b1"},
{file = "scikit_learn-1.5.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e9a702e2de732bbb20d3bad29ebd77fc05a6b427dc49964300340e4c9328b3f5"},
{file = "scikit_learn-1.5.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:b0768ad641981f5d3a198430a1d31c3e044ed2e8a6f22166b4d546a5116d7908"},
{file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:178ddd0a5cb0044464fc1bfc4cca5b1833bfc7bb022d70b05db8530da4bb3dd3"},
{file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7284ade780084d94505632241bf78c44ab3b6f1e8ccab3d2af58e0e950f9c12"},
{file = "scikit_learn-1.5.2-cp313-cp313-win_amd64.whl", hash = "sha256:b7b0f9a0b1040830d38c39b91b3a44e1b643f4b36e36567b80b7c6bd2202a27f"},
{file = "scikit_learn-1.5.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:757c7d514ddb00ae249832fe87100d9c73c6ea91423802872d9e74970a0e40b9"}, {file = "scikit_learn-1.5.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:757c7d514ddb00ae249832fe87100d9c73c6ea91423802872d9e74970a0e40b9"},
{file = "scikit_learn-1.5.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:52788f48b5d8bca5c0736c175fa6bdaab2ef00a8f536cda698db61bd89c551c1"}, {file = "scikit_learn-1.5.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:52788f48b5d8bca5c0736c175fa6bdaab2ef00a8f536cda698db61bd89c551c1"},
{file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:643964678f4b5fbdc95cbf8aec638acc7aa70f5f79ee2cdad1eec3df4ba6ead8"}, {file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:643964678f4b5fbdc95cbf8aec638acc7aa70f5f79ee2cdad1eec3df4ba6ead8"},
@ -7766,4 +7791,4 @@ weaviate = ["weaviate-client"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.9.0,<3.12" python-versions = ">=3.9.0,<3.12"
content-hash = "70a0072dce8de95d64b862f9a9df48aaec84c8d8515ae018fce4426a0dcacf88" content-hash = "fef56656ead761cab7d5c3d0bf1fa5a54608db73b14616d08e5fb152dba91236"

View file

@ -72,6 +72,7 @@ asyncpg = "^0.29.0"
alembic = "^1.13.3" alembic = "^1.13.3"
pgvector = "^0.3.5" pgvector = "^0.3.5"
psycopg2 = {version = "^2.9.10", optional = true} psycopg2 = {version = "^2.9.10", optional = true}
falkordb = "^1.0.9"
[tool.poetry.extras] [tool.poetry.extras]
filesystem = ["s3fs", "botocore"] filesystem = ["s3fs", "botocore"]