From 73372df31e0d4f68f0b0de267b1ae15858891df8 Mon Sep 17 00:00:00 2001 From: Boris Arzentar Date: Thu, 24 Oct 2024 12:37:06 +0200 Subject: [PATCH 01/20] feat: add falkordb adapter --- .../databases/graph/neo4j_driver/adapter.py | 1 - .../vector/falkordb/FalkorDBAdapter.py | 113 ++++++++++++++++++ examples/python/GraphModel.py | 62 ++++++++++ poetry.lock | 40 ++++++- pyproject.toml | 3 +- 5 files changed, 215 insertions(+), 4 deletions(-) create mode 100644 cognee/infrastructure/databases/vector/falkordb/FalkorDBAdapter.py create mode 100644 examples/python/GraphModel.py diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py index 26bbb5819..8831591d5 100644 --- a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +++ b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py @@ -8,7 +8,6 @@ from uuid import UUID from neo4j import AsyncSession from neo4j import AsyncGraphDatabase from neo4j.exceptions import Neo4jError -from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface logger = logging.getLogger("Neo4jAdapter") diff --git a/cognee/infrastructure/databases/vector/falkordb/FalkorDBAdapter.py b/cognee/infrastructure/databases/vector/falkordb/FalkorDBAdapter.py new file mode 100644 index 000000000..744d79f53 --- /dev/null +++ b/cognee/infrastructure/databases/vector/falkordb/FalkorDBAdapter.py @@ -0,0 +1,113 @@ +import asyncio +from falkordb import FalkorDB +from ..models.DataPoint import DataPoint +from ..vector_db_interface import VectorDBInterface +from ..embeddings.EmbeddingEngine import EmbeddingEngine + + +class FalcorDBAdapter(VectorDBInterface): + def __init__( + self, + graph_database_url: str, + graph_database_port: int, + embedding_engine = EmbeddingEngine, + ): + self.driver = FalkorDB( + host = graph_database_url, + port = graph_database_port) + self.embedding_engine = embedding_engine + + + async def embed_data(self, data: list[str]) -> list[list[float]]: + return await self.embedding_engine.embed_text(data) + + async def has_collection(self, collection_name: str) -> bool: + collections = self.driver.list_graphs() + + return collection_name in collections + + async def create_collection(self, collection_name: str, payload_schema = None): + self.driver.select_graph(collection_name) + + async def create_data_points(self, collection_name: str, data_points: list[DataPoint]): + graph = self.driver.select_graph(collection_name) + + def stringify_properties(properties: dict) -> str: + return ",".join(f"{key}:'{value}'" for key, value in properties.items()) + + def create_data_point_query(data_point: DataPoint): + node_label = type(data_point.payload).__name__ + node_properties = stringify_properties(data_point.payload.dict()) + + return f"""CREATE (:{node_label} {{{node_properties}}})""" + + query = " ".join([create_data_point_query(data_point) for data_point in data_points]) + + graph.query(query) + + async def retrieve(self, collection_name: str, data_point_ids: list[str]): + graph = self.driver.select_graph(collection_name) + + return graph.query( + f"MATCH (node) WHERE node.id IN $node_ids RETURN node", + { + "node_ids": data_point_ids, + }, + ) + + async def search( + self, + collection_name: str, + query_text: str = None, + query_vector: list[float] = None, + limit: int = 10, + with_vector: bool = False, + ): + if query_text is None and query_vector is None: + raise ValueError("One of query_text or query_vector must be provided!") + + if query_text and not query_vector: + query_vector = (await self.embedding_engine.embed_text([query_text]))[0] + + graph = self.driver.select_graph(collection_name) + + query = f""" + CALL db.idx.vector.queryNodes( + null, + 'text', + {limit}, + {query_vector} + ) YIELD node, score + """ + + result = graph.query(query) + + return result + + async def batch_search( + self, + collection_name: str, + query_texts: list[str], + limit: int = None, + with_vectors: bool = False, + ): + query_vectors = await self.embedding_engine.embed_text(query_texts) + + return await asyncio.gather( + *[self.search( + collection_name = collection_name, + query_vector = query_vector, + limit = limit, + with_vector = with_vectors, + ) for query_vector in query_vectors] + ) + + async def delete_data_points(self, collection_name: str, data_point_ids: list[str]): + graph = self.driver.select_graph(collection_name) + + return graph.query( + f"MATCH (node) WHERE node.id IN $node_ids DETACH DELETE node", + { + "node_ids": data_point_ids, + }, + ) diff --git a/examples/python/GraphModel.py b/examples/python/GraphModel.py new file mode 100644 index 000000000..01251fc20 --- /dev/null +++ b/examples/python/GraphModel.py @@ -0,0 +1,62 @@ + +from typing import Optional +from uuid import UUID +from datetime import datetime +from pydantic import BaseModel + + +async def add_data_points(collection_name: str, data_points: list): + pass + + + +class Summary(BaseModel): + id: UUID + text: str + chunk: "Chunk" + created_at: datetime + updated_at: Optional[datetime] + + vector_index = ["text"] + +class Chunk(BaseModel): + id: UUID + text: str + summary: Summary + document: "Document" + created_at: datetime + updated_at: Optional[datetime] + word_count: int + chunk_index: int + cut_type: str + + vector_index = ["text"] + +class Document(BaseModel): + id: UUID + chunks: list[Chunk] + created_at: datetime + updated_at: Optional[datetime] + +class EntityType(BaseModel): + id: UUID + name: str + description: str + created_at: datetime + updated_at: Optional[datetime] + + vector_index = ["name"] + +class Entity(BaseModel): + id: UUID + name: str + type: EntityType + description: str + chunks: list[Chunk] + created_at: datetime + updated_at: Optional[datetime] + + vector_index = ["name"] + +class OntologyModel(BaseModel): + chunks: list[Chunk] diff --git a/poetry.lock b/poetry.lock index 270e66027..12b1e59ba 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "aiofiles" @@ -3729,6 +3729,24 @@ htmlmin2 = ">=0.1.13" jsmin = ">=3.0.1" mkdocs = ">=1.4.1" +[[package]] +name = "mkdocs-redirects" +version = "1.2.1" +description = "A MkDocs plugin for dynamic page redirects to prevent broken links." +optional = false +python-versions = ">=3.6" +files = [ + {file = "mkdocs-redirects-1.2.1.tar.gz", hash = "sha256:9420066d70e2a6bb357adf86e67023dcdca1857f97f07c7fe450f8f1fb42f861"}, +] + +[package.dependencies] +mkdocs = ">=1.1.1" + +[package.extras] +dev = ["autoflake", "black", "isort", "pytest", "twine (>=1.13.0)"] +release = ["twine (>=1.13.0)"] +test = ["autoflake", "black", "isort", "pytest"] + [[package]] name = "mkdocstrings" version = "0.26.2" @@ -5799,6 +5817,24 @@ async-timeout = {version = ">=4.0.3", markers = "python_full_version < \"3.11.3\ hiredis = ["hiredis (>=3.0.0)"] ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==23.2.1)", "requests (>=2.31.0)"] +[[package]] +name = "redis" +version = "5.1.1" +description = "Python client for Redis database and key-value store" +optional = false +python-versions = ">=3.8" +files = [ + {file = "redis-5.1.1-py3-none-any.whl", hash = "sha256:f8ea06b7482a668c6475ae202ed8d9bcaa409f6e87fb77ed1043d912afd62e24"}, + {file = "redis-5.1.1.tar.gz", hash = "sha256:f6c997521fedbae53387307c5d0bf784d9acc28d9f1d058abeac566ec4dbed72"}, +] + +[package.dependencies] +async-timeout = {version = ">=4.0.3", markers = "python_full_version < \"3.11.3\""} + +[package.extras] +hiredis = ["hiredis (>=3.0.0)"] +ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==23.2.1)", "requests (>=2.31.0)"] + [[package]] name = "referencing" version = "0.35.1" @@ -7746,4 +7782,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.9.0,<3.12" -content-hash = "fb09733ff7a70fb91c5f72ff0c8a8137b857557930a7aa025aad3154de4d8ceb" +content-hash = "fef56656ead761cab7d5c3d0bf1fa5a54608db73b14616d08e5fb152dba91236" diff --git a/pyproject.toml b/pyproject.toml index 0bc3849b1..92d8f829b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,7 +69,8 @@ fastapi-users = {version = "*", extras = ["sqlalchemy"]} alembic = "^1.13.3" asyncpg = "^0.29.0" pgvector = "^0.3.5" -psycopg2 = "^2.9.10" +psycopg2 = {version = "^2.9.10", optional = true} +falkordb = "^1.0.9" [tool.poetry.extras] filesystem = ["s3fs", "botocore"] From a2b1087c84c0ac54f8a98295a78f54f643cde50f Mon Sep 17 00:00:00 2001 From: Boris Arzentar Date: Thu, 7 Nov 2024 11:17:01 +0100 Subject: [PATCH 02/20] feat: add FalkorDB integration --- .../databases/graph/neo4j_driver/adapter.py | 35 +++--- .../databases/graph/networkx/adapter.py | 4 - .../hybrid/falkordb/FalkorDBAdapter.py | 30 ----- .../vector/falkordb/FalkorDBAdapter.py | 113 ------------------ .../vector/pgvector/PGVectorAdapter.py | 50 ++++---- .../databases/vector/qdrant/QDrantAdapter.py | 1 - .../vector/weaviate_db/WeaviateAdapter.py | 15 ++- cognee/modules/chunking/TextChunker.py | 6 +- cognee/modules/engine/utils/__init__.py | 1 - .../engine/utils/generate_node_name.py | 2 +- .../graph/utils/get_graph_from_model.py | 66 ++++------ cognee/shared/utils.py | 2 +- cognee/tasks/graph/__init__.py | 1 - cognee/tasks/graph/extract_graph_from_data.py | 6 +- cognee/tasks/graph/query_graph_connections.py | 4 +- .../infer_data_ontology/models/models.py | 31 ----- cognee/tasks/storage/index_data_points.py | 5 +- .../tasks/summarization/models/TextSummary.py | 3 +- cognee/tasks/summarization/summarize_text.py | 6 +- cognee/tests/test_library.py | 4 +- cognee/tests/test_neo4j.py | 4 +- cognee/tests/test_pgvector.py | 4 +- cognee/tests/test_qdrant.py | 4 +- cognee/tests/test_weaviate.py | 4 +- examples/python/GraphModel.py | 62 ---------- notebooks/cognee_demo.ipynb | 22 ++-- poetry.lock | 54 +-------- pyproject.toml | 2 +- tools/daily_twitter_stats.py | 66 ++++++++++ 29 files changed, 180 insertions(+), 427 deletions(-) delete mode 100644 cognee/infrastructure/databases/vector/falkordb/FalkorDBAdapter.py delete mode 100644 cognee/tasks/infer_data_ontology/models/models.py delete mode 100644 examples/python/GraphModel.py create mode 100644 tools/daily_twitter_stats.py diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py index 8831591d5..7165aa29b 100644 --- a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +++ b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py @@ -8,6 +8,7 @@ from uuid import UUID from neo4j import AsyncSession from neo4j import AsyncGraphDatabase from neo4j.exceptions import Neo4jError +from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface logger = logging.getLogger("Neo4jAdapter") @@ -62,10 +63,11 @@ class Neo4jAdapter(GraphDBInterface): async def add_node(self, node: DataPoint): serialized_properties = self.serialize_properties(node.model_dump()) - query = dedent("""MERGE (node {id: $node_id}) - ON CREATE SET node += $properties, node.updated_at = timestamp() - ON MATCH SET node += $properties, node.updated_at = timestamp() - RETURN ID(node) AS internal_id, node.id AS nodeId""") + query = """MERGE (node {id: $node_id}) + ON CREATE SET node += $properties + ON MATCH SET node += $properties + ON MATCH SET node.updated_at = timestamp() + RETURN ID(node) AS internal_id, node.id AS nodeId""" params = { "node_id": str(node.id), @@ -78,8 +80,9 @@ class Neo4jAdapter(GraphDBInterface): query = """ UNWIND $nodes AS node MERGE (n {id: node.node_id}) - ON CREATE SET n += node.properties, n.updated_at = timestamp() - ON MATCH SET n += node.properties, n.updated_at = timestamp() + ON CREATE SET n += node.properties + ON MATCH SET n += node.properties + ON MATCH SET n.updated_at = timestamp() WITH n, node.node_id AS label CALL apoc.create.addLabels(n, [label]) YIELD node AS labeledNode RETURN ID(labeledNode) AS internal_id, labeledNode.id AS nodeId @@ -134,9 +137,8 @@ class Neo4jAdapter(GraphDBInterface): return await self.query(query, params) async def has_edge(self, from_node: UUID, to_node: UUID, edge_label: str) -> bool: - query = """ - MATCH (from_node)-[relationship]->(to_node) - WHERE from_node.id = $from_node_id AND to_node.id = $to_node_id AND type(relationship) = $edge_label + query = f""" + MATCH (from_node:`{str(from_node)}`)-[relationship:`{edge_label}`]->(to_node:`{str(to_node)}`) RETURN COUNT(relationship) > 0 AS edge_exists """ @@ -176,18 +178,17 @@ class Neo4jAdapter(GraphDBInterface): async def add_edge(self, from_node: UUID, to_node: UUID, relationship_name: str, edge_properties: Optional[Dict[str, Any]] = {}): serialized_properties = self.serialize_properties(edge_properties) - query = dedent("""MATCH (from_node {id: $from_node}), - (to_node {id: $to_node}) - MERGE (from_node)-[r]->(to_node) - ON CREATE SET r += $properties, r.updated_at = timestamp(), r.type = $relationship_name - ON MATCH SET r += $properties, r.updated_at = timestamp() - RETURN r - """) + query = f"""MATCH (from_node:`{str(from_node)}` + {{id: $from_node}}), + (to_node:`{str(to_node)}` {{id: $to_node}}) + MERGE (from_node)-[r:`{relationship_name}`]->(to_node) + ON CREATE SET r += $properties, r.updated_at = timestamp() + ON MATCH SET r += $properties, r.updated_at = timestamp() + RETURN r""" params = { "from_node": str(from_node), "to_node": str(to_node), - "relationship_name": relationship_name, "properties": serialized_properties } diff --git a/cognee/infrastructure/databases/graph/networkx/adapter.py b/cognee/infrastructure/databases/graph/networkx/adapter.py index 65aeea289..b0a9e7a13 100644 --- a/cognee/infrastructure/databases/graph/networkx/adapter.py +++ b/cognee/infrastructure/databases/graph/networkx/adapter.py @@ -30,10 +30,6 @@ class NetworkXAdapter(GraphDBInterface): def __init__(self, filename = "cognee_graph.pkl"): self.filename = filename - async def get_graph_data(self): - await self.load_graph_from_file() - return (list(self.graph.nodes(data = True)), list(self.graph.edges(data = True, keys = True))) - async def query(self, query: str, params: dict): pass diff --git a/cognee/infrastructure/databases/hybrid/falkordb/FalkorDBAdapter.py b/cognee/infrastructure/databases/hybrid/falkordb/FalkorDBAdapter.py index ea5a75088..effe9e682 100644 --- a/cognee/infrastructure/databases/hybrid/falkordb/FalkorDBAdapter.py +++ b/cognee/infrastructure/databases/hybrid/falkordb/FalkorDBAdapter.py @@ -1,7 +1,6 @@ import asyncio from textwrap import dedent from typing import Any -from uuid import UUID from falkordb import FalkorDB from cognee.infrastructure.engine import DataPoint @@ -162,35 +161,6 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface): async def extract_nodes(self, data_point_ids: list[str]): return await self.retrieve(data_point_ids) - async def get_connections(self, node_id: UUID) -> list: - predecessors_query = """ - MATCH (node)<-[relation]-(neighbour) - WHERE node.id = $node_id - RETURN neighbour, relation, node - """ - successors_query = """ - MATCH (node)-[relation]->(neighbour) - WHERE node.id = $node_id - RETURN node, relation, neighbour - """ - - predecessors, successors = await asyncio.gather( - self.query(predecessors_query, dict(node_id = node_id)), - self.query(successors_query, dict(node_id = node_id)), - ) - - connections = [] - - for neighbour in predecessors: - neighbour = neighbour["relation"] - connections.append((neighbour[0], { "relationship_name": neighbour[1] }, neighbour[2])) - - for neighbour in successors: - neighbour = neighbour["relation"] - connections.append((neighbour[0], { "relationship_name": neighbour[1] }, neighbour[2])) - - return connections - async def search( self, collection_name: str, diff --git a/cognee/infrastructure/databases/vector/falkordb/FalkorDBAdapter.py b/cognee/infrastructure/databases/vector/falkordb/FalkorDBAdapter.py deleted file mode 100644 index 744d79f53..000000000 --- a/cognee/infrastructure/databases/vector/falkordb/FalkorDBAdapter.py +++ /dev/null @@ -1,113 +0,0 @@ -import asyncio -from falkordb import FalkorDB -from ..models.DataPoint import DataPoint -from ..vector_db_interface import VectorDBInterface -from ..embeddings.EmbeddingEngine import EmbeddingEngine - - -class FalcorDBAdapter(VectorDBInterface): - def __init__( - self, - graph_database_url: str, - graph_database_port: int, - embedding_engine = EmbeddingEngine, - ): - self.driver = FalkorDB( - host = graph_database_url, - port = graph_database_port) - self.embedding_engine = embedding_engine - - - async def embed_data(self, data: list[str]) -> list[list[float]]: - return await self.embedding_engine.embed_text(data) - - async def has_collection(self, collection_name: str) -> bool: - collections = self.driver.list_graphs() - - return collection_name in collections - - async def create_collection(self, collection_name: str, payload_schema = None): - self.driver.select_graph(collection_name) - - async def create_data_points(self, collection_name: str, data_points: list[DataPoint]): - graph = self.driver.select_graph(collection_name) - - def stringify_properties(properties: dict) -> str: - return ",".join(f"{key}:'{value}'" for key, value in properties.items()) - - def create_data_point_query(data_point: DataPoint): - node_label = type(data_point.payload).__name__ - node_properties = stringify_properties(data_point.payload.dict()) - - return f"""CREATE (:{node_label} {{{node_properties}}})""" - - query = " ".join([create_data_point_query(data_point) for data_point in data_points]) - - graph.query(query) - - async def retrieve(self, collection_name: str, data_point_ids: list[str]): - graph = self.driver.select_graph(collection_name) - - return graph.query( - f"MATCH (node) WHERE node.id IN $node_ids RETURN node", - { - "node_ids": data_point_ids, - }, - ) - - async def search( - self, - collection_name: str, - query_text: str = None, - query_vector: list[float] = None, - limit: int = 10, - with_vector: bool = False, - ): - if query_text is None and query_vector is None: - raise ValueError("One of query_text or query_vector must be provided!") - - if query_text and not query_vector: - query_vector = (await self.embedding_engine.embed_text([query_text]))[0] - - graph = self.driver.select_graph(collection_name) - - query = f""" - CALL db.idx.vector.queryNodes( - null, - 'text', - {limit}, - {query_vector} - ) YIELD node, score - """ - - result = graph.query(query) - - return result - - async def batch_search( - self, - collection_name: str, - query_texts: list[str], - limit: int = None, - with_vectors: bool = False, - ): - query_vectors = await self.embedding_engine.embed_text(query_texts) - - return await asyncio.gather( - *[self.search( - collection_name = collection_name, - query_vector = query_vector, - limit = limit, - with_vector = with_vectors, - ) for query_vector in query_vectors] - ) - - async def delete_data_points(self, collection_name: str, data_point_ids: list[str]): - graph = self.driver.select_graph(collection_name) - - return graph.query( - f"MATCH (node) WHERE node.id IN $node_ids DETACH DELETE node", - { - "node_ids": data_point_ids, - }, - ) diff --git a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py index 01691714b..d9aecec90 100644 --- a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +++ b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py @@ -8,7 +8,7 @@ from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker from cognee.infrastructure.engine import DataPoint -from .serialize_data import serialize_data +from .serialize_datetime import serialize_datetime from ..models.ScoredResult import ScoredResult from ..vector_db_interface import VectorDBInterface from ..embeddings.EmbeddingEngine import EmbeddingEngine @@ -79,10 +79,15 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): async def create_data_points( self, collection_name: str, data_points: List[DataPoint] ): - if not await self.has_collection(collection_name): - await self.create_collection( - collection_name = collection_name, - payload_schema = type(data_points[0]), + async with self.get_async_session() as session: + if not await self.has_collection(collection_name): + await self.create_collection( + collection_name=collection_name, + payload_schema=type(data_points[0]), + ) + + data_vectors = await self.embed_data( + [data_point.get_embeddable_data() for data_point in data_points] ) data_vectors = await self.embed_data( @@ -102,10 +107,14 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): payload = Column(JSON) vector = Column(Vector(vector_size)) - def __init__(self, id, payload, vector): - self.id = id - self.payload = payload - self.vector = vector + pgvector_data_points = [ + PGVectorDataPoint( + id=data_point.id, + vector=data_vectors[data_index], + payload=serialize_datetime(data_point.model_dump()), + ) + for (data_index, data_point) in enumerate(data_points) + ] pgvector_data_points = [ PGVectorDataPoint( @@ -127,7 +136,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): await self.create_data_points(f"{index_name}_{index_property_name}", [ IndexSchema( id = data_point.id, - text = data_point.get_embeddable_data(), + text = getattr(data_point, data_point._metadata["index_fields"][0]), ) for data_point in data_points ]) @@ -197,19 +206,14 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): vector_list = [] - # Extract distances and find min/max for normalization - for vector in closest_items: - # TODO: Add normalization of similarity score - vector_list.append(vector) - - # Create and return ScoredResult objects - return [ - ScoredResult( - id = UUID(str(row.id)), - payload = row.payload, - score = row.similarity - ) for row in vector_list - ] + # Create and return ScoredResult objects + return [ + ScoredResult( + id = UUID(row.id), + payload = row.payload, + score = row.similarity + ) for row in vector_list + ] async def batch_search( self, diff --git a/cognee/infrastructure/databases/vector/qdrant/QDrantAdapter.py b/cognee/infrastructure/databases/vector/qdrant/QDrantAdapter.py index 1efcd47b3..436861a45 100644 --- a/cognee/infrastructure/databases/vector/qdrant/QDrantAdapter.py +++ b/cognee/infrastructure/databases/vector/qdrant/QDrantAdapter.py @@ -3,7 +3,6 @@ from uuid import UUID from typing import List, Dict, Optional from qdrant_client import AsyncQdrantClient, models -from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult from cognee.infrastructure.engine import DataPoint from ..vector_db_interface import VectorDBInterface from ..embeddings.EmbeddingEngine import EmbeddingEngine diff --git a/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py b/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py index be356740f..2e4e88323 100644 --- a/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py +++ b/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py @@ -11,6 +11,7 @@ from ..embeddings.EmbeddingEngine import EmbeddingEngine logger = logging.getLogger("WeaviateAdapter") class IndexSchema(DataPoint): + uuid: str text: str _metadata: dict = { @@ -88,10 +89,8 @@ class WeaviateAdapter(VectorDBInterface): def convert_to_weaviate_data_points(data_point: DataPoint): vector = data_vectors[data_points.index(data_point)] properties = data_point.model_dump() - - if "id" in properties: - properties["uuid"] = str(data_point.id) - del properties["id"] + properties["uuid"] = properties["id"] + del properties["id"] return DataObject( uuid = data_point.id, @@ -131,8 +130,8 @@ class WeaviateAdapter(VectorDBInterface): async def index_data_points(self, index_name: str, index_property_name: str, data_points: list[DataPoint]): await self.create_data_points(f"{index_name}_{index_property_name}", [ IndexSchema( - id = data_point.id, - text = data_point.get_embeddable_data(), + uuid = str(data_point.id), + text = getattr(data_point, data_point._metadata["index_fields"][0]), ) for data_point in data_points ]) @@ -179,9 +178,9 @@ class WeaviateAdapter(VectorDBInterface): return [ ScoredResult( - id = UUID(str(result.uuid)), + id = UUID(result.id), payload = result.properties, - score = 1 - float(result.metadata.score) + score = float(result.metadata.score) ) for result in search_result.objects ] diff --git a/cognee/modules/chunking/TextChunker.py b/cognee/modules/chunking/TextChunker.py index 714383804..4717d108d 100644 --- a/cognee/modules/chunking/TextChunker.py +++ b/cognee/modules/chunking/TextChunker.py @@ -29,7 +29,7 @@ class TextChunker(): else: if len(self.paragraph_chunks) == 0: yield DocumentChunk( - id = chunk_data["chunk_id"], + id = str(chunk_data["chunk_id"]), text = chunk_data["text"], word_count = chunk_data["word_count"], is_part_of = self.document, @@ -42,7 +42,7 @@ class TextChunker(): chunk_text = " ".join(chunk["text"] for chunk in self.paragraph_chunks) try: yield DocumentChunk( - id = uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}"), + id = str(uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}")), text = chunk_text, word_count = self.chunk_size, is_part_of = self.document, @@ -59,7 +59,7 @@ class TextChunker(): if len(self.paragraph_chunks) > 0: try: yield DocumentChunk( - id = uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}"), + id = str(uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}")), text = " ".join(chunk["text"] for chunk in self.paragraph_chunks), word_count = self.chunk_size, is_part_of = self.document, diff --git a/cognee/modules/engine/utils/__init__.py b/cognee/modules/engine/utils/__init__.py index 4d4ab02e7..9cc2bc573 100644 --- a/cognee/modules/engine/utils/__init__.py +++ b/cognee/modules/engine/utils/__init__.py @@ -1,3 +1,2 @@ from .generate_node_id import generate_node_id from .generate_node_name import generate_node_name -from .generate_edge_name import generate_edge_name diff --git a/cognee/modules/engine/utils/generate_node_name.py b/cognee/modules/engine/utils/generate_node_name.py index a2871875b..84b266198 100644 --- a/cognee/modules/engine/utils/generate_node_name.py +++ b/cognee/modules/engine/utils/generate_node_name.py @@ -1,2 +1,2 @@ def generate_node_name(name: str) -> str: - return name.lower().replace("'", "") + return name.lower().replace(" ", "_").replace("'", "") diff --git a/cognee/modules/graph/utils/get_graph_from_model.py b/cognee/modules/graph/utils/get_graph_from_model.py index 29137ddc7..ef402e4d6 100644 --- a/cognee/modules/graph/utils/get_graph_from_model.py +++ b/cognee/modules/graph/utils/get_graph_from_model.py @@ -1,8 +1,9 @@ from datetime import datetime, timezone from cognee.infrastructure.engine import DataPoint +from cognee.modules import data from cognee.modules.storage.utils import copy_model -def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes = {}, added_edges = {}): +def get_graph_from_model(data_point: DataPoint, include_root = True): nodes = [] edges = [] @@ -16,55 +17,29 @@ def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes if isinstance(field_value, DataPoint): excluded_properties.add(field_name) - property_nodes, property_edges = get_graph_from_model(field_value, True, added_nodes, added_edges) - - for node in property_nodes: - if str(node.id) not in added_nodes: - nodes.append(node) - added_nodes[str(node.id)] = True - - for edge in property_edges: - edge_key = str(edge[0]) + str(edge[1]) + edge[2] - - if str(edge_key) not in added_edges: - edges.append(edge) - added_edges[str(edge_key)] = True + property_nodes, property_edges = get_graph_from_model(field_value, True) + nodes[:0] = property_nodes + edges[:0] = property_edges for property_node in get_own_properties(property_nodes, property_edges): - edge_key = str(data_point.id) + str(property_node.id) + field_name - - if str(edge_key) not in added_edges: - edges.append((data_point.id, property_node.id, field_name, { - "source_node_id": data_point.id, - "target_node_id": property_node.id, - "relationship_name": field_name, - "updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"), - })) - added_edges[str(edge_key)] = True + edges.append((data_point.id, property_node.id, field_name, { + "source_node_id": data_point.id, + "target_node_id": property_node.id, + "relationship_name": field_name, + "updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"), + })) continue - if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint): - excluded_properties.add(field_name) + if isinstance(field_value, list): + if isinstance(field_value[0], DataPoint): + excluded_properties.add(field_name) - for item in field_value: - property_nodes, property_edges = get_graph_from_model(item, True, added_nodes, added_edges) + for item in field_value: + property_nodes, property_edges = get_graph_from_model(item, True) + nodes[:0] = property_nodes + edges[:0] = property_edges - for node in property_nodes: - if str(node.id) not in added_nodes: - nodes.append(node) - added_nodes[str(node.id)] = True - - for edge in property_edges: - edge_key = str(edge[0]) + str(edge[1]) + edge[2] - - if str(edge_key) not in added_edges: - edges.append(edge) - added_edges[edge_key] = True - - for property_node in get_own_properties(property_nodes, property_edges): - edge_key = str(data_point.id) + str(property_node.id) + field_name - - if str(edge_key) not in added_edges: + for property_node in get_own_properties(property_nodes, property_edges): edges.append((data_point.id, property_node.id, field_name, { "source_node_id": data_point.id, "target_node_id": property_node.id, @@ -74,8 +49,7 @@ def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes "type": "list" }, })) - added_edges[edge_key] = True - continue + continue data_point_properties[field_name] = field_value diff --git a/cognee/shared/utils.py b/cognee/shared/utils.py index 42a95b88b..e32fad15e 100644 --- a/cognee/shared/utils.py +++ b/cognee/shared/utils.py @@ -115,7 +115,7 @@ def prepare_edges(graph, source, target, edge_key): source: str(edge[0]), target: str(edge[1]), edge_key: str(edge[2]), - } for edge in graph.edges(keys = True, data = True)] + } for edge in graph.edges] return pd.DataFrame(edge_list) diff --git a/cognee/tasks/graph/__init__.py b/cognee/tasks/graph/__init__.py index eafc12921..94dc82f20 100644 --- a/cognee/tasks/graph/__init__.py +++ b/cognee/tasks/graph/__init__.py @@ -1,3 +1,2 @@ from .extract_graph_from_data import extract_graph_from_data -from .extract_graph_from_code import extract_graph_from_code from .query_graph_connections import query_graph_connections diff --git a/cognee/tasks/graph/extract_graph_from_data.py b/cognee/tasks/graph/extract_graph_from_data.py index 9e6edcabd..36cc3e2fc 100644 --- a/cognee/tasks/graph/extract_graph_from_data.py +++ b/cognee/tasks/graph/extract_graph_from_data.py @@ -5,7 +5,7 @@ from cognee.infrastructure.databases.graph import get_graph_engine from cognee.modules.data.extraction.knowledge_graph import extract_content_graph from cognee.modules.chunking.models.DocumentChunk import DocumentChunk from cognee.modules.engine.models import EntityType, Entity -from cognee.modules.engine.utils import generate_edge_name, generate_node_id, generate_node_name +from cognee.modules.engine.utils import generate_node_id, generate_node_name from cognee.tasks.storage import add_data_points async def extract_graph_from_data(data_chunks: list[DocumentChunk], graph_model: Type[BaseModel]): @@ -95,7 +95,7 @@ async def extract_graph_from_data(data_chunks: list[DocumentChunk], graph_model: for edge in graph.edges: source_node_id = generate_node_id(edge.source_node_id) target_node_id = generate_node_id(edge.target_node_id) - relationship_name = generate_edge_name(edge.relationship_name) + relationship_name = generate_node_name(edge.relationship_name) edge_key = str(source_node_id) + str(target_node_id) + relationship_name @@ -105,7 +105,7 @@ async def extract_graph_from_data(data_chunks: list[DocumentChunk], graph_model: target_node_id, edge.relationship_name, dict( - relationship_name = generate_edge_name(edge.relationship_name), + relationship_name = generate_node_name(edge.relationship_name), source_node_id = source_node_id, target_node_id = target_node_id, ), diff --git a/cognee/tasks/graph/query_graph_connections.py b/cognee/tasks/graph/query_graph_connections.py index cd4d76a5e..5c538a994 100644 --- a/cognee/tasks/graph/query_graph_connections.py +++ b/cognee/tasks/graph/query_graph_connections.py @@ -27,8 +27,8 @@ async def query_graph_connections(query: str, exploration_levels = 1) -> list[(s else: vector_engine = get_vector_engine() results = await asyncio.gather( - vector_engine.search("Entity_name", query_text = query, limit = 5), - vector_engine.search("EntityType_name", query_text = query, limit = 5), + vector_engine.search("Entity_text", query_text = query, limit = 5), + vector_engine.search("EntityType_text", query_text = query, limit = 5), ) results = [*results[0], *results[1]] relevant_results = [result for result in results if result.score < 0.5][:5] diff --git a/cognee/tasks/infer_data_ontology/models/models.py b/cognee/tasks/infer_data_ontology/models/models.py deleted file mode 100644 index 9c086b5c7..000000000 --- a/cognee/tasks/infer_data_ontology/models/models.py +++ /dev/null @@ -1,31 +0,0 @@ -from typing import Any, Dict, List, Optional, Union -from pydantic import BaseModel, Field - -class RelationshipModel(BaseModel): - type: str - source: str - target: str - -class NodeModel(BaseModel): - node_id: str - name: str - default_relationship: Optional[RelationshipModel] = None - children: List[Union[Dict[str, Any], "NodeModel"]] = Field(default_factory=list) - -NodeModel.model_rebuild() - - -class OntologyNode(BaseModel): - id: str = Field(..., description = "Unique identifier made from node name.") - name: str - description: str - -class OntologyEdge(BaseModel): - id: str - source_id: str - target_id: str - relationship_type: str - -class GraphOntology(BaseModel): - nodes: list[OntologyNode] - edges: list[OntologyEdge] diff --git a/cognee/tasks/storage/index_data_points.py b/cognee/tasks/storage/index_data_points.py index dc74d705d..a28335e24 100644 --- a/cognee/tasks/storage/index_data_points.py +++ b/cognee/tasks/storage/index_data_points.py @@ -47,7 +47,7 @@ def get_data_points_from_model(data_point: DataPoint, added_data_points = {}) -> added_data_points[str(new_point.id)] = True data_points.append(new_point) - if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint): + if isinstance(field_value, list) and isinstance(field_value[0], DataPoint): for field_value_item in field_value: new_data_points = get_data_points_from_model(field_value_item, added_data_points) @@ -56,8 +56,7 @@ def get_data_points_from_model(data_point: DataPoint, added_data_points = {}) -> added_data_points[str(new_point.id)] = True data_points.append(new_point) - if (str(data_point.id) not in added_data_points): - data_points.append(data_point) + data_points.append(data_point) return data_points diff --git a/cognee/tasks/summarization/models/TextSummary.py b/cognee/tasks/summarization/models/TextSummary.py index c6a932b37..5e724cd63 100644 --- a/cognee/tasks/summarization/models/TextSummary.py +++ b/cognee/tasks/summarization/models/TextSummary.py @@ -4,8 +4,9 @@ from cognee.modules.data.processing.document_types import Document class TextSummary(DataPoint): text: str - made_from: DocumentChunk + chunk: DocumentChunk _metadata: dict = { "index_fields": ["text"], } + diff --git a/cognee/tasks/summarization/summarize_text.py b/cognee/tasks/summarization/summarize_text.py index 47d6946bb..a1abacccf 100644 --- a/cognee/tasks/summarization/summarize_text.py +++ b/cognee/tasks/summarization/summarize_text.py @@ -17,12 +17,12 @@ async def summarize_text(data_chunks: list[DocumentChunk], summarization_model: summaries = [ TextSummary( - id = uuid5(chunk.id, "TextSummary"), - made_from = chunk, + id = uuid5(chunk.id, "summary"), + chunk = chunk, text = chunk_summaries[chunk_index].summary, ) for (chunk_index, chunk) in enumerate(data_chunks) ] - await add_data_points(summaries) + add_data_points(summaries) return data_chunks diff --git a/cognee/tests/test_library.py b/cognee/tests/test_library.py index 2e707b64c..d7e7e5fe8 100755 --- a/cognee/tests/test_library.py +++ b/cognee/tests/test_library.py @@ -32,8 +32,8 @@ async def main(): from cognee.infrastructure.databases.vector import get_vector_engine vector_engine = get_vector_engine() - random_node = (await vector_engine.search("Entity_name", "AI"))[0] - random_node_name = random_node.payload["text"] + random_node = (await vector_engine.search("Entity", "AI"))[0] + random_node_name = random_node.payload["name"] search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name) assert len(search_results) != 0, "The search results list is empty." diff --git a/cognee/tests/test_neo4j.py b/cognee/tests/test_neo4j.py index 9cf1c53dd..2f9abf124 100644 --- a/cognee/tests/test_neo4j.py +++ b/cognee/tests/test_neo4j.py @@ -36,8 +36,8 @@ async def main(): from cognee.infrastructure.databases.vector import get_vector_engine vector_engine = get_vector_engine() - random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0] - random_node_name = random_node.payload["text"] + random_node = (await vector_engine.search("Entity", "AI"))[0] + random_node_name = random_node.payload["name"] search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name) assert len(search_results) != 0, "The search results list is empty." diff --git a/cognee/tests/test_pgvector.py b/cognee/tests/test_pgvector.py index ac4d08fbb..4adc6accc 100644 --- a/cognee/tests/test_pgvector.py +++ b/cognee/tests/test_pgvector.py @@ -65,8 +65,8 @@ async def main(): from cognee.infrastructure.databases.vector import get_vector_engine vector_engine = get_vector_engine() - random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0] - random_node_name = random_node.payload["text"] + random_node = (await vector_engine.search("Entity", "AI"))[0] + random_node_name = random_node.payload["name"] search_results = await cognee.search(SearchType.INSIGHTS, query=random_node_name) assert len(search_results) != 0, "The search results list is empty." diff --git a/cognee/tests/test_qdrant.py b/cognee/tests/test_qdrant.py index 784b3f27a..84fac6a2e 100644 --- a/cognee/tests/test_qdrant.py +++ b/cognee/tests/test_qdrant.py @@ -37,8 +37,8 @@ async def main(): from cognee.infrastructure.databases.vector import get_vector_engine vector_engine = get_vector_engine() - random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0] - random_node_name = random_node.payload["text"] + random_node = (await vector_engine.search("Entity", "AI"))[0] + random_node_name = random_node.payload["name"] search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name) assert len(search_results) != 0, "The search results list is empty." diff --git a/cognee/tests/test_weaviate.py b/cognee/tests/test_weaviate.py index f788f9973..e943e1ec9 100644 --- a/cognee/tests/test_weaviate.py +++ b/cognee/tests/test_weaviate.py @@ -35,8 +35,8 @@ async def main(): from cognee.infrastructure.databases.vector import get_vector_engine vector_engine = get_vector_engine() - random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0] - random_node_name = random_node.payload["text"] + random_node = (await vector_engine.search("Entity", "AI"))[0] + random_node_name = random_node.payload["name"] search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name) assert len(search_results) != 0, "The search results list is empty." diff --git a/examples/python/GraphModel.py b/examples/python/GraphModel.py deleted file mode 100644 index 01251fc20..000000000 --- a/examples/python/GraphModel.py +++ /dev/null @@ -1,62 +0,0 @@ - -from typing import Optional -from uuid import UUID -from datetime import datetime -from pydantic import BaseModel - - -async def add_data_points(collection_name: str, data_points: list): - pass - - - -class Summary(BaseModel): - id: UUID - text: str - chunk: "Chunk" - created_at: datetime - updated_at: Optional[datetime] - - vector_index = ["text"] - -class Chunk(BaseModel): - id: UUID - text: str - summary: Summary - document: "Document" - created_at: datetime - updated_at: Optional[datetime] - word_count: int - chunk_index: int - cut_type: str - - vector_index = ["text"] - -class Document(BaseModel): - id: UUID - chunks: list[Chunk] - created_at: datetime - updated_at: Optional[datetime] - -class EntityType(BaseModel): - id: UUID - name: str - description: str - created_at: datetime - updated_at: Optional[datetime] - - vector_index = ["name"] - -class Entity(BaseModel): - id: UUID - name: str - type: EntityType - description: str - chunks: list[Chunk] - created_at: datetime - updated_at: Optional[datetime] - - vector_index = ["name"] - -class OntologyModel(BaseModel): - chunks: list[Chunk] diff --git a/notebooks/cognee_demo.ipynb b/notebooks/cognee_demo.ipynb index 06cd2a86a..396d7b980 100644 --- a/notebooks/cognee_demo.ipynb +++ b/notebooks/cognee_demo.ipynb @@ -265,7 +265,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "id": "df16431d0f48b006", "metadata": { "ExecuteTime": { @@ -304,7 +304,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "id": "9086abf3af077ab4", "metadata": { "ExecuteTime": { @@ -349,7 +349,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 15, "id": "a9de0cc07f798b7f", "metadata": { "ExecuteTime": { @@ -393,7 +393,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 16, "id": "185ff1c102d06111", "metadata": { "ExecuteTime": { @@ -437,7 +437,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 17, "id": "d55ce4c58f8efb67", "metadata": { "ExecuteTime": { @@ -479,7 +479,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 18, "id": "ca4ecc32721ad332", "metadata": { "ExecuteTime": { @@ -572,7 +572,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 26, "id": "9f1a1dbd", "metadata": {}, "outputs": [], @@ -758,7 +758,7 @@ "from cognee.infrastructure.databases.vector import get_vector_engine\n", "\n", "vector_engine = get_vector_engine()\n", - "results = await search(vector_engine, \"Entity_name\", \"sarah.nguyen@example.com\")\n", + "results = await search(vector_engine, \"entities\", \"sarah.nguyen@example.com\")\n", "for result in results:\n", " print(result)" ] @@ -788,8 +788,8 @@ "source": [ "from cognee.api.v1.search import SearchType\n", "\n", - "node = (await vector_engine.search(\"Entity_name\", \"sarah.nguyen@example.com\"))[0]\n", - "node_name = node.payload[\"text\"]\n", + "node = (await vector_engine.search(\"entities\", \"sarah.nguyen@example.com\"))[0]\n", + "node_name = node.payload[\"name\"]\n", "\n", "search_results = await cognee.search(SearchType.SUMMARIES, query = node_name)\n", "print(\"\\n\\Extracted summaries are:\\n\")\n", @@ -881,7 +881,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.8" + "version": "3.9.6" } }, "nbformat": 4, diff --git a/poetry.lock b/poetry.lock index 12b1e59ba..e361398d5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3215,54 +3215,6 @@ docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"] embeddings = ["awscli (>=1.29.57)", "boto3 (>=1.28.57)", "botocore (>=1.31.57)", "cohere", "google-generativeai", "huggingface-hub", "ibm-watsonx-ai (>=1.1.2)", "instructorembedding", "ollama", "open-clip-torch", "openai (>=1.6.1)", "pillow", "sentence-transformers", "torch"] tests = ["aiohttp", "boto3", "duckdb", "pandas (>=1.4)", "polars (>=0.19,<=1.3.0)", "pytest", "pytest-asyncio", "pytest-mock", "pytz", "tantivy"] -[[package]] -name = "langchain-core" -version = "0.3.15" -description = "Building applications with LLMs through composability" -optional = false -python-versions = "<4.0,>=3.9" -files = [ - {file = "langchain_core-0.3.15-py3-none-any.whl", hash = "sha256:3d4ca6dbb8ed396a6ee061063832a2451b0ce8c345570f7b086ffa7288e4fa29"}, - {file = "langchain_core-0.3.15.tar.gz", hash = "sha256:b1a29787a4ffb7ec2103b4e97d435287201da7809b369740dd1e32f176325aba"}, -] - -[package.dependencies] -jsonpatch = ">=1.33,<2.0" -langsmith = ">=0.1.125,<0.2.0" -packaging = ">=23.2,<25" -pydantic = {version = ">=2.5.2,<3.0.0", markers = "python_full_version < \"3.12.4\""} -PyYAML = ">=5.3" -tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<10.0.0" -typing-extensions = ">=4.7" - -[[package]] -name = "langchain-text-splitters" -version = "0.3.2" -description = "LangChain text splitting utilities" -optional = false -python-versions = "<4.0,>=3.9" -files = [ - {file = "langchain_text_splitters-0.3.2-py3-none-any.whl", hash = "sha256:0db28c53f41d1bc024cdb3b1646741f6d46d5371e90f31e7e7c9fbe75d01c726"}, - {file = "langchain_text_splitters-0.3.2.tar.gz", hash = "sha256:81e6515d9901d6dd8e35fb31ccd4f30f76d44b771890c789dc835ef9f16204df"}, -] - -[package.dependencies] -langchain-core = ">=0.3.15,<0.4.0" - -[[package]] -name = "langdetect" -version = "1.0.9" -description = "Language detection library ported from Google's language-detection." -optional = false -python-versions = "*" -files = [ - {file = "langdetect-1.0.9-py2-none-any.whl", hash = "sha256:7cbc0746252f19e76f77c0b1690aadf01963be835ef0cd4b56dddf2a8f1dfc2a"}, - {file = "langdetect-1.0.9.tar.gz", hash = "sha256:cbc1fef89f8d062739774bd51eda3da3274006b3661d199c2655f6b3f6d605a0"}, -] - -[package.dependencies] -six = "*" - [[package]] name = "langfuse" version = "2.53.9" @@ -5083,8 +5035,8 @@ argon2-cffi = {version = ">=23.1.0,<24", optional = true, markers = "extra == \" bcrypt = {version = ">=4.1.2,<5", optional = true, markers = "extra == \"bcrypt\""} [package.extras] -argon2 = ["argon2-cffi (>=23.1.0,<24)"] -bcrypt = ["bcrypt (>=4.1.2,<5)"] +argon2 = ["argon2-cffi (==23.1.0)"] +bcrypt = ["bcrypt (==4.1.2)"] [[package]] name = "pyarrow" @@ -7782,4 +7734,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.9.0,<3.12" -content-hash = "fef56656ead761cab7d5c3d0bf1fa5a54608db73b14616d08e5fb152dba91236" +content-hash = "bb70798562fee44c6daa2f5c7fa4d17165fb76016618c1fc8fd0782c5aa4a6de" diff --git a/pyproject.toml b/pyproject.toml index 92d8f829b..28529b446 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,7 @@ langsmith = "0.1.139" langdetect = "1.0.9" posthog = "^3.5.0" lancedb = "0.15.0" -litellm = "1.49.1" +litellm = "1.38.10" groq = "0.8.0" langfuse = "^2.32.0" pydantic-settings = "^2.2.1" diff --git a/tools/daily_twitter_stats.py b/tools/daily_twitter_stats.py new file mode 100644 index 000000000..d66f052d9 --- /dev/null +++ b/tools/daily_twitter_stats.py @@ -0,0 +1,66 @@ +import tweepy +import requests +import json +from datetime import datetime, timezone + +# Twitter API credentials from GitHub Secrets +API_KEY = '${{ secrets.TWITTER_API_KEY }}' +API_SECRET = '${{ secrets.TWITTER_API_SECRET }}' +ACCESS_TOKEN = '${{ secrets.TWITTER_ACCESS_TOKEN }}' +ACCESS_SECRET = '${{ secrets.TWITTER_ACCESS_SECRET }}' +USERNAME = '${{ secrets.TWITTER_USERNAME }}' +SEGMENT_WRITE_KEY = '${{ secrets.SEGMENT_WRITE_KEY }}' + +# Initialize Tweepy API +auth = tweepy.OAuthHandler(API_KEY, API_SECRET) +auth.set_access_token(ACCESS_TOKEN, ACCESS_SECRET) +twitter_api = tweepy.API(auth) + +# Segment endpoint +SEGMENT_ENDPOINT = 'https://api.segment.io/v1/track' + + +def get_follower_count(username): + try: + user = twitter_api.get_user(screen_name=username) + return user.followers_count + except tweepy.TweepError as e: + print(f'Error fetching follower count: {e}') + return None + + +def send_data_to_segment(username, follower_count): + current_time = datetime.now(timezone.utc).isoformat() + + data = { + 'userId': username, + 'event': 'Follower Count Update', + 'properties': { + 'username': username, + 'follower_count': follower_count, + 'timestamp': current_time + }, + 'timestamp': current_time + } + + headers = { + 'Content-Type': 'application/json', + 'Authorization': f'Basic {SEGMENT_WRITE_KEY.encode("utf-8").decode("utf-8")}' + } + + try: + response = requests.post(SEGMENT_ENDPOINT, headers=headers, data=json.dumps(data)) + + if response.status_code == 200: + print(f'Successfully sent data to Segment for {username}') + else: + print(f'Failed to send data to Segment. Status code: {response.status_code}, Response: {response.text}') + except requests.exceptions.RequestException as e: + print(f'Error sending data to Segment: {e}') + + +follower_count = get_follower_count(USERNAME) +if follower_count is not None: + send_data_to_segment(USERNAME, follower_count) +else: + print('Failed to retrieve follower count.') From 63900f6b0a687461e56a84daf84d868990646165 Mon Sep 17 00:00:00 2001 From: Boris Arzentar Date: Thu, 7 Nov 2024 11:36:31 +0100 Subject: [PATCH 03/20] fix: serialize UUID in pgvector data point payload --- .../databases/vector/pgvector/PGVectorAdapter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py index d9aecec90..a911d4785 100644 --- a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +++ b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py @@ -8,7 +8,7 @@ from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker from cognee.infrastructure.engine import DataPoint -from .serialize_datetime import serialize_datetime +from .serialize_data import serialize_data from ..models.ScoredResult import ScoredResult from ..vector_db_interface import VectorDBInterface from ..embeddings.EmbeddingEngine import EmbeddingEngine @@ -111,7 +111,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): PGVectorDataPoint( id=data_point.id, vector=data_vectors[data_index], - payload=serialize_datetime(data_point.model_dump()), + payload=serialize_data(data_point.model_dump()), ) for (data_index, data_point) in enumerate(data_points) ] From 7ea5f638fe0478bf34f528481aec23f033ecc978 Mon Sep 17 00:00:00 2001 From: Boris Arzentar Date: Thu, 7 Nov 2024 15:38:03 +0100 Subject: [PATCH 04/20] fix: add summaries to the graph --- .../databases/graph/networkx/adapter.py | 26 ++------ .../hybrid/falkordb/FalkorDBAdapter.py | 30 +++++++++ .../vector/weaviate_db/WeaviateAdapter.py | 2 +- cognee/modules/chunking/TextChunker.py | 6 +- .../graph/utils/get_graph_from_model.py | 66 +++++++++++++------ cognee/tasks/graph/query_graph_connections.py | 4 +- cognee/tasks/storage/index_data_points.py | 3 +- .../tasks/summarization/models/TextSummary.py | 3 +- cognee/tasks/summarization/summarize_text.py | 7 +- cognee/tests/test_library.py | 4 +- cognee/tests/test_neo4j.py | 4 +- cognee/tests/test_pgvector.py | 4 +- cognee/tests/test_qdrant.py | 4 +- cognee/tests/test_weaviate.py | 4 +- 14 files changed, 106 insertions(+), 61 deletions(-) diff --git a/cognee/infrastructure/databases/graph/networkx/adapter.py b/cognee/infrastructure/databases/graph/networkx/adapter.py index b0a9e7a13..dcb05c2ed 100644 --- a/cognee/infrastructure/databases/graph/networkx/adapter.py +++ b/cognee/infrastructure/databases/graph/networkx/adapter.py @@ -247,27 +247,15 @@ class NetworkXAdapter(GraphDBInterface): async with aiofiles.open(file_path, "r") as file: graph_data = json.loads(await file.read()) for node in graph_data["nodes"]: - try: - node["id"] = UUID(node["id"]) - except: - pass - if "updated_at" in node: - node["updated_at"] = datetime.strptime(node["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z") + node["id"] = UUID(node["id"]) + node["updated_at"] = datetime.strptime(node["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z") for edge in graph_data["links"]: - try: - source_id = UUID(edge["source"]) - target_id = UUID(edge["target"]) - - edge["source"] = source_id - edge["target"] = target_id - edge["source_node_id"] = source_id - edge["target_node_id"] = target_id - except: - pass - - if "updated_at" in edge: - edge["updated_at"] = datetime.strptime(edge["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z") + edge["source"] = UUID(edge["source"]) + edge["target"] = UUID(edge["target"]) + edge["source_node_id"] = UUID(edge["source_node_id"]) + edge["target_node_id"] = UUID(edge["target_node_id"]) + edge["updated_at"] = datetime.strptime(edge["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z") self.graph = nx.readwrite.json_graph.node_link_graph(graph_data) else: diff --git a/cognee/infrastructure/databases/hybrid/falkordb/FalkorDBAdapter.py b/cognee/infrastructure/databases/hybrid/falkordb/FalkorDBAdapter.py index effe9e682..ea5a75088 100644 --- a/cognee/infrastructure/databases/hybrid/falkordb/FalkorDBAdapter.py +++ b/cognee/infrastructure/databases/hybrid/falkordb/FalkorDBAdapter.py @@ -1,6 +1,7 @@ import asyncio from textwrap import dedent from typing import Any +from uuid import UUID from falkordb import FalkorDB from cognee.infrastructure.engine import DataPoint @@ -161,6 +162,35 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface): async def extract_nodes(self, data_point_ids: list[str]): return await self.retrieve(data_point_ids) + async def get_connections(self, node_id: UUID) -> list: + predecessors_query = """ + MATCH (node)<-[relation]-(neighbour) + WHERE node.id = $node_id + RETURN neighbour, relation, node + """ + successors_query = """ + MATCH (node)-[relation]->(neighbour) + WHERE node.id = $node_id + RETURN node, relation, neighbour + """ + + predecessors, successors = await asyncio.gather( + self.query(predecessors_query, dict(node_id = node_id)), + self.query(successors_query, dict(node_id = node_id)), + ) + + connections = [] + + for neighbour in predecessors: + neighbour = neighbour["relation"] + connections.append((neighbour[0], { "relationship_name": neighbour[1] }, neighbour[2])) + + for neighbour in successors: + neighbour = neighbour["relation"] + connections.append((neighbour[0], { "relationship_name": neighbour[1] }, neighbour[2])) + + return connections + async def search( self, collection_name: str, diff --git a/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py b/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py index 2e4e88323..dd7539118 100644 --- a/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py +++ b/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py @@ -178,7 +178,7 @@ class WeaviateAdapter(VectorDBInterface): return [ ScoredResult( - id = UUID(result.id), + id = UUID(result.uuid), payload = result.properties, score = float(result.metadata.score) ) for result in search_result.objects diff --git a/cognee/modules/chunking/TextChunker.py b/cognee/modules/chunking/TextChunker.py index 4717d108d..714383804 100644 --- a/cognee/modules/chunking/TextChunker.py +++ b/cognee/modules/chunking/TextChunker.py @@ -29,7 +29,7 @@ class TextChunker(): else: if len(self.paragraph_chunks) == 0: yield DocumentChunk( - id = str(chunk_data["chunk_id"]), + id = chunk_data["chunk_id"], text = chunk_data["text"], word_count = chunk_data["word_count"], is_part_of = self.document, @@ -42,7 +42,7 @@ class TextChunker(): chunk_text = " ".join(chunk["text"] for chunk in self.paragraph_chunks) try: yield DocumentChunk( - id = str(uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}")), + id = uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}"), text = chunk_text, word_count = self.chunk_size, is_part_of = self.document, @@ -59,7 +59,7 @@ class TextChunker(): if len(self.paragraph_chunks) > 0: try: yield DocumentChunk( - id = str(uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}")), + id = uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}"), text = " ".join(chunk["text"] for chunk in self.paragraph_chunks), word_count = self.chunk_size, is_part_of = self.document, diff --git a/cognee/modules/graph/utils/get_graph_from_model.py b/cognee/modules/graph/utils/get_graph_from_model.py index ef402e4d6..35e00fb5d 100644 --- a/cognee/modules/graph/utils/get_graph_from_model.py +++ b/cognee/modules/graph/utils/get_graph_from_model.py @@ -1,9 +1,8 @@ from datetime import datetime, timezone from cognee.infrastructure.engine import DataPoint -from cognee.modules import data from cognee.modules.storage.utils import copy_model -def get_graph_from_model(data_point: DataPoint, include_root = True): +def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes = {}, added_edges = {}): nodes = [] edges = [] @@ -17,29 +16,55 @@ def get_graph_from_model(data_point: DataPoint, include_root = True): if isinstance(field_value, DataPoint): excluded_properties.add(field_name) - property_nodes, property_edges = get_graph_from_model(field_value, True) - nodes[:0] = property_nodes - edges[:0] = property_edges + property_nodes, property_edges = get_graph_from_model(field_value, True, added_nodes, added_edges) + + for node in property_nodes: + if str(node.id) not in added_nodes: + nodes.append(node) + added_nodes[str(node.id)] = True + + for edge in property_edges: + edge_key = str(edge[0]) + str(edge[1]) + edge[2] + + if str(edge_key) not in added_edges: + edges.append(edge) + added_edges[str(edge_key)] = True for property_node in get_own_properties(property_nodes, property_edges): - edges.append((data_point.id, property_node.id, field_name, { - "source_node_id": data_point.id, - "target_node_id": property_node.id, - "relationship_name": field_name, - "updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"), - })) + edge_key = str(data_point.id) + str(property_node.id) + field_name + + if str(edge_key) not in added_edges: + edges.append((data_point.id, property_node.id, field_name, { + "source_node_id": data_point.id, + "target_node_id": property_node.id, + "relationship_name": field_name, + "updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"), + })) + added_edges[str(edge_key)] = True continue - if isinstance(field_value, list): - if isinstance(field_value[0], DataPoint): - excluded_properties.add(field_name) + if isinstance(field_value, list) and isinstance(field_value[0], DataPoint): + excluded_properties.add(field_name) - for item in field_value: - property_nodes, property_edges = get_graph_from_model(item, True) - nodes[:0] = property_nodes - edges[:0] = property_edges + for item in field_value: + property_nodes, property_edges = get_graph_from_model(item, True, added_nodes, added_edges) - for property_node in get_own_properties(property_nodes, property_edges): + for node in property_nodes: + if str(node.id) not in added_nodes: + nodes.append(node) + added_nodes[str(node.id)] = True + + for edge in property_edges: + edge_key = str(edge[0]) + str(edge[1]) + edge[2] + + if str(edge_key) not in added_edges: + edges.append(edge) + added_edges[edge_key] = True + + for property_node in get_own_properties(property_nodes, property_edges): + edge_key = str(data_point.id) + str(property_node.id) + field_name + + if str(edge_key) not in added_edges: edges.append((data_point.id, property_node.id, field_name, { "source_node_id": data_point.id, "target_node_id": property_node.id, @@ -49,7 +74,8 @@ def get_graph_from_model(data_point: DataPoint, include_root = True): "type": "list" }, })) - continue + added_edges[edge_key] = True + continue data_point_properties[field_name] = field_value diff --git a/cognee/tasks/graph/query_graph_connections.py b/cognee/tasks/graph/query_graph_connections.py index 5c538a994..cd4d76a5e 100644 --- a/cognee/tasks/graph/query_graph_connections.py +++ b/cognee/tasks/graph/query_graph_connections.py @@ -27,8 +27,8 @@ async def query_graph_connections(query: str, exploration_levels = 1) -> list[(s else: vector_engine = get_vector_engine() results = await asyncio.gather( - vector_engine.search("Entity_text", query_text = query, limit = 5), - vector_engine.search("EntityType_text", query_text = query, limit = 5), + vector_engine.search("Entity_name", query_text = query, limit = 5), + vector_engine.search("EntityType_name", query_text = query, limit = 5), ) results = [*results[0], *results[1]] relevant_results = [result for result in results if result.score < 0.5][:5] diff --git a/cognee/tasks/storage/index_data_points.py b/cognee/tasks/storage/index_data_points.py index a28335e24..681fbaa1f 100644 --- a/cognee/tasks/storage/index_data_points.py +++ b/cognee/tasks/storage/index_data_points.py @@ -56,7 +56,8 @@ def get_data_points_from_model(data_point: DataPoint, added_data_points = {}) -> added_data_points[str(new_point.id)] = True data_points.append(new_point) - data_points.append(data_point) + if (str(data_point.id) not in added_data_points): + data_points.append(data_point) return data_points diff --git a/cognee/tasks/summarization/models/TextSummary.py b/cognee/tasks/summarization/models/TextSummary.py index 5e724cd63..c6a932b37 100644 --- a/cognee/tasks/summarization/models/TextSummary.py +++ b/cognee/tasks/summarization/models/TextSummary.py @@ -4,9 +4,8 @@ from cognee.modules.data.processing.document_types import Document class TextSummary(DataPoint): text: str - chunk: DocumentChunk + made_from: DocumentChunk _metadata: dict = { "index_fields": ["text"], } - diff --git a/cognee/tasks/summarization/summarize_text.py b/cognee/tasks/summarization/summarize_text.py index a1abacccf..756f65e39 100644 --- a/cognee/tasks/summarization/summarize_text.py +++ b/cognee/tasks/summarization/summarize_text.py @@ -5,6 +5,7 @@ from pydantic import BaseModel from cognee.modules.data.extraction.extract_summary import extract_summary from cognee.modules.chunking.models.DocumentChunk import DocumentChunk from cognee.tasks.storage import add_data_points +from cognee.tasks.storage.index_data_points import get_data_points_from_model from .models.TextSummary import TextSummary async def summarize_text(data_chunks: list[DocumentChunk], summarization_model: Type[BaseModel]): @@ -17,12 +18,12 @@ async def summarize_text(data_chunks: list[DocumentChunk], summarization_model: summaries = [ TextSummary( - id = uuid5(chunk.id, "summary"), - chunk = chunk, + id = uuid5(chunk.id, "TextSummary"), + made_from = chunk, text = chunk_summaries[chunk_index].summary, ) for (chunk_index, chunk) in enumerate(data_chunks) ] - add_data_points(summaries) + await add_data_points(summaries) return data_chunks diff --git a/cognee/tests/test_library.py b/cognee/tests/test_library.py index d7e7e5fe8..2e707b64c 100755 --- a/cognee/tests/test_library.py +++ b/cognee/tests/test_library.py @@ -32,8 +32,8 @@ async def main(): from cognee.infrastructure.databases.vector import get_vector_engine vector_engine = get_vector_engine() - random_node = (await vector_engine.search("Entity", "AI"))[0] - random_node_name = random_node.payload["name"] + random_node = (await vector_engine.search("Entity_name", "AI"))[0] + random_node_name = random_node.payload["text"] search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name) assert len(search_results) != 0, "The search results list is empty." diff --git a/cognee/tests/test_neo4j.py b/cognee/tests/test_neo4j.py index 2f9abf124..0783e973a 100644 --- a/cognee/tests/test_neo4j.py +++ b/cognee/tests/test_neo4j.py @@ -36,8 +36,8 @@ async def main(): from cognee.infrastructure.databases.vector import get_vector_engine vector_engine = get_vector_engine() - random_node = (await vector_engine.search("Entity", "AI"))[0] - random_node_name = random_node.payload["name"] + random_node = (await vector_engine.search("Entity_name", "AI"))[0] + random_node_name = random_node.payload["text"] search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name) assert len(search_results) != 0, "The search results list is empty." diff --git a/cognee/tests/test_pgvector.py b/cognee/tests/test_pgvector.py index 4adc6accc..cea7c8f72 100644 --- a/cognee/tests/test_pgvector.py +++ b/cognee/tests/test_pgvector.py @@ -65,8 +65,8 @@ async def main(): from cognee.infrastructure.databases.vector import get_vector_engine vector_engine = get_vector_engine() - random_node = (await vector_engine.search("Entity", "AI"))[0] - random_node_name = random_node.payload["name"] + random_node = (await vector_engine.search("Entity_name", "AI"))[0] + random_node_name = random_node.payload["text"] search_results = await cognee.search(SearchType.INSIGHTS, query=random_node_name) assert len(search_results) != 0, "The search results list is empty." diff --git a/cognee/tests/test_qdrant.py b/cognee/tests/test_qdrant.py index 84fac6a2e..faa2cbcf4 100644 --- a/cognee/tests/test_qdrant.py +++ b/cognee/tests/test_qdrant.py @@ -37,8 +37,8 @@ async def main(): from cognee.infrastructure.databases.vector import get_vector_engine vector_engine = get_vector_engine() - random_node = (await vector_engine.search("Entity", "AI"))[0] - random_node_name = random_node.payload["name"] + random_node = (await vector_engine.search("Entity_name", "AI"))[0] + random_node_name = random_node.payload["text"] search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name) assert len(search_results) != 0, "The search results list is empty." diff --git a/cognee/tests/test_weaviate.py b/cognee/tests/test_weaviate.py index e943e1ec9..121c1749e 100644 --- a/cognee/tests/test_weaviate.py +++ b/cognee/tests/test_weaviate.py @@ -35,8 +35,8 @@ async def main(): from cognee.infrastructure.databases.vector import get_vector_engine vector_engine = get_vector_engine() - random_node = (await vector_engine.search("Entity", "AI"))[0] - random_node_name = random_node.payload["name"] + random_node = (await vector_engine.search("Entity_name", "AI"))[0] + random_node_name = random_node.payload["text"] search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name) assert len(search_results) != 0, "The search results list is empty." From e1e5e7336a7f81ca20002dc20f73659670bd0c62 Mon Sep 17 00:00:00 2001 From: Boris Arzentar Date: Thu, 7 Nov 2024 15:41:11 +0100 Subject: [PATCH 05/20] fix: remove unused import --- cognee/tasks/summarization/summarize_text.py | 1 - 1 file changed, 1 deletion(-) diff --git a/cognee/tasks/summarization/summarize_text.py b/cognee/tasks/summarization/summarize_text.py index 756f65e39..47d6946bb 100644 --- a/cognee/tasks/summarization/summarize_text.py +++ b/cognee/tasks/summarization/summarize_text.py @@ -5,7 +5,6 @@ from pydantic import BaseModel from cognee.modules.data.extraction.extract_summary import extract_summary from cognee.modules.chunking.models.DocumentChunk import DocumentChunk from cognee.tasks.storage import add_data_points -from cognee.tasks.storage.index_data_points import get_data_points_from_model from .models.TextSummary import TextSummary async def summarize_text(data_chunks: list[DocumentChunk], summarization_model: Type[BaseModel]): From 68700f32c78c99610f6fa8a2a4bbe983f6cb905d Mon Sep 17 00:00:00 2001 From: Boris Arzentar Date: Fri, 8 Nov 2024 15:31:02 +0100 Subject: [PATCH 06/20] fix: add code graph generation pipeline --- .../databases/graph/networkx/adapter.py | 30 ++++++++++++++----- .../graph/utils/get_graph_from_model.py | 2 +- cognee/shared/utils.py | 2 +- cognee/tasks/graph/__init__.py | 1 + cognee/tasks/storage/index_data_points.py | 2 +- 5 files changed, 27 insertions(+), 10 deletions(-) diff --git a/cognee/infrastructure/databases/graph/networkx/adapter.py b/cognee/infrastructure/databases/graph/networkx/adapter.py index dcb05c2ed..6c7abd498 100644 --- a/cognee/infrastructure/databases/graph/networkx/adapter.py +++ b/cognee/infrastructure/databases/graph/networkx/adapter.py @@ -30,6 +30,10 @@ class NetworkXAdapter(GraphDBInterface): def __init__(self, filename = "cognee_graph.pkl"): self.filename = filename + async def get_graph_data(self): + await self.load_graph_from_file() + return (list(self.graph.nodes(data = True)), list(self.graph.edges(data = True, keys = True))) + async def query(self, query: str, params: dict): pass @@ -247,15 +251,27 @@ class NetworkXAdapter(GraphDBInterface): async with aiofiles.open(file_path, "r") as file: graph_data = json.loads(await file.read()) for node in graph_data["nodes"]: - node["id"] = UUID(node["id"]) - node["updated_at"] = datetime.strptime(node["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z") + try: + node["id"] = UUID(node["id"]) + except: + pass + if "updated_at" in node: + node["updated_at"] = datetime.strptime(node["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z") for edge in graph_data["links"]: - edge["source"] = UUID(edge["source"]) - edge["target"] = UUID(edge["target"]) - edge["source_node_id"] = UUID(edge["source_node_id"]) - edge["target_node_id"] = UUID(edge["target_node_id"]) - edge["updated_at"] = datetime.strptime(edge["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z") + try: + source_id = UUID(edge["source"]) + target_id = UUID(edge["target"]) + + edge["source"] = source_id + edge["target"] = target_id + edge["source_node_id"] = source_id + edge["target_node_id"] = target_id + except: + pass + + if "updated_at" in node: + edge["updated_at"] = datetime.strptime(edge["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z") self.graph = nx.readwrite.json_graph.node_link_graph(graph_data) else: diff --git a/cognee/modules/graph/utils/get_graph_from_model.py b/cognee/modules/graph/utils/get_graph_from_model.py index 35e00fb5d..29137ddc7 100644 --- a/cognee/modules/graph/utils/get_graph_from_model.py +++ b/cognee/modules/graph/utils/get_graph_from_model.py @@ -43,7 +43,7 @@ def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes added_edges[str(edge_key)] = True continue - if isinstance(field_value, list) and isinstance(field_value[0], DataPoint): + if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint): excluded_properties.add(field_name) for item in field_value: diff --git a/cognee/shared/utils.py b/cognee/shared/utils.py index e32fad15e..42a95b88b 100644 --- a/cognee/shared/utils.py +++ b/cognee/shared/utils.py @@ -115,7 +115,7 @@ def prepare_edges(graph, source, target, edge_key): source: str(edge[0]), target: str(edge[1]), edge_key: str(edge[2]), - } for edge in graph.edges] + } for edge in graph.edges(keys = True, data = True)] return pd.DataFrame(edge_list) diff --git a/cognee/tasks/graph/__init__.py b/cognee/tasks/graph/__init__.py index 94dc82f20..eafc12921 100644 --- a/cognee/tasks/graph/__init__.py +++ b/cognee/tasks/graph/__init__.py @@ -1,2 +1,3 @@ from .extract_graph_from_data import extract_graph_from_data +from .extract_graph_from_code import extract_graph_from_code from .query_graph_connections import query_graph_connections diff --git a/cognee/tasks/storage/index_data_points.py b/cognee/tasks/storage/index_data_points.py index 681fbaa1f..dc74d705d 100644 --- a/cognee/tasks/storage/index_data_points.py +++ b/cognee/tasks/storage/index_data_points.py @@ -47,7 +47,7 @@ def get_data_points_from_model(data_point: DataPoint, added_data_points = {}) -> added_data_points[str(new_point.id)] = True data_points.append(new_point) - if isinstance(field_value, list) and isinstance(field_value[0], DataPoint): + if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint): for field_value_item in field_value: new_data_points = get_data_points_from_model(field_value_item, added_data_points) From 0b3a94e90bff3b5b279885ae8578b8b9976b95f4 Mon Sep 17 00:00:00 2001 From: hajdul88 <52442977+hajdul88@users.noreply.github.com> Date: Thu, 7 Nov 2024 16:19:38 +0100 Subject: [PATCH 07/20] fix: resolves pg asyncpg UUID to UUID --- .../infrastructure/databases/vector/pgvector/PGVectorAdapter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py index a911d4785..025d361bd 100644 --- a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +++ b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py @@ -209,7 +209,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): # Create and return ScoredResult objects return [ ScoredResult( - id = UUID(row.id), + id = UUID(str(row.id)), payload = row.payload, score = row.similarity ) for row in vector_list From 9fe1b6c5faafc58163e8aaf91a0f2a5b53ef57fc Mon Sep 17 00:00:00 2001 From: Leon Luithlen Date: Mon, 11 Nov 2024 13:03:50 +0100 Subject: [PATCH 08/20] Add code_graph_demo notebook --- .gitignore | 1 + .../databases/graph/networkx/adapter.py | 2 +- notebooks/cognee_code_graph_demo.ipynb | 138 ++++++++++++++++++ 3 files changed, 140 insertions(+), 1 deletion(-) create mode 100644 notebooks/cognee_code_graph_demo.ipynb diff --git a/.gitignore b/.gitignore index f447655cf..1c75b636c 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ .prod.env cognee/.data/ +*.lance/ # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/cognee/infrastructure/databases/graph/networkx/adapter.py b/cognee/infrastructure/databases/graph/networkx/adapter.py index 6c7abd498..65aeea289 100644 --- a/cognee/infrastructure/databases/graph/networkx/adapter.py +++ b/cognee/infrastructure/databases/graph/networkx/adapter.py @@ -270,7 +270,7 @@ class NetworkXAdapter(GraphDBInterface): except: pass - if "updated_at" in node: + if "updated_at" in edge: edge["updated_at"] = datetime.strptime(edge["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z") self.graph = nx.readwrite.json_graph.node_link_graph(graph_data) diff --git a/notebooks/cognee_code_graph_demo.ipynb b/notebooks/cognee_code_graph_demo.ipynb new file mode 100644 index 000000000..5e21e9dad --- /dev/null +++ b/notebooks/cognee_code_graph_demo.ipynb @@ -0,0 +1,138 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.environ['GRAPHISTRY_USERNAME'] = input(\"Please enter your graphistry username\")\n", + "os.environ['GRAPHISTRY_PASSWORD'] = input(\"Please enter your graphistry password\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from cognee.modules.users.methods import get_default_user\n", + "\n", + "from cognee.modules.data.methods import get_datasets\n", + "from cognee.modules.data.methods.get_dataset_data import get_dataset_data\n", + "from cognee.modules.data.models import Data\n", + "\n", + "from cognee.modules.pipelines.tasks.Task import Task\n", + "from cognee.tasks.documents import classify_documents, check_permissions_on_documents, extract_chunks_from_documents\n", + "from cognee.tasks.graph import extract_graph_from_code\n", + "from cognee.tasks.storage import add_data_points\n", + "from cognee.shared.SourceCodeGraph import SourceCodeGraph\n", + "\n", + "from cognee.modules.pipelines import run_tasks\n", + "\n", + "from cognee.shared.utils import render_graph\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "user = await get_default_user()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "existing_datasets = await get_datasets(user.id)\n", + "\n", + "datasets = {}\n", + "for dataset in existing_datasets:\n", + " dataset_name = dataset.name.replace(\".\", \"_\").replace(\" \", \"_\")\n", + " data_documents: list[Data] = await get_dataset_data(dataset_id = dataset.id)\n", + " datasets[dataset_name] = data_documents\n", + "print(datasets.keys())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tasks = [\n", + " Task(classify_documents),\n", + " Task(check_permissions_on_documents, user = user, permissions = [\"write\"]),\n", + " Task(extract_chunks_from_documents), # Extract text chunks based on the document type.\n", + " Task(add_data_points, task_config = { \"batch_size\": 10 }),\n", + " Task(extract_graph_from_code, graph_model = SourceCodeGraph, task_config = { \"batch_size\": 10 }), # Generate knowledge graphs from the document chunks.\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async def run_codegraph_pipeline(tasks, data_documents):\n", + " pipeline = run_tasks(tasks, data_documents, \"code_graph_pipeline\")\n", + " results = []\n", + " async for result in pipeline:\n", + " results.append(result)\n", + " return(results)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "results = await run_codegraph_pipeline(tasks, datasets[\"main_dataset\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "await render_graph(None, include_nodes = True, include_labels = True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "cognee", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.10" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From b1b6b79ca4d5d77408b3f62587bf17d2ce3ee5c9 Mon Sep 17 00:00:00 2001 From: Boris Arzentar Date: Mon, 11 Nov 2024 14:38:59 +0100 Subject: [PATCH 09/20] fix: convert qdrant search results to ScoredPoint --- .../infrastructure/databases/vector/qdrant/QDrantAdapter.py | 1 + cognee/modules/engine/utils/__init__.py | 1 + cognee/modules/engine/utils/generate_node_name.py | 2 +- cognee/tasks/graph/extract_graph_from_data.py | 6 +++--- cognee/tests/test_qdrant.py | 2 +- cognee/tests/test_weaviate.py | 2 +- 6 files changed, 8 insertions(+), 6 deletions(-) diff --git a/cognee/infrastructure/databases/vector/qdrant/QDrantAdapter.py b/cognee/infrastructure/databases/vector/qdrant/QDrantAdapter.py index 436861a45..1efcd47b3 100644 --- a/cognee/infrastructure/databases/vector/qdrant/QDrantAdapter.py +++ b/cognee/infrastructure/databases/vector/qdrant/QDrantAdapter.py @@ -3,6 +3,7 @@ from uuid import UUID from typing import List, Dict, Optional from qdrant_client import AsyncQdrantClient, models +from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult from cognee.infrastructure.engine import DataPoint from ..vector_db_interface import VectorDBInterface from ..embeddings.EmbeddingEngine import EmbeddingEngine diff --git a/cognee/modules/engine/utils/__init__.py b/cognee/modules/engine/utils/__init__.py index 9cc2bc573..4d4ab02e7 100644 --- a/cognee/modules/engine/utils/__init__.py +++ b/cognee/modules/engine/utils/__init__.py @@ -1,2 +1,3 @@ from .generate_node_id import generate_node_id from .generate_node_name import generate_node_name +from .generate_edge_name import generate_edge_name diff --git a/cognee/modules/engine/utils/generate_node_name.py b/cognee/modules/engine/utils/generate_node_name.py index 84b266198..a2871875b 100644 --- a/cognee/modules/engine/utils/generate_node_name.py +++ b/cognee/modules/engine/utils/generate_node_name.py @@ -1,2 +1,2 @@ def generate_node_name(name: str) -> str: - return name.lower().replace(" ", "_").replace("'", "") + return name.lower().replace("'", "") diff --git a/cognee/tasks/graph/extract_graph_from_data.py b/cognee/tasks/graph/extract_graph_from_data.py index 36cc3e2fc..9e6edcabd 100644 --- a/cognee/tasks/graph/extract_graph_from_data.py +++ b/cognee/tasks/graph/extract_graph_from_data.py @@ -5,7 +5,7 @@ from cognee.infrastructure.databases.graph import get_graph_engine from cognee.modules.data.extraction.knowledge_graph import extract_content_graph from cognee.modules.chunking.models.DocumentChunk import DocumentChunk from cognee.modules.engine.models import EntityType, Entity -from cognee.modules.engine.utils import generate_node_id, generate_node_name +from cognee.modules.engine.utils import generate_edge_name, generate_node_id, generate_node_name from cognee.tasks.storage import add_data_points async def extract_graph_from_data(data_chunks: list[DocumentChunk], graph_model: Type[BaseModel]): @@ -95,7 +95,7 @@ async def extract_graph_from_data(data_chunks: list[DocumentChunk], graph_model: for edge in graph.edges: source_node_id = generate_node_id(edge.source_node_id) target_node_id = generate_node_id(edge.target_node_id) - relationship_name = generate_node_name(edge.relationship_name) + relationship_name = generate_edge_name(edge.relationship_name) edge_key = str(source_node_id) + str(target_node_id) + relationship_name @@ -105,7 +105,7 @@ async def extract_graph_from_data(data_chunks: list[DocumentChunk], graph_model: target_node_id, edge.relationship_name, dict( - relationship_name = generate_node_name(edge.relationship_name), + relationship_name = generate_edge_name(edge.relationship_name), source_node_id = source_node_id, target_node_id = target_node_id, ), diff --git a/cognee/tests/test_qdrant.py b/cognee/tests/test_qdrant.py index faa2cbcf4..784b3f27a 100644 --- a/cognee/tests/test_qdrant.py +++ b/cognee/tests/test_qdrant.py @@ -37,7 +37,7 @@ async def main(): from cognee.infrastructure.databases.vector import get_vector_engine vector_engine = get_vector_engine() - random_node = (await vector_engine.search("Entity_name", "AI"))[0] + random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0] random_node_name = random_node.payload["text"] search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name) diff --git a/cognee/tests/test_weaviate.py b/cognee/tests/test_weaviate.py index 121c1749e..3f853f63e 100644 --- a/cognee/tests/test_weaviate.py +++ b/cognee/tests/test_weaviate.py @@ -35,7 +35,7 @@ async def main(): from cognee.infrastructure.databases.vector import get_vector_engine vector_engine = get_vector_engine() - random_node = (await vector_engine.search("Entity_name", "AI"))[0] + random_node = (await vector_engine.search("Entity_name", "quantum computer"))[0] random_node_name = random_node.payload["text"] search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name) From eb5f30fcd1adad2b35a5ac433bdc205f03f56a75 Mon Sep 17 00:00:00 2001 From: Boris Arzentar Date: Mon, 11 Nov 2024 15:56:09 +0100 Subject: [PATCH 10/20] fix: fix single data point addition to weaiate --- .../vector/weaviate_db/WeaviateAdapter.py | 17 +++++++++-------- cognee/tests/test_weaviate.py | 2 +- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py b/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py index dd7539118..4ebae3b29 100644 --- a/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py +++ b/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py @@ -11,7 +11,6 @@ from ..embeddings.EmbeddingEngine import EmbeddingEngine logger = logging.getLogger("WeaviateAdapter") class IndexSchema(DataPoint): - uuid: str text: str _metadata: dict = { @@ -89,8 +88,10 @@ class WeaviateAdapter(VectorDBInterface): def convert_to_weaviate_data_points(data_point: DataPoint): vector = data_vectors[data_points.index(data_point)] properties = data_point.model_dump() - properties["uuid"] = properties["id"] - del properties["id"] + + if "id" in properties: + properties["uuid"] = str(data_point.id) + del properties["id"] return DataObject( uuid = data_point.id, @@ -114,7 +115,7 @@ class WeaviateAdapter(VectorDBInterface): ) else: data_point: DataObject = data_points[0] - return collection.data.update( + return collection.data.insert( uuid = data_point.uuid, vector = data_point.vector, properties = data_point.properties, @@ -130,8 +131,8 @@ class WeaviateAdapter(VectorDBInterface): async def index_data_points(self, index_name: str, index_property_name: str, data_points: list[DataPoint]): await self.create_data_points(f"{index_name}_{index_property_name}", [ IndexSchema( - uuid = str(data_point.id), - text = getattr(data_point, data_point._metadata["index_fields"][0]), + id = data_point.id, + text = data_point.get_embeddable_data(), ) for data_point in data_points ]) @@ -178,9 +179,9 @@ class WeaviateAdapter(VectorDBInterface): return [ ScoredResult( - id = UUID(result.uuid), + id = UUID(str(result.uuid)), payload = result.properties, - score = float(result.metadata.score) + score = 1 - float(result.metadata.score) ) for result in search_result.objects ] diff --git a/cognee/tests/test_weaviate.py b/cognee/tests/test_weaviate.py index 3f853f63e..f788f9973 100644 --- a/cognee/tests/test_weaviate.py +++ b/cognee/tests/test_weaviate.py @@ -35,7 +35,7 @@ async def main(): from cognee.infrastructure.databases.vector import get_vector_engine vector_engine = get_vector_engine() - random_node = (await vector_engine.search("Entity_name", "quantum computer"))[0] + random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0] random_node_name = random_node.payload["text"] search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name) From c8c2d45cb1312ad8037e1194b298ecd3a793e8c2 Mon Sep 17 00:00:00 2001 From: Boris Arzentar Date: Mon, 11 Nov 2024 15:56:30 +0100 Subject: [PATCH 11/20] fix: convert UUID to str for neo4j query --- cognee/tests/test_neo4j.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cognee/tests/test_neo4j.py b/cognee/tests/test_neo4j.py index 0783e973a..9cf1c53dd 100644 --- a/cognee/tests/test_neo4j.py +++ b/cognee/tests/test_neo4j.py @@ -36,7 +36,7 @@ async def main(): from cognee.infrastructure.databases.vector import get_vector_engine vector_engine = get_vector_engine() - random_node = (await vector_engine.search("Entity_name", "AI"))[0] + random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0] random_node_name = random_node.payload["text"] search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name) From 40bb4bc37fab3c121ee85ea1e2d4b4c4079450c7 Mon Sep 17 00:00:00 2001 From: Boris Arzentar Date: Mon, 11 Nov 2024 16:23:53 +0100 Subject: [PATCH 12/20] fix: change weaviate batch update to use dynamic batch --- .../databases/graph/neo4j_driver/adapter.py | 34 +++++++++---------- .../vector/weaviate_db/WeaviateAdapter.py | 2 +- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py index 7165aa29b..26bbb5819 100644 --- a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +++ b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py @@ -63,11 +63,10 @@ class Neo4jAdapter(GraphDBInterface): async def add_node(self, node: DataPoint): serialized_properties = self.serialize_properties(node.model_dump()) - query = """MERGE (node {id: $node_id}) - ON CREATE SET node += $properties - ON MATCH SET node += $properties - ON MATCH SET node.updated_at = timestamp() - RETURN ID(node) AS internal_id, node.id AS nodeId""" + query = dedent("""MERGE (node {id: $node_id}) + ON CREATE SET node += $properties, node.updated_at = timestamp() + ON MATCH SET node += $properties, node.updated_at = timestamp() + RETURN ID(node) AS internal_id, node.id AS nodeId""") params = { "node_id": str(node.id), @@ -80,9 +79,8 @@ class Neo4jAdapter(GraphDBInterface): query = """ UNWIND $nodes AS node MERGE (n {id: node.node_id}) - ON CREATE SET n += node.properties - ON MATCH SET n += node.properties - ON MATCH SET n.updated_at = timestamp() + ON CREATE SET n += node.properties, n.updated_at = timestamp() + ON MATCH SET n += node.properties, n.updated_at = timestamp() WITH n, node.node_id AS label CALL apoc.create.addLabels(n, [label]) YIELD node AS labeledNode RETURN ID(labeledNode) AS internal_id, labeledNode.id AS nodeId @@ -137,8 +135,9 @@ class Neo4jAdapter(GraphDBInterface): return await self.query(query, params) async def has_edge(self, from_node: UUID, to_node: UUID, edge_label: str) -> bool: - query = f""" - MATCH (from_node:`{str(from_node)}`)-[relationship:`{edge_label}`]->(to_node:`{str(to_node)}`) + query = """ + MATCH (from_node)-[relationship]->(to_node) + WHERE from_node.id = $from_node_id AND to_node.id = $to_node_id AND type(relationship) = $edge_label RETURN COUNT(relationship) > 0 AS edge_exists """ @@ -178,17 +177,18 @@ class Neo4jAdapter(GraphDBInterface): async def add_edge(self, from_node: UUID, to_node: UUID, relationship_name: str, edge_properties: Optional[Dict[str, Any]] = {}): serialized_properties = self.serialize_properties(edge_properties) - query = f"""MATCH (from_node:`{str(from_node)}` - {{id: $from_node}}), - (to_node:`{str(to_node)}` {{id: $to_node}}) - MERGE (from_node)-[r:`{relationship_name}`]->(to_node) - ON CREATE SET r += $properties, r.updated_at = timestamp() - ON MATCH SET r += $properties, r.updated_at = timestamp() - RETURN r""" + query = dedent("""MATCH (from_node {id: $from_node}), + (to_node {id: $to_node}) + MERGE (from_node)-[r]->(to_node) + ON CREATE SET r += $properties, r.updated_at = timestamp(), r.type = $relationship_name + ON MATCH SET r += $properties, r.updated_at = timestamp() + RETURN r + """) params = { "from_node": str(from_node), "to_node": str(to_node), + "relationship_name": relationship_name, "properties": serialized_properties } diff --git a/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py b/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py index 4ebae3b29..be356740f 100644 --- a/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py +++ b/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py @@ -115,7 +115,7 @@ class WeaviateAdapter(VectorDBInterface): ) else: data_point: DataObject = data_points[0] - return collection.data.insert( + return collection.data.update( uuid = data_point.uuid, vector = data_point.vector, properties = data_point.properties, From bc17759c04b66ab8fded9f6c75881b5976a49690 Mon Sep 17 00:00:00 2001 From: Boris Arzentar Date: Mon, 11 Nov 2024 17:29:13 +0100 Subject: [PATCH 13/20] fix: unwrap connections in PGVectorAdapter --- .../vector/pgvector/PGVectorAdapter.py | 29 ++++++------------- cognee/tests/test_pgvector.py | 2 +- pyproject.toml | 8 ++--- 3 files changed, 14 insertions(+), 25 deletions(-) diff --git a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py index 025d361bd..84a32e3e2 100644 --- a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +++ b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py @@ -79,15 +79,10 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): async def create_data_points( self, collection_name: str, data_points: List[DataPoint] ): - async with self.get_async_session() as session: - if not await self.has_collection(collection_name): - await self.create_collection( - collection_name=collection_name, - payload_schema=type(data_points[0]), - ) - - data_vectors = await self.embed_data( - [data_point.get_embeddable_data() for data_point in data_points] + if not await self.has_collection(collection_name): + await self.create_collection( + collection_name = collection_name, + payload_schema = type(data_points[0]), ) data_vectors = await self.embed_data( @@ -107,14 +102,10 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): payload = Column(JSON) vector = Column(Vector(vector_size)) - pgvector_data_points = [ - PGVectorDataPoint( - id=data_point.id, - vector=data_vectors[data_index], - payload=serialize_data(data_point.model_dump()), - ) - for (data_index, data_point) in enumerate(data_points) - ] + def __init__(self, id, payload, vector): + self.id = id + self.payload = payload + self.vector = vector pgvector_data_points = [ PGVectorDataPoint( @@ -136,7 +127,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): await self.create_data_points(f"{index_name}_{index_property_name}", [ IndexSchema( id = data_point.id, - text = getattr(data_point, data_point._metadata["index_fields"][0]), + text = data_point.get_embeddable_data(), ) for data_point in data_points ]) @@ -188,8 +179,6 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): # Get PGVectorDataPoint Table from database PGVectorDataPoint = await self.get_table(collection_name) - closest_items = [] - # Use async session to connect to the database async with self.get_async_session() as session: # Find closest vectors to query_vector diff --git a/cognee/tests/test_pgvector.py b/cognee/tests/test_pgvector.py index cea7c8f72..ac4d08fbb 100644 --- a/cognee/tests/test_pgvector.py +++ b/cognee/tests/test_pgvector.py @@ -65,7 +65,7 @@ async def main(): from cognee.infrastructure.databases.vector import get_vector_engine vector_engine = get_vector_engine() - random_node = (await vector_engine.search("Entity_name", "AI"))[0] + random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0] random_node_name = random_node.payload["text"] search_results = await cognee.search(SearchType.INSIGHTS, query=random_node_name) diff --git a/pyproject.toml b/pyproject.toml index 28529b446..c7363d4a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,10 +67,6 @@ anthropic = "^0.26.1" sentry-sdk = {extras = ["fastapi"], version = "^2.9.0"} fastapi-users = {version = "*", extras = ["sqlalchemy"]} alembic = "^1.13.3" -asyncpg = "^0.29.0" -pgvector = "^0.3.5" -psycopg2 = {version = "^2.9.10", optional = true} -falkordb = "^1.0.9" [tool.poetry.extras] filesystem = ["s3fs", "botocore"] @@ -81,6 +77,10 @@ neo4j = ["neo4j"] postgres = ["psycopg2", "pgvector", "asyncpg"] notebook = ["ipykernel", "overrides", "ipywidgets", "jupyterlab", "jupyterlab_widgets", "jupyterlab-server", "jupyterlab-git"] +[tool.poetry.group.postgres.dependencies] +asyncpg = "^0.29.0" +pgvector = "^0.3.5" +psycopg2 = "^2.9.10" [tool.poetry.group.dev.dependencies] pytest = "^7.4.0" From 88e226d8c76c262526b5636d5d812cc95a345f90 Mon Sep 17 00:00:00 2001 From: Boris Arzentar Date: Mon, 11 Nov 2024 17:32:06 +0100 Subject: [PATCH 14/20] fix: update poetry.lock --- poetry.lock | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/poetry.lock b/poetry.lock index e361398d5..9200748b8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -7727,11 +7727,11 @@ cli = [] filesystem = ["botocore"] neo4j = ["neo4j"] notebook = [] -postgres = ["asyncpg", "pgvector", "psycopg2"] +postgres = [] qdrant = ["qdrant-client"] weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.9.0,<3.12" -content-hash = "bb70798562fee44c6daa2f5c7fa4d17165fb76016618c1fc8fd0782c5aa4a6de" +content-hash = "7c305c381d9327bd55e658cc955a6335411d85fc3e11f2f3dcebfdc5e3b70da0" From 27057d3a293a5d114b91c4a8073941779196e7fb Mon Sep 17 00:00:00 2001 From: Boris Arzentar Date: Mon, 11 Nov 2024 17:38:33 +0100 Subject: [PATCH 15/20] fix: add postgres extras to dependencies --- poetry.lock | 4 ++-- pyproject.toml | 7 +++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/poetry.lock b/poetry.lock index 9200748b8..59b6caecc 100644 --- a/poetry.lock +++ b/poetry.lock @@ -7727,11 +7727,11 @@ cli = [] filesystem = ["botocore"] neo4j = ["neo4j"] notebook = [] -postgres = [] +postgres = ["asyncpg", "pgvector", "psycopg2"] qdrant = ["qdrant-client"] weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.9.0,<3.12" -content-hash = "7c305c381d9327bd55e658cc955a6335411d85fc3e11f2f3dcebfdc5e3b70da0" +content-hash = "fb09733ff7a70fb91c5f72ff0c8a8137b857557930a7aa025aad3154de4d8ceb" diff --git a/pyproject.toml b/pyproject.toml index c7363d4a1..ccc75cead 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,6 +67,9 @@ anthropic = "^0.26.1" sentry-sdk = {extras = ["fastapi"], version = "^2.9.0"} fastapi-users = {version = "*", extras = ["sqlalchemy"]} alembic = "^1.13.3" +asyncpg = "^0.29.0" +pgvector = "^0.3.5" +psycopg2 = "^2.9.10" [tool.poetry.extras] filesystem = ["s3fs", "botocore"] @@ -77,10 +80,6 @@ neo4j = ["neo4j"] postgres = ["psycopg2", "pgvector", "asyncpg"] notebook = ["ipykernel", "overrides", "ipywidgets", "jupyterlab", "jupyterlab_widgets", "jupyterlab-server", "jupyterlab-git"] -[tool.poetry.group.postgres.dependencies] -asyncpg = "^0.29.0" -pgvector = "^0.3.5" -psycopg2 = "^2.9.10" [tool.poetry.group.dev.dependencies] pytest = "^7.4.0" From c0d1aa12160ed4f6402d8c29766c22135c30c832 Mon Sep 17 00:00:00 2001 From: Boris Arzentar Date: Mon, 11 Nov 2024 17:54:00 +0100 Subject: [PATCH 16/20] fix: update entities collection name in cognee_demo notebook --- .../vector/pgvector/PGVectorAdapter.py | 23 ++++++++++++------- notebooks/cognee_demo.ipynb | 2 +- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py index 84a32e3e2..01691714b 100644 --- a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +++ b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py @@ -179,6 +179,8 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): # Get PGVectorDataPoint Table from database PGVectorDataPoint = await self.get_table(collection_name) + closest_items = [] + # Use async session to connect to the database async with self.get_async_session() as session: # Find closest vectors to query_vector @@ -195,14 +197,19 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): vector_list = [] - # Create and return ScoredResult objects - return [ - ScoredResult( - id = UUID(str(row.id)), - payload = row.payload, - score = row.similarity - ) for row in vector_list - ] + # Extract distances and find min/max for normalization + for vector in closest_items: + # TODO: Add normalization of similarity score + vector_list.append(vector) + + # Create and return ScoredResult objects + return [ + ScoredResult( + id = UUID(str(row.id)), + payload = row.payload, + score = row.similarity + ) for row in vector_list + ] async def batch_search( self, diff --git a/notebooks/cognee_demo.ipynb b/notebooks/cognee_demo.ipynb index 396d7b980..5f4dfa227 100644 --- a/notebooks/cognee_demo.ipynb +++ b/notebooks/cognee_demo.ipynb @@ -758,7 +758,7 @@ "from cognee.infrastructure.databases.vector import get_vector_engine\n", "\n", "vector_engine = get_vector_engine()\n", - "results = await search(vector_engine, \"entities\", \"sarah.nguyen@example.com\")\n", + "results = await search(vector_engine, \"Entity_name\", \"sarah.nguyen@example.com\")\n", "for result in results:\n", " print(result)" ] From da4d9c2c3b0d7d19d6944830d73c1ecd859f6636 Mon Sep 17 00:00:00 2001 From: Boris Arzentar Date: Mon, 11 Nov 2024 17:59:14 +0100 Subject: [PATCH 17/20] fix: change entity collection name --- notebooks/cognee_demo.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/notebooks/cognee_demo.ipynb b/notebooks/cognee_demo.ipynb index 5f4dfa227..d26476a4a 100644 --- a/notebooks/cognee_demo.ipynb +++ b/notebooks/cognee_demo.ipynb @@ -788,7 +788,7 @@ "source": [ "from cognee.api.v1.search import SearchType\n", "\n", - "node = (await vector_engine.search(\"entities\", \"sarah.nguyen@example.com\"))[0]\n", + "node = (await vector_engine.search(\"Entity_name\", \"sarah.nguyen@example.com\"))[0]\n", "node_name = node.payload[\"name\"]\n", "\n", "search_results = await cognee.search(SearchType.SUMMARIES, query = node_name)\n", From 7c015e525d6266d247b8e6d6778c65aebef65bee Mon Sep 17 00:00:00 2001 From: Boris Arzentar Date: Mon, 11 Nov 2024 18:07:53 +0100 Subject: [PATCH 18/20] fix: cognee_demo notebook search --- notebooks/cognee_demo.ipynb | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/notebooks/cognee_demo.ipynb b/notebooks/cognee_demo.ipynb index d26476a4a..06cd2a86a 100644 --- a/notebooks/cognee_demo.ipynb +++ b/notebooks/cognee_demo.ipynb @@ -265,7 +265,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "id": "df16431d0f48b006", "metadata": { "ExecuteTime": { @@ -304,7 +304,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "id": "9086abf3af077ab4", "metadata": { "ExecuteTime": { @@ -349,7 +349,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "id": "a9de0cc07f798b7f", "metadata": { "ExecuteTime": { @@ -393,7 +393,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "id": "185ff1c102d06111", "metadata": { "ExecuteTime": { @@ -437,7 +437,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": null, "id": "d55ce4c58f8efb67", "metadata": { "ExecuteTime": { @@ -479,7 +479,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "id": "ca4ecc32721ad332", "metadata": { "ExecuteTime": { @@ -572,7 +572,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": null, "id": "9f1a1dbd", "metadata": {}, "outputs": [], @@ -789,7 +789,7 @@ "from cognee.api.v1.search import SearchType\n", "\n", "node = (await vector_engine.search(\"Entity_name\", \"sarah.nguyen@example.com\"))[0]\n", - "node_name = node.payload[\"name\"]\n", + "node_name = node.payload[\"text\"]\n", "\n", "search_results = await cognee.search(SearchType.SUMMARIES, query = node_name)\n", "print(\"\\n\\Extracted summaries are:\\n\")\n", @@ -881,7 +881,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.6" + "version": "3.11.8" } }, "nbformat": 4, From e0e93ae37955b10e4321a38c4c10be8450723ae3 Mon Sep 17 00:00:00 2001 From: Leon Luithlen Date: Tue, 12 Nov 2024 09:04:43 +0100 Subject: [PATCH 19/20] Clean up notebook merge request --- .gitignore | 2 +- poetry.lock | 90 ++++++++++++++++++++---------------- pyproject.toml | 2 +- tools/daily_twitter_stats.py | 66 -------------------------- 4 files changed, 53 insertions(+), 107 deletions(-) delete mode 100644 tools/daily_twitter_stats.py diff --git a/.gitignore b/.gitignore index 1c75b636c..d256013d2 100644 --- a/.gitignore +++ b/.gitignore @@ -5,7 +5,7 @@ cognee/.data/ *.lance/ - +.DS_Store # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/poetry.lock b/poetry.lock index 59b6caecc..270e66027 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. [[package]] name = "aiofiles" @@ -3215,6 +3215,54 @@ docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"] embeddings = ["awscli (>=1.29.57)", "boto3 (>=1.28.57)", "botocore (>=1.31.57)", "cohere", "google-generativeai", "huggingface-hub", "ibm-watsonx-ai (>=1.1.2)", "instructorembedding", "ollama", "open-clip-torch", "openai (>=1.6.1)", "pillow", "sentence-transformers", "torch"] tests = ["aiohttp", "boto3", "duckdb", "pandas (>=1.4)", "polars (>=0.19,<=1.3.0)", "pytest", "pytest-asyncio", "pytest-mock", "pytz", "tantivy"] +[[package]] +name = "langchain-core" +version = "0.3.15" +description = "Building applications with LLMs through composability" +optional = false +python-versions = "<4.0,>=3.9" +files = [ + {file = "langchain_core-0.3.15-py3-none-any.whl", hash = "sha256:3d4ca6dbb8ed396a6ee061063832a2451b0ce8c345570f7b086ffa7288e4fa29"}, + {file = "langchain_core-0.3.15.tar.gz", hash = "sha256:b1a29787a4ffb7ec2103b4e97d435287201da7809b369740dd1e32f176325aba"}, +] + +[package.dependencies] +jsonpatch = ">=1.33,<2.0" +langsmith = ">=0.1.125,<0.2.0" +packaging = ">=23.2,<25" +pydantic = {version = ">=2.5.2,<3.0.0", markers = "python_full_version < \"3.12.4\""} +PyYAML = ">=5.3" +tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<10.0.0" +typing-extensions = ">=4.7" + +[[package]] +name = "langchain-text-splitters" +version = "0.3.2" +description = "LangChain text splitting utilities" +optional = false +python-versions = "<4.0,>=3.9" +files = [ + {file = "langchain_text_splitters-0.3.2-py3-none-any.whl", hash = "sha256:0db28c53f41d1bc024cdb3b1646741f6d46d5371e90f31e7e7c9fbe75d01c726"}, + {file = "langchain_text_splitters-0.3.2.tar.gz", hash = "sha256:81e6515d9901d6dd8e35fb31ccd4f30f76d44b771890c789dc835ef9f16204df"}, +] + +[package.dependencies] +langchain-core = ">=0.3.15,<0.4.0" + +[[package]] +name = "langdetect" +version = "1.0.9" +description = "Language detection library ported from Google's language-detection." +optional = false +python-versions = "*" +files = [ + {file = "langdetect-1.0.9-py2-none-any.whl", hash = "sha256:7cbc0746252f19e76f77c0b1690aadf01963be835ef0cd4b56dddf2a8f1dfc2a"}, + {file = "langdetect-1.0.9.tar.gz", hash = "sha256:cbc1fef89f8d062739774bd51eda3da3274006b3661d199c2655f6b3f6d605a0"}, +] + +[package.dependencies] +six = "*" + [[package]] name = "langfuse" version = "2.53.9" @@ -3681,24 +3729,6 @@ htmlmin2 = ">=0.1.13" jsmin = ">=3.0.1" mkdocs = ">=1.4.1" -[[package]] -name = "mkdocs-redirects" -version = "1.2.1" -description = "A MkDocs plugin for dynamic page redirects to prevent broken links." -optional = false -python-versions = ">=3.6" -files = [ - {file = "mkdocs-redirects-1.2.1.tar.gz", hash = "sha256:9420066d70e2a6bb357adf86e67023dcdca1857f97f07c7fe450f8f1fb42f861"}, -] - -[package.dependencies] -mkdocs = ">=1.1.1" - -[package.extras] -dev = ["autoflake", "black", "isort", "pytest", "twine (>=1.13.0)"] -release = ["twine (>=1.13.0)"] -test = ["autoflake", "black", "isort", "pytest"] - [[package]] name = "mkdocstrings" version = "0.26.2" @@ -5035,8 +5065,8 @@ argon2-cffi = {version = ">=23.1.0,<24", optional = true, markers = "extra == \" bcrypt = {version = ">=4.1.2,<5", optional = true, markers = "extra == \"bcrypt\""} [package.extras] -argon2 = ["argon2-cffi (==23.1.0)"] -bcrypt = ["bcrypt (==4.1.2)"] +argon2 = ["argon2-cffi (>=23.1.0,<24)"] +bcrypt = ["bcrypt (>=4.1.2,<5)"] [[package]] name = "pyarrow" @@ -5769,24 +5799,6 @@ async-timeout = {version = ">=4.0.3", markers = "python_full_version < \"3.11.3\ hiredis = ["hiredis (>=3.0.0)"] ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==23.2.1)", "requests (>=2.31.0)"] -[[package]] -name = "redis" -version = "5.1.1" -description = "Python client for Redis database and key-value store" -optional = false -python-versions = ">=3.8" -files = [ - {file = "redis-5.1.1-py3-none-any.whl", hash = "sha256:f8ea06b7482a668c6475ae202ed8d9bcaa409f6e87fb77ed1043d912afd62e24"}, - {file = "redis-5.1.1.tar.gz", hash = "sha256:f6c997521fedbae53387307c5d0bf784d9acc28d9f1d058abeac566ec4dbed72"}, -] - -[package.dependencies] -async-timeout = {version = ">=4.0.3", markers = "python_full_version < \"3.11.3\""} - -[package.extras] -hiredis = ["hiredis (>=3.0.0)"] -ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==23.2.1)", "requests (>=2.31.0)"] - [[package]] name = "referencing" version = "0.35.1" diff --git a/pyproject.toml b/pyproject.toml index ccc75cead..0bc3849b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,7 @@ langsmith = "0.1.139" langdetect = "1.0.9" posthog = "^3.5.0" lancedb = "0.15.0" -litellm = "1.38.10" +litellm = "1.49.1" groq = "0.8.0" langfuse = "^2.32.0" pydantic-settings = "^2.2.1" diff --git a/tools/daily_twitter_stats.py b/tools/daily_twitter_stats.py deleted file mode 100644 index d66f052d9..000000000 --- a/tools/daily_twitter_stats.py +++ /dev/null @@ -1,66 +0,0 @@ -import tweepy -import requests -import json -from datetime import datetime, timezone - -# Twitter API credentials from GitHub Secrets -API_KEY = '${{ secrets.TWITTER_API_KEY }}' -API_SECRET = '${{ secrets.TWITTER_API_SECRET }}' -ACCESS_TOKEN = '${{ secrets.TWITTER_ACCESS_TOKEN }}' -ACCESS_SECRET = '${{ secrets.TWITTER_ACCESS_SECRET }}' -USERNAME = '${{ secrets.TWITTER_USERNAME }}' -SEGMENT_WRITE_KEY = '${{ secrets.SEGMENT_WRITE_KEY }}' - -# Initialize Tweepy API -auth = tweepy.OAuthHandler(API_KEY, API_SECRET) -auth.set_access_token(ACCESS_TOKEN, ACCESS_SECRET) -twitter_api = tweepy.API(auth) - -# Segment endpoint -SEGMENT_ENDPOINT = 'https://api.segment.io/v1/track' - - -def get_follower_count(username): - try: - user = twitter_api.get_user(screen_name=username) - return user.followers_count - except tweepy.TweepError as e: - print(f'Error fetching follower count: {e}') - return None - - -def send_data_to_segment(username, follower_count): - current_time = datetime.now(timezone.utc).isoformat() - - data = { - 'userId': username, - 'event': 'Follower Count Update', - 'properties': { - 'username': username, - 'follower_count': follower_count, - 'timestamp': current_time - }, - 'timestamp': current_time - } - - headers = { - 'Content-Type': 'application/json', - 'Authorization': f'Basic {SEGMENT_WRITE_KEY.encode("utf-8").decode("utf-8")}' - } - - try: - response = requests.post(SEGMENT_ENDPOINT, headers=headers, data=json.dumps(data)) - - if response.status_code == 200: - print(f'Successfully sent data to Segment for {username}') - else: - print(f'Failed to send data to Segment. Status code: {response.status_code}, Response: {response.text}') - except requests.exceptions.RequestException as e: - print(f'Error sending data to Segment: {e}') - - -follower_count = get_follower_count(USERNAME) -if follower_count is not None: - send_data_to_segment(USERNAME, follower_count) -else: - print('Failed to retrieve follower count.') From adaf69c1278cbdb58968b7ec4e8999215bf8c2f9 Mon Sep 17 00:00:00 2001 From: Leon Luithlen Date: Tue, 12 Nov 2024 09:05:51 +0100 Subject: [PATCH 20/20] Readd infer_data_ontology models --- .../models/__pycache__/models.py | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 cognee/tasks/infer_data_ontology/models/__pycache__/models.py diff --git a/cognee/tasks/infer_data_ontology/models/__pycache__/models.py b/cognee/tasks/infer_data_ontology/models/__pycache__/models.py new file mode 100644 index 000000000..5b1108e6a --- /dev/null +++ b/cognee/tasks/infer_data_ontology/models/__pycache__/models.py @@ -0,0 +1,31 @@ +from typing import Any, Dict, List, Optional, Union +from pydantic import BaseModel, Field + +class RelationshipModel(BaseModel): + type: str + source: str + target: str + +class NodeModel(BaseModel): + node_id: str + name: str + default_relationship: Optional[RelationshipModel] = None + children: List[Union[Dict[str, Any], "NodeModel"]] = Field(default_factory=list) + +NodeModel.model_rebuild() + + +class OntologyNode(BaseModel): + id: str = Field(..., description = "Unique identifier made from node name.") + name: str + description: str + +class OntologyEdge(BaseModel): + id: str + source_id: str + target_id: str + relationship_type: str + +class GraphOntology(BaseModel): + nodes: list[OntologyNode] + edges: list[OntologyEdge] \ No newline at end of file