fix: enable falkordb and add test for it (#31)
This commit is contained in:
parent
d885a047ac
commit
6403d15a76
11 changed files with 191 additions and 41 deletions
|
|
@ -21,7 +21,7 @@ async def get_graph_engine() -> GraphDBInterface:
|
|||
)
|
||||
|
||||
elif config.graph_database_provider == "falkordb":
|
||||
if not (config.graph_database_url and config.graph_database_username and config.graph_database_password):
|
||||
if not (config.graph_database_url and config.graph_database_port):
|
||||
raise EnvironmentError("Missing required FalkorDB credentials.")
|
||||
|
||||
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import asyncio
|
||||
from textwrap import dedent
|
||||
from typing import Any
|
||||
# from datetime import datetime
|
||||
from uuid import UUID
|
||||
from textwrap import dedent
|
||||
from falkordb import FalkorDB
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
|
|
@ -43,23 +43,31 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
|
|||
async def embed_data(self, data: list[str]) -> list[list[float]]:
|
||||
return await self.embedding_engine.embed_text(data)
|
||||
|
||||
async def stringify_properties(self, properties: dict, vectorize_fields = []) -> str:
|
||||
async def get_value(key, value):
|
||||
return f"'{value}'" if key not in vectorize_fields else await self.get_vectorized_value(value)
|
||||
async def stringify_properties(self, properties: dict) -> str:
|
||||
def parse_value(value):
|
||||
if type(value) is UUID:
|
||||
return f"'{str(value)}'"
|
||||
if type(value) is int or type(value) is float:
|
||||
return value
|
||||
if type(value) is list and type(value[0]) is float and len(value) == self.embedding_engine.get_vector_size():
|
||||
return f"'vecf32({value})'"
|
||||
# if type(value) is datetime:
|
||||
# return datetime.strptime(value, "%Y-%m-%dT%H:%M:%S.%f%z")
|
||||
return f"'{value}'"
|
||||
|
||||
return ",".join([f"{key}:{await get_value(key, value)}" for key, value in properties.items()])
|
||||
return ",".join([f"{key}:{parse_value(value)}" for key, value in properties.items()])
|
||||
|
||||
async def get_vectorized_value(self, value: Any) -> str:
|
||||
vector = (await self.embed_data([value]))[0]
|
||||
return f"vecf32({vector})"
|
||||
async def create_data_point_query(self, data_point: DataPoint, vectorized_values: list = None):
|
||||
node_label = type(data_point).__tablename__
|
||||
embeddable_fields = data_point._metadata.get("index_fields", [])
|
||||
|
||||
async def create_data_point_query(self, data_point: DataPoint):
|
||||
node_label = type(data_point).__name__
|
||||
node_properties = await self.stringify_properties(
|
||||
data_point.model_dump(),
|
||||
data_point._metadata["index_fields"],
|
||||
# data_point._metadata["index_fields"] if hasattr(data_point, "_metadata") else [],
|
||||
)
|
||||
node_properties = await self.stringify_properties({
|
||||
**data_point.model_dump(),
|
||||
**({
|
||||
embeddable_fields[index]: vectorized_values[index] \
|
||||
for index in range(len(embeddable_fields)) \
|
||||
} if vectorized_values is not None else {}),
|
||||
})
|
||||
|
||||
return dedent(f"""
|
||||
MERGE (node:{node_label} {{id: '{str(data_point.id)}'}})
|
||||
|
|
@ -90,7 +98,33 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
|
|||
return collection_name in collections
|
||||
|
||||
async def create_data_points(self, data_points: list[DataPoint]):
|
||||
queries = [await self.create_data_point_query(data_point) for data_point in data_points]
|
||||
embeddable_values = [DataPoint.get_embeddable_properties(data_point) for data_point in data_points]
|
||||
|
||||
vectorized_values = await self.embed_data(
|
||||
sum(embeddable_values, [])
|
||||
)
|
||||
|
||||
index = 0
|
||||
positioned_vectorized_values = []
|
||||
|
||||
for values in embeddable_values:
|
||||
if len(values) > 0:
|
||||
values_list = []
|
||||
for i in range(len(values)):
|
||||
values_list.append(vectorized_values[index + i])
|
||||
|
||||
positioned_vectorized_values.append(values_list)
|
||||
index += len(values)
|
||||
else:
|
||||
positioned_vectorized_values.append(None)
|
||||
|
||||
queries = [
|
||||
await self.create_data_point_query(
|
||||
data_point,
|
||||
positioned_vectorized_values[index],
|
||||
) for index, data_point in enumerate(data_points)
|
||||
]
|
||||
|
||||
for query in queries:
|
||||
self.query(query)
|
||||
|
||||
|
|
@ -205,10 +239,12 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
|
|||
if query_text and not query_vector:
|
||||
query_vector = (await self.embed_data([query_text]))[0]
|
||||
|
||||
[label, attribute_name] = collection_name.split(".")
|
||||
|
||||
query = dedent(f"""
|
||||
CALL db.idx.vector.queryNodes(
|
||||
{collection_name},
|
||||
'text',
|
||||
'{label}',
|
||||
'{attribute_name}',
|
||||
{limit},
|
||||
vecf32({query_vector})
|
||||
) YIELD node, score
|
||||
|
|
@ -216,7 +252,7 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
|
|||
|
||||
result = self.query(query)
|
||||
|
||||
return result
|
||||
return result.result_set
|
||||
|
||||
async def batch_search(
|
||||
self,
|
||||
|
|
@ -236,6 +272,30 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
|
|||
) for query_vector in query_vectors]
|
||||
)
|
||||
|
||||
async def get_graph_data(self):
|
||||
query = "MATCH (n) RETURN ID(n) AS id, labels(n) AS labels, properties(n) AS properties"
|
||||
|
||||
result = self.query(query)
|
||||
|
||||
nodes = [(
|
||||
record[2]["id"],
|
||||
record[2],
|
||||
) for record in result.result_set]
|
||||
|
||||
query = """
|
||||
MATCH (n)-[r]->(m)
|
||||
RETURN ID(n) AS source, ID(m) AS target, TYPE(r) AS type, properties(r) AS properties
|
||||
"""
|
||||
result = self.query(query)
|
||||
edges = [(
|
||||
record[3]["source_node_id"],
|
||||
record[3]["target_node_id"],
|
||||
record[2],
|
||||
record[3],
|
||||
) for record in result.result_set]
|
||||
|
||||
return (nodes, edges)
|
||||
|
||||
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
||||
return self.query(
|
||||
f"MATCH (node) WHERE node.id IN $node_ids DETACH DELETE node",
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
|
|||
)
|
||||
|
||||
elif config["vector_db_provider"] == "falkordb":
|
||||
if not (config["vector_db_url"] and config["vector_db_key"]):
|
||||
if not (config["vector_db_url"] and config["vector_db_port"]):
|
||||
raise EnvironmentError("Missing requred FalkorDB credentials!")
|
||||
|
||||
from ..hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter
|
||||
|
|
|
|||
|
|
@ -36,10 +36,10 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
api_version = self.api_version
|
||||
)
|
||||
|
||||
return response.data[0]["embedding"]
|
||||
return [data["embedding"] for data in response.data]
|
||||
|
||||
tasks = [get_embedding(text_) for text_ in text]
|
||||
result = await asyncio.gather(*tasks)
|
||||
# tasks = [get_embedding(text_) for text_ in text]
|
||||
result = await get_embedding(text)
|
||||
return result
|
||||
|
||||
def get_vector_size(self) -> int:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import inspect
|
||||
from typing import List, Optional, get_type_hints, Generic, TypeVar
|
||||
import asyncio
|
||||
from uuid import UUID
|
||||
|
|
|
|||
|
|
@ -23,7 +23,15 @@ class DataPoint(BaseModel):
|
|||
if self._metadata and len(self._metadata["index_fields"]) > 0 \
|
||||
and hasattr(self, self._metadata["index_fields"][0]):
|
||||
attribute = getattr(self, self._metadata["index_fields"][0])
|
||||
|
||||
if isinstance(attribute, str):
|
||||
return(attribute.strip())
|
||||
return attribute.strip()
|
||||
else:
|
||||
return (attribute)
|
||||
return attribute
|
||||
|
||||
@classmethod
|
||||
def get_embeddable_properties(self, data_point):
|
||||
if data_point._metadata and len(data_point._metadata["index_fields"]) > 0:
|
||||
return [getattr(data_point, field, None) for field in data_point._metadata["index_fields"]]
|
||||
|
||||
return []
|
||||
|
|
|
|||
|
|
@ -1,9 +1,6 @@
|
|||
from typing import Union
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
|
||||
from cognee.modules.engine.models.EntityType import EntityType
|
||||
from cognee.shared.CodeGraphEntities import Repository
|
||||
|
||||
|
||||
class Entity(DataPoint):
|
||||
|
|
@ -11,7 +8,7 @@ class Entity(DataPoint):
|
|||
name: str
|
||||
is_a: EntityType
|
||||
description: str
|
||||
mentioned_in: Union[DocumentChunk, Repository]
|
||||
mentioned_in: DocumentChunk
|
||||
_metadata: dict = {
|
||||
"index_fields": ["name"],
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,8 +1,5 @@
|
|||
from typing import Union
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
|
||||
from cognee.shared.CodeGraphEntities import Repository
|
||||
|
||||
|
||||
class EntityType(DataPoint):
|
||||
|
|
@ -10,7 +7,7 @@ class EntityType(DataPoint):
|
|||
name: str
|
||||
type: str
|
||||
description: str
|
||||
exists_in: Union[DocumentChunk, Repository]
|
||||
exists_in: DocumentChunk
|
||||
_metadata: dict = {
|
||||
"index_fields": ["name"],
|
||||
}
|
||||
|
|
|
|||
|
|
@ -122,6 +122,7 @@ async def get_graph_from_model(
|
|||
type(data_point),
|
||||
include_fields = {
|
||||
"_metadata": (dict, data_point._metadata),
|
||||
"__tablename__": data_point.__tablename__,
|
||||
},
|
||||
exclude_fields = excluded_properties,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import asyncio
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
|
||||
|
|
@ -9,8 +10,12 @@ async def index_data_points(data_points: list[DataPoint]):
|
|||
|
||||
flat_data_points: list[DataPoint] = []
|
||||
|
||||
for data_point in data_points:
|
||||
flat_data_points.extend(get_data_points_from_model(data_point))
|
||||
results = await asyncio.gather(*[
|
||||
get_data_points_from_model(data_point) for data_point in data_points
|
||||
])
|
||||
|
||||
for result in results:
|
||||
flat_data_points.extend(result)
|
||||
|
||||
for data_point in flat_data_points:
|
||||
data_point_type = type(data_point)
|
||||
|
|
@ -38,7 +43,7 @@ async def index_data_points(data_points: list[DataPoint]):
|
|||
|
||||
return data_points
|
||||
|
||||
def get_data_points_from_model(data_point: DataPoint, added_data_points = None, visited_properties = None) -> list[DataPoint]:
|
||||
async def get_data_points_from_model(data_point: DataPoint, added_data_points = None, visited_properties = None) -> list[DataPoint]:
|
||||
data_points = []
|
||||
added_data_points = added_data_points or {}
|
||||
visited_properties = visited_properties or {}
|
||||
|
|
@ -52,7 +57,7 @@ def get_data_points_from_model(data_point: DataPoint, added_data_points = None,
|
|||
|
||||
visited_properties[property_key] = True
|
||||
|
||||
new_data_points = get_data_points_from_model(field_value, added_data_points, visited_properties)
|
||||
new_data_points = await get_data_points_from_model(field_value, added_data_points, visited_properties)
|
||||
|
||||
for new_point in new_data_points:
|
||||
if str(new_point.id) not in added_data_points:
|
||||
|
|
@ -68,7 +73,7 @@ def get_data_points_from_model(data_point: DataPoint, added_data_points = None,
|
|||
|
||||
visited_properties[property_key] = True
|
||||
|
||||
new_data_points = get_data_points_from_model(field_value_item, added_data_points, visited_properties)
|
||||
new_data_points = await get_data_points_from_model(field_value_item, added_data_points, visited_properties)
|
||||
|
||||
for new_point in new_data_points:
|
||||
if str(new_point.id) not in added_data_points:
|
||||
|
|
|
|||
83
cognee/tests/test_falkordb.py
Executable file
83
cognee/tests/test_falkordb.py
Executable file
|
|
@ -0,0 +1,83 @@
|
|||
import os
|
||||
import logging
|
||||
import pathlib
|
||||
import cognee
|
||||
from cognee.api.v1.search import SearchType
|
||||
from cognee.shared.utils import render_graph
|
||||
|
||||
logging.basicConfig(level = logging.DEBUG)
|
||||
|
||||
async def main():
|
||||
data_directory_path = str(pathlib.Path(os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_library")).resolve())
|
||||
cognee.config.data_root_directory(data_directory_path)
|
||||
cognee_directory_path = str(pathlib.Path(os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_library")).resolve())
|
||||
cognee.config.system_root_directory(cognee_directory_path)
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata = True)
|
||||
|
||||
dataset_name = "artificial_intelligence"
|
||||
|
||||
ai_text_file_path = os.path.join(pathlib.Path(__file__).parent, "test_data/artificial-intelligence.pdf")
|
||||
await cognee.add([ai_text_file_path], dataset_name)
|
||||
|
||||
text = """A large language model (LLM) is a language model notable for its ability to achieve general-purpose language generation and other natural language processing tasks such as classification. LLMs acquire these abilities by learning statistical relationships from text documents during a computationally intensive self-supervised and semi-supervised training process. LLMs can be used for text generation, a form of generative AI, by taking an input text and repeatedly predicting the next token or word.
|
||||
LLMs are artificial neural networks. The largest and most capable, as of March 2024, are built with a decoder-only transformer-based architecture while some recent implementations are based on other architectures, such as recurrent neural network variants and Mamba (a state space model).
|
||||
Up to 2020, fine tuning was the only way a model could be adapted to be able to accomplish specific tasks. Larger sized models, such as GPT-3, however, can be prompt-engineered to achieve similar results.[6] They are thought to acquire knowledge about syntax, semantics and "ontology" inherent in human language corpora, but also inaccuracies and biases present in the corpora.
|
||||
Some notable LLMs are OpenAI's GPT series of models (e.g., GPT-3.5 and GPT-4, used in ChatGPT and Microsoft Copilot), Google's PaLM and Gemini (the latter of which is currently used in the chatbot of the same name), xAI's Grok, Meta's LLaMA family of open-source models, Anthropic's Claude models, Mistral AI's open source models, and Databricks' open source DBRX.
|
||||
"""
|
||||
|
||||
await cognee.add([text], dataset_name)
|
||||
|
||||
await cognee.cognify([dataset_name])
|
||||
|
||||
# await render_graph(None, include_labels = True, include_nodes = True)
|
||||
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
vector_engine = get_vector_engine()
|
||||
random_node = (await vector_engine.search("entity.name", "AI"))[0]
|
||||
random_node_name = random_node.payload["text"]
|
||||
|
||||
search_results = await cognee.search(SearchType.INSIGHTS, query_text = random_node_name)
|
||||
assert len(search_results) != 0, "The search results list is empty."
|
||||
print("\n\nExtracted sentences are:\n")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
||||
search_results = await cognee.search(SearchType.CHUNKS, query_text = random_node_name)
|
||||
assert len(search_results) != 0, "The search results list is empty."
|
||||
print("\n\nExtracted chunks are:\n")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
||||
search_results = await cognee.search(SearchType.SUMMARIES, query_text = random_node_name)
|
||||
assert len(search_results) != 0, "Query related summaries don't exist."
|
||||
print("\nExtracted summaries are:\n")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
||||
history = await cognee.get_search_history()
|
||||
|
||||
assert len(history) == 6, "Search history is not correct."
|
||||
|
||||
# Assert local data files are cleaned properly
|
||||
await cognee.prune.prune_data()
|
||||
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
||||
|
||||
# Assert relational, vector and graph databases have been cleaned properly
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
connection = await vector_engine.get_connection()
|
||||
collection_names = await connection.table_names()
|
||||
assert len(collection_names) == 0, "LanceDB vector database is not empty"
|
||||
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
assert not os.path.exists(get_relational_engine().db_path), "SQLite relational database is not empty"
|
||||
|
||||
from cognee.infrastructure.databases.graph import get_graph_config
|
||||
graph_config = get_graph_config()
|
||||
assert not os.path.exists(graph_config.graph_file_path), "Networkx graph database is not empty"
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
asyncio.run(main(), debug=True)
|
||||
Loading…
Add table
Reference in a new issue