diff --git a/.github/workflows/test_neo4j.yml b/.github/workflows/test_neo4j.yml index 55b0f4ee4..0b47a55fc 100644 --- a/.github/workflows/test_neo4j.yml +++ b/.github/workflows/test_neo4j.yml @@ -5,7 +5,7 @@ on: pull_request: branches: - main - types: [labeled] + types: [labeled, synchronize] concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} @@ -22,7 +22,7 @@ jobs: run_neo4j_integration_test: name: test needs: get_docs_changes - if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' | ${{ github.event.label.name == 'run-checks' }} + if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }} runs-on: ubuntu-latest defaults: diff --git a/.github/workflows/test_notebook.yml b/.github/workflows/test_notebook.yml index 20f51a6e2..e5d10f0f5 100644 --- a/.github/workflows/test_notebook.yml +++ b/.github/workflows/test_notebook.yml @@ -5,7 +5,7 @@ on: pull_request: branches: - main - types: [labeled] + types: [labeled, synchronize] concurrency: @@ -23,7 +23,7 @@ jobs: run_notebook_test: name: test needs: get_docs_changes - if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' | ${{ github.event.label.name == 'run-checks' }} + if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }} runs-on: ubuntu-latest defaults: run: diff --git a/.github/workflows/test_pgvector.yml b/.github/workflows/test_pgvector.yml index c9dfc2c35..52df86c79 100644 --- a/.github/workflows/test_pgvector.yml +++ b/.github/workflows/test_pgvector.yml @@ -5,7 +5,7 @@ on: pull_request: branches: - main - types: [labeled] + types: [labeled, synchronize] concurrency: @@ -23,7 +23,7 @@ jobs: run_pgvector_integration_test: name: test needs: get_docs_changes - if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' | ${{ github.event.label.name == 'run-checks' }} + if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }} runs-on: ubuntu-latest defaults: run: diff --git a/.github/workflows/test_python_3_10.yml b/.github/workflows/test_python_3_10.yml index 0ee5bd2cd..5a7954033 100644 --- a/.github/workflows/test_python_3_10.yml +++ b/.github/workflows/test_python_3_10.yml @@ -1,10 +1,11 @@ name: test | python 3.10 on: + workflow_dispatch: pull_request: branches: - main - workflow_dispatch: + types: [labeled, synchronize] concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} @@ -21,7 +22,7 @@ jobs: run_common: name: test needs: get_docs_changes - if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' + if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }} runs-on: ubuntu-latest strategy: fail-fast: false diff --git a/.github/workflows/test_python_3_11.yml b/.github/workflows/test_python_3_11.yml index 4327312ec..22cdad320 100644 --- a/.github/workflows/test_python_3_11.yml +++ b/.github/workflows/test_python_3_11.yml @@ -1,10 +1,11 @@ name: test | python 3.11 on: + workflow_dispatch: pull_request: branches: - main - workflow_dispatch: + types: [labeled, synchronize] concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} @@ -21,7 +22,7 @@ jobs: run_common: name: test needs: get_docs_changes - if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' + if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }} runs-on: ubuntu-latest strategy: fail-fast: false diff --git a/.github/workflows/test_python_3_9.yml b/.github/workflows/test_python_3_9.yml index 154e0eab3..d6e7f8b97 100644 --- a/.github/workflows/test_python_3_9.yml +++ b/.github/workflows/test_python_3_9.yml @@ -1,10 +1,11 @@ name: test | python 3.9 on: + workflow_dispatch: pull_request: branches: - main - workflow_dispatch: + types: [labeled, synchronize] concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} @@ -21,7 +22,7 @@ jobs: run_common: name: test needs: get_docs_changes - if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' + if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }} runs-on: ubuntu-latest strategy: fail-fast: false diff --git a/.github/workflows/test_qdrant.yml b/.github/workflows/test_qdrant.yml index 595325672..a6347bd0d 100644 --- a/.github/workflows/test_qdrant.yml +++ b/.github/workflows/test_qdrant.yml @@ -5,7 +5,7 @@ on: pull_request: branches: - main - types: [labeled] + types: [labeled, synchronize] concurrency: @@ -23,7 +23,7 @@ jobs: run_qdrant_integration_test: name: test needs: get_docs_changes - if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' | ${{ github.event.label.name == 'run-checks' }} + if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }} runs-on: ubuntu-latest defaults: diff --git a/.github/workflows/test_weaviate.yml b/.github/workflows/test_weaviate.yml index 9353d1747..490f9075a 100644 --- a/.github/workflows/test_weaviate.yml +++ b/.github/workflows/test_weaviate.yml @@ -5,7 +5,7 @@ on: pull_request: branches: - main - types: [labeled] + types: [labeled, synchronize] concurrency: @@ -23,7 +23,7 @@ jobs: run_weaviate_integration_test: name: test needs: get_docs_changes - if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' | ${{ github.event.label.name == 'run-checks' }} + if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }} runs-on: ubuntu-latest defaults: diff --git a/.gitignore b/.gitignore index fef10cd63..f447655cf 100644 --- a/.gitignore +++ b/.gitignore @@ -177,5 +177,6 @@ cognee/cache/ # Default cognee system directory, used in development .cognee_system/ .data_storage/ +.anon_id node_modules/ diff --git a/README.md b/README.md index 9ce92e80e..82c3730dc 100644 --- a/README.md +++ b/README.md @@ -109,24 +109,34 @@ import asyncio from cognee.api.v1.search import SearchType async def main(): - await cognee.prune.prune_data() # Reset cognee data - await cognee.prune.prune_system(metadata=True) # Reset cognee system state + # Reset cognee data + await cognee.prune.prune_data() + # Reset cognee system state + await cognee.prune.prune_system(metadata=True) text = """ Natural language processing (NLP) is an interdisciplinary subfield of computer science and information retrieval. """ - await cognee.add(text) # Add text to cognee - await cognee.cognify() # Use LLMs and cognee to create knowledge graph + # Add text to cognee + await cognee.add(text) - search_results = await cognee.search( # Search cognee for insights + # Use LLMs and cognee to create knowledge graph + await cognee.cognify() + + # Search cognee for insights + search_results = await cognee.search( SearchType.INSIGHTS, - {'query': 'Tell me about NLP'} + "Tell me about NLP", ) - for result_text in search_results: # Display results + # Display results + for result_text in search_results: print(result_text) + # natural_language_processing is_a field + # natural_language_processing is_subfield_of computer_science + # natural_language_processing is_subfield_of information_retrieval asyncio.run(main()) ``` diff --git a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py index 10b1d029f..26bbb5819 100644 --- a/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +++ b/cognee/infrastructure/databases/graph/neo4j_driver/adapter.py @@ -1,6 +1,7 @@ """ Neo4j Adapter for Graph Database""" import logging import asyncio +from textwrap import dedent from typing import Optional, Any, List, Dict from contextlib import asynccontextmanager from uuid import UUID @@ -18,7 +19,7 @@ class Neo4jAdapter(GraphDBInterface): graph_database_url: str, graph_database_username: str, graph_database_password: str, - driver: Optional[Any] = None + driver: Optional[Any] = None, ): self.driver = driver or AsyncGraphDatabase.driver( graph_database_url, @@ -26,6 +27,9 @@ class Neo4jAdapter(GraphDBInterface): max_connection_lifetime = 120 ) + async def close(self) -> None: + await self.driver.close() + @asynccontextmanager async def get_session(self) -> AsyncSession: async with self.driver.session() as session: @@ -59,11 +63,10 @@ class Neo4jAdapter(GraphDBInterface): async def add_node(self, node: DataPoint): serialized_properties = self.serialize_properties(node.model_dump()) - query = """MERGE (node {id: $node_id}) - ON CREATE SET node += $properties - ON MATCH SET node += $properties - ON MATCH SET node.updated_at = timestamp() - RETURN ID(node) AS internal_id, node.id AS nodeId""" + query = dedent("""MERGE (node {id: $node_id}) + ON CREATE SET node += $properties, node.updated_at = timestamp() + ON MATCH SET node += $properties, node.updated_at = timestamp() + RETURN ID(node) AS internal_id, node.id AS nodeId""") params = { "node_id": str(node.id), @@ -76,9 +79,8 @@ class Neo4jAdapter(GraphDBInterface): query = """ UNWIND $nodes AS node MERGE (n {id: node.node_id}) - ON CREATE SET n += node.properties - ON MATCH SET n += node.properties - ON MATCH SET n.updated_at = timestamp() + ON CREATE SET n += node.properties, n.updated_at = timestamp() + ON MATCH SET n += node.properties, n.updated_at = timestamp() WITH n, node.node_id AS label CALL apoc.create.addLabels(n, [label]) YIELD node AS labeledNode RETURN ID(labeledNode) AS internal_id, labeledNode.id AS nodeId @@ -133,12 +135,19 @@ class Neo4jAdapter(GraphDBInterface): return await self.query(query, params) async def has_edge(self, from_node: UUID, to_node: UUID, edge_label: str) -> bool: - query = f""" - MATCH (from_node:`{str(from_node)}`)-[relationship:`{edge_label}`]->(to_node:`{str(to_node)}`) + query = """ + MATCH (from_node)-[relationship]->(to_node) + WHERE from_node.id = $from_node_id AND to_node.id = $to_node_id AND type(relationship) = $edge_label RETURN COUNT(relationship) > 0 AS edge_exists """ - edge_exists = await self.query(query) + params = { + "from_node_id": str(from_node), + "to_node_id": str(to_node), + "edge_label": edge_label, + } + + edge_exists = await self.query(query, params) return edge_exists async def has_edges(self, edges): @@ -165,22 +174,21 @@ class Neo4jAdapter(GraphDBInterface): raise error - async def add_edge(self, from_node: str, to_node: str, relationship_name: str, edge_properties: Optional[Dict[str, Any]] = {}): + async def add_edge(self, from_node: UUID, to_node: UUID, relationship_name: str, edge_properties: Optional[Dict[str, Any]] = {}): serialized_properties = self.serialize_properties(edge_properties) - from_node = from_node.replace(":", "_") - to_node = to_node.replace(":", "_") - query = f"""MATCH (from_node:`{str(from_node)}` - {{id: $from_node}}), - (to_node:`{str(to_node)}` {{id: $to_node}}) - MERGE (from_node)-[r:`{relationship_name}`]->(to_node) - ON CREATE SET r += $properties, r.updated_at = timestamp() - ON MATCH SET r += $properties, r.updated_at = timestamp() - RETURN r""" + query = dedent("""MATCH (from_node {id: $from_node}), + (to_node {id: $to_node}) + MERGE (from_node)-[r]->(to_node) + ON CREATE SET r += $properties, r.updated_at = timestamp(), r.type = $relationship_name + ON MATCH SET r += $properties, r.updated_at = timestamp() + RETURN r + """) params = { "from_node": str(from_node), "to_node": str(to_node), + "relationship_name": relationship_name, "properties": serialized_properties } @@ -347,8 +355,8 @@ class Neo4jAdapter(GraphDBInterface): """ predecessors, successors = await asyncio.gather( - self.query(predecessors_query, dict(node_id = node_id)), - self.query(successors_query, dict(node_id = node_id)), + self.query(predecessors_query, dict(node_id = str(node_id))), + self.query(successors_query, dict(node_id = str(node_id))), ) connections = [] diff --git a/cognee/infrastructure/databases/graph/networkx/adapter.py b/cognee/infrastructure/databases/graph/networkx/adapter.py index 6c7abd498..65aeea289 100644 --- a/cognee/infrastructure/databases/graph/networkx/adapter.py +++ b/cognee/infrastructure/databases/graph/networkx/adapter.py @@ -270,7 +270,7 @@ class NetworkXAdapter(GraphDBInterface): except: pass - if "updated_at" in node: + if "updated_at" in edge: edge["updated_at"] = datetime.strptime(edge["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z") self.graph = nx.readwrite.json_graph.node_link_graph(graph_data) diff --git a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py index 96f026b4f..d883a29e7 100644 --- a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +++ b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py @@ -112,18 +112,10 @@ class LanceDBAdapter(VectorDBInterface): for (data_point_index, data_point) in enumerate(data_points) ] - # TODO: This enables us to work with pydantic version but shouldn't - # stay like this, existing rows should be updated - - await collection.delete("id IS NOT NULL") - - original_size = await collection.count_rows() - await collection.add(lance_data_points) - new_size = await collection.count_rows() - - if new_size <= original_size: - raise ValueError( - "LanceDB create_datapoints error: data points did not get added.") + await collection.merge_insert("id") \ + .when_matched_update_all() \ + .when_not_matched_insert_all() \ + .execute(lance_data_points) async def retrieve(self, collection_name: str, data_point_ids: list[str]): diff --git a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py index 2e9a3764b..01691714b 100644 --- a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +++ b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py @@ -54,7 +54,6 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): vector_size = self.embedding_engine.get_vector_size() if not await self.has_collection(collection_name): - class PGVectorDataPoint(Base): __tablename__ = collection_name __table_args__ = {"extend_existing": True} @@ -80,44 +79,44 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): async def create_data_points( self, collection_name: str, data_points: List[DataPoint] ): - async with self.get_async_session() as session: - if not await self.has_collection(collection_name): - await self.create_collection( - collection_name=collection_name, - payload_schema=type(data_points[0]), - ) - - data_vectors = await self.embed_data( - [data_point.get_embeddable_data() for data_point in data_points] + if not await self.has_collection(collection_name): + await self.create_collection( + collection_name = collection_name, + payload_schema = type(data_points[0]), ) - vector_size = self.embedding_engine.get_vector_size() + data_vectors = await self.embed_data( + [data_point.get_embeddable_data() for data_point in data_points] + ) - class PGVectorDataPoint(Base): - __tablename__ = collection_name - __table_args__ = {"extend_existing": True} - # PGVector requires one column to be the primary key - primary_key: Mapped[int] = mapped_column( - primary_key=True, autoincrement=True - ) - id: Mapped[type(data_points[0].id)] - payload = Column(JSON) - vector = Column(Vector(vector_size)) + vector_size = self.embedding_engine.get_vector_size() - def __init__(self, id, payload, vector): - self.id = id - self.payload = payload - self.vector = vector + class PGVectorDataPoint(Base): + __tablename__ = collection_name + __table_args__ = {"extend_existing": True} + # PGVector requires one column to be the primary key + primary_key: Mapped[int] = mapped_column( + primary_key=True, autoincrement=True + ) + id: Mapped[type(data_points[0].id)] + payload = Column(JSON) + vector = Column(Vector(vector_size)) - pgvector_data_points = [ - PGVectorDataPoint( - id=data_point.id, - vector=data_vectors[data_index], - payload=serialize_data(data_point.model_dump()), - ) - for (data_index, data_point) in enumerate(data_points) - ] + def __init__(self, id, payload, vector): + self.id = id + self.payload = payload + self.vector = vector + pgvector_data_points = [ + PGVectorDataPoint( + id = data_point.id, + vector = data_vectors[data_index], + payload = serialize_data(data_point.model_dump()), + ) + for (data_index, data_point) in enumerate(data_points) + ] + + async with self.get_async_session() as session: session.add_all(pgvector_data_points) await session.commit() @@ -128,7 +127,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): await self.create_data_points(f"{index_name}_{index_property_name}", [ IndexSchema( id = data_point.id, - text = getattr(data_point, data_point._metadata["index_fields"][0]), + text = data_point.get_embeddable_data(), ) for data_point in data_points ]) @@ -146,10 +145,10 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): raise ValueError(f"Table '{collection_name}' not found.") async def retrieve(self, collection_name: str, data_point_ids: List[str]): - async with self.get_async_session() as session: - # Get PGVectorDataPoint Table from database - PGVectorDataPoint = await self.get_table(collection_name) + # Get PGVectorDataPoint Table from database + PGVectorDataPoint = await self.get_table(collection_name) + async with self.get_async_session() as session: results = await session.execute( select(PGVectorDataPoint).where(PGVectorDataPoint.c.id.in_(data_point_ids)) ) @@ -177,11 +176,13 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): if query_text and not query_vector: query_vector = (await self.embedding_engine.embed_text([query_text]))[0] + # Get PGVectorDataPoint Table from database + PGVectorDataPoint = await self.get_table(collection_name) + + closest_items = [] + # Use async session to connect to the database async with self.get_async_session() as session: - # Get PGVectorDataPoint Table from database - PGVectorDataPoint = await self.get_table(collection_name) - # Find closest vectors to query_vector closest_items = await session.execute( select( @@ -194,20 +195,21 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): .limit(limit) ) - vector_list = [] - # Extract distances and find min/max for normalization - for vector in closest_items: - # TODO: Add normalization of similarity score - vector_list.append(vector) + vector_list = [] - # Create and return ScoredResult objects - return [ - ScoredResult( - id = UUID(str(row.id)), - payload = row.payload, - score = row.similarity - ) for row in vector_list - ] + # Extract distances and find min/max for normalization + for vector in closest_items: + # TODO: Add normalization of similarity score + vector_list.append(vector) + + # Create and return ScoredResult objects + return [ + ScoredResult( + id = UUID(str(row.id)), + payload = row.payload, + score = row.similarity + ) for row in vector_list + ] async def batch_search( self, diff --git a/cognee/infrastructure/databases/vector/qdrant/QDrantAdapter.py b/cognee/infrastructure/databases/vector/qdrant/QDrantAdapter.py index 87d673a03..1efcd47b3 100644 --- a/cognee/infrastructure/databases/vector/qdrant/QDrantAdapter.py +++ b/cognee/infrastructure/databases/vector/qdrant/QDrantAdapter.py @@ -1,7 +1,9 @@ import logging +from uuid import UUID from typing import List, Dict, Optional from qdrant_client import AsyncQdrantClient, models +from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult from cognee.infrastructure.engine import DataPoint from ..vector_db_interface import VectorDBInterface from ..embeddings.EmbeddingEngine import EmbeddingEngine @@ -153,7 +155,7 @@ class QDrantAdapter(VectorDBInterface): client = self.get_qdrant_client() - result = await client.search( + results = await client.search( collection_name = collection_name, query_vector = models.NamedVector( name = "text", @@ -165,7 +167,16 @@ class QDrantAdapter(VectorDBInterface): await client.close() - return result + return [ + ScoredResult( + id = UUID(result.id), + payload = { + **result.payload, + "id": UUID(result.id), + }, + score = 1 - result.score, + ) for result in results + ] async def batch_search(self, collection_name: str, query_texts: List[str], limit: int = None, with_vectors: bool = False): diff --git a/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py b/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py index b5cabc56c..be356740f 100644 --- a/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py +++ b/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py @@ -11,7 +11,6 @@ from ..embeddings.EmbeddingEngine import EmbeddingEngine logger = logging.getLogger("WeaviateAdapter") class IndexSchema(DataPoint): - uuid: str text: str _metadata: dict = { @@ -58,18 +57,21 @@ class WeaviateAdapter(VectorDBInterface): future = asyncio.Future() - future.set_result( - self.client.collections.create( - name=collection_name, - properties=[ - wvcc.Property( - name="text", - data_type=wvcc.DataType.TEXT, - skip_vectorization=True - ) - ] + if not self.client.collections.exists(collection_name): + future.set_result( + self.client.collections.create( + name = collection_name, + properties = [ + wvcc.Property( + name = "text", + data_type = wvcc.DataType.TEXT, + skip_vectorization = True + ) + ] + ) ) - ) + else: + future.set_result(self.get_collection(collection_name)) return await future @@ -80,13 +82,16 @@ class WeaviateAdapter(VectorDBInterface): from weaviate.classes.data import DataObject data_vectors = await self.embed_data( - list(map(lambda data_point: data_point.get_embeddable_data(), data_points))) + [data_point.get_embeddable_data() for data_point in data_points] + ) def convert_to_weaviate_data_points(data_point: DataPoint): vector = data_vectors[data_points.index(data_point)] properties = data_point.model_dump() - properties["uuid"] = properties["id"] - del properties["id"] + + if "id" in properties: + properties["uuid"] = str(data_point.id) + del properties["id"] return DataObject( uuid = data_point.id, @@ -94,22 +99,28 @@ class WeaviateAdapter(VectorDBInterface): vector = vector ) - data_points = list(map(convert_to_weaviate_data_points, data_points)) + data_points = [convert_to_weaviate_data_points(data_point) for data_point in data_points] collection = self.get_collection(collection_name) try: if len(data_points) > 1: - return collection.data.insert_many(data_points) + with collection.batch.dynamic() as batch: + for data_point in data_points: + batch.add_object( + uuid = data_point.uuid, + vector = data_point.vector, + properties = data_point.properties, + references = data_point.references, + ) else: - return collection.data.insert(data_points[0]) - # with collection.batch.dynamic() as batch: - # for point in data_points: - # batch.add_object( - # uuid = point.uuid, - # properties = point.properties, - # vector = point.vector - # ) + data_point: DataObject = data_points[0] + return collection.data.update( + uuid = data_point.uuid, + vector = data_point.vector, + properties = data_point.properties, + references = data_point.references, + ) except Exception as error: logger.error("Error creating data points: %s", str(error)) raise error @@ -120,8 +131,8 @@ class WeaviateAdapter(VectorDBInterface): async def index_data_points(self, index_name: str, index_property_name: str, data_points: list[DataPoint]): await self.create_data_points(f"{index_name}_{index_property_name}", [ IndexSchema( - uuid = str(data_point.id), - text = getattr(data_point, data_point._metadata["index_fields"][0]), + id = data_point.id, + text = data_point.get_embeddable_data(), ) for data_point in data_points ]) @@ -168,9 +179,9 @@ class WeaviateAdapter(VectorDBInterface): return [ ScoredResult( - id = UUID(result.uuid), + id = UUID(str(result.uuid)), payload = result.properties, - score = float(result.metadata.score) + score = 1 - float(result.metadata.score) ) for result in search_result.objects ] diff --git a/cognee/modules/engine/utils/__init__.py b/cognee/modules/engine/utils/__init__.py index 9cc2bc573..4d4ab02e7 100644 --- a/cognee/modules/engine/utils/__init__.py +++ b/cognee/modules/engine/utils/__init__.py @@ -1,2 +1,3 @@ from .generate_node_id import generate_node_id from .generate_node_name import generate_node_name +from .generate_edge_name import generate_edge_name diff --git a/cognee/modules/engine/utils/generate_edge_name.py b/cognee/modules/engine/utils/generate_edge_name.py new file mode 100644 index 000000000..49ab5e8a3 --- /dev/null +++ b/cognee/modules/engine/utils/generate_edge_name.py @@ -0,0 +1,2 @@ +def generate_edge_name(name: str) -> str: + return name.lower().replace(" ", "_").replace("'", "") diff --git a/cognee/modules/engine/utils/generate_node_name.py b/cognee/modules/engine/utils/generate_node_name.py index 84b266198..a2871875b 100644 --- a/cognee/modules/engine/utils/generate_node_name.py +++ b/cognee/modules/engine/utils/generate_node_name.py @@ -1,2 +1,2 @@ def generate_node_name(name: str) -> str: - return name.lower().replace(" ", "_").replace("'", "") + return name.lower().replace("'", "") diff --git a/cognee/shared/utils.py b/cognee/shared/utils.py index 14578f202..42a95b88b 100644 --- a/cognee/shared/utils.py +++ b/cognee/shared/utils.py @@ -9,9 +9,29 @@ import matplotlib.pyplot as plt import tiktoken import nltk from posthog import Posthog + from cognee.base_config import get_base_config from cognee.infrastructure.databases.graph import get_graph_engine +from uuid import uuid4 +import pathlib + +def get_anonymous_id(): + """Creates or reads a anonymous user id""" + home_dir = str(pathlib.Path(pathlib.Path(__file__).parent.parent.parent.resolve())) + + if not os.path.isdir(home_dir): + os.makedirs(home_dir, exist_ok=True) + anonymous_id_file = os.path.join(home_dir, ".anon_id") + if not os.path.isfile(anonymous_id_file): + anonymous_id = str(uuid4()) + with open(anonymous_id_file, "w", encoding="utf-8") as f: + f.write(anonymous_id) + else: + with open(anonymous_id_file, "r", encoding="utf-8") as f: + anonymous_id = f.read() + return anonymous_id + def send_telemetry(event_name: str, user_id, additional_properties: dict = {}): if os.getenv("TELEMETRY_DISABLED"): return @@ -28,11 +48,15 @@ def send_telemetry(event_name: str, user_id, additional_properties: dict = {}): current_time = datetime.now(timezone.utc) properties = { "time": current_time.strftime("%m/%d/%Y"), + "user_id": user_id, **additional_properties, } + # Needed to forward properties to PostHog along with id + posthog.identify(get_anonymous_id(), properties) + try: - posthog.capture(user_id, event_name, properties) + posthog.capture(get_anonymous_id(), event_name, properties) except Exception as e: print("ERROR sending telemetric data to Posthog. See exception: %s", e) diff --git a/cognee/tasks/graph/extract_graph_from_data.py b/cognee/tasks/graph/extract_graph_from_data.py index 36cc3e2fc..9e6edcabd 100644 --- a/cognee/tasks/graph/extract_graph_from_data.py +++ b/cognee/tasks/graph/extract_graph_from_data.py @@ -5,7 +5,7 @@ from cognee.infrastructure.databases.graph import get_graph_engine from cognee.modules.data.extraction.knowledge_graph import extract_content_graph from cognee.modules.chunking.models.DocumentChunk import DocumentChunk from cognee.modules.engine.models import EntityType, Entity -from cognee.modules.engine.utils import generate_node_id, generate_node_name +from cognee.modules.engine.utils import generate_edge_name, generate_node_id, generate_node_name from cognee.tasks.storage import add_data_points async def extract_graph_from_data(data_chunks: list[DocumentChunk], graph_model: Type[BaseModel]): @@ -95,7 +95,7 @@ async def extract_graph_from_data(data_chunks: list[DocumentChunk], graph_model: for edge in graph.edges: source_node_id = generate_node_id(edge.source_node_id) target_node_id = generate_node_id(edge.target_node_id) - relationship_name = generate_node_name(edge.relationship_name) + relationship_name = generate_edge_name(edge.relationship_name) edge_key = str(source_node_id) + str(target_node_id) + relationship_name @@ -105,7 +105,7 @@ async def extract_graph_from_data(data_chunks: list[DocumentChunk], graph_model: target_node_id, edge.relationship_name, dict( - relationship_name = generate_node_name(edge.relationship_name), + relationship_name = generate_edge_name(edge.relationship_name), source_node_id = source_node_id, target_node_id = target_node_id, ), diff --git a/cognee/tasks/graph/query_graph_connections.py b/cognee/tasks/graph/query_graph_connections.py index c64abc31b..cd4d76a5e 100644 --- a/cognee/tasks/graph/query_graph_connections.py +++ b/cognee/tasks/graph/query_graph_connections.py @@ -37,7 +37,7 @@ async def query_graph_connections(query: str, exploration_levels = 1) -> list[(s return [] node_connections_results = await asyncio.gather( - *[graph_engine.get_connections(str(result.id)) for result in relevant_results] + *[graph_engine.get_connections(result.id) for result in relevant_results] ) node_connections = [] diff --git a/cognee/tests/test_neo4j.py b/cognee/tests/test_neo4j.py index 0783e973a..9cf1c53dd 100644 --- a/cognee/tests/test_neo4j.py +++ b/cognee/tests/test_neo4j.py @@ -36,7 +36,7 @@ async def main(): from cognee.infrastructure.databases.vector import get_vector_engine vector_engine = get_vector_engine() - random_node = (await vector_engine.search("Entity_name", "AI"))[0] + random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0] random_node_name = random_node.payload["text"] search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name) diff --git a/cognee/tests/test_pgvector.py b/cognee/tests/test_pgvector.py index 802aa3fcb..ac4d08fbb 100644 --- a/cognee/tests/test_pgvector.py +++ b/cognee/tests/test_pgvector.py @@ -41,7 +41,7 @@ async def main(): cognee.config.system_root_directory(cognee_directory_path) await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) + await cognee.prune.prune_system(metadata = True) dataset_name = "cs_explanations" @@ -65,7 +65,7 @@ async def main(): from cognee.infrastructure.databases.vector import get_vector_engine vector_engine = get_vector_engine() - random_node = (await vector_engine.search("Entity_name", "AI"))[0] + random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0] random_node_name = random_node.payload["text"] search_results = await cognee.search(SearchType.INSIGHTS, query=random_node_name) diff --git a/cognee/tests/test_qdrant.py b/cognee/tests/test_qdrant.py index faa2cbcf4..784b3f27a 100644 --- a/cognee/tests/test_qdrant.py +++ b/cognee/tests/test_qdrant.py @@ -37,7 +37,7 @@ async def main(): from cognee.infrastructure.databases.vector import get_vector_engine vector_engine = get_vector_engine() - random_node = (await vector_engine.search("Entity_name", "AI"))[0] + random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0] random_node_name = random_node.payload["text"] search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name) diff --git a/cognee/tests/test_weaviate.py b/cognee/tests/test_weaviate.py index 121c1749e..f788f9973 100644 --- a/cognee/tests/test_weaviate.py +++ b/cognee/tests/test_weaviate.py @@ -35,7 +35,7 @@ async def main(): from cognee.infrastructure.databases.vector import get_vector_engine vector_engine = get_vector_engine() - random_node = (await vector_engine.search("Entity_name", "AI"))[0] + random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0] random_node_name = random_node.payload["text"] search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name) diff --git a/notebooks/cognee_demo.ipynb b/notebooks/cognee_demo.ipynb index 396d7b980..06cd2a86a 100644 --- a/notebooks/cognee_demo.ipynb +++ b/notebooks/cognee_demo.ipynb @@ -265,7 +265,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "id": "df16431d0f48b006", "metadata": { "ExecuteTime": { @@ -304,7 +304,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "id": "9086abf3af077ab4", "metadata": { "ExecuteTime": { @@ -349,7 +349,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "id": "a9de0cc07f798b7f", "metadata": { "ExecuteTime": { @@ -393,7 +393,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "id": "185ff1c102d06111", "metadata": { "ExecuteTime": { @@ -437,7 +437,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": null, "id": "d55ce4c58f8efb67", "metadata": { "ExecuteTime": { @@ -479,7 +479,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "id": "ca4ecc32721ad332", "metadata": { "ExecuteTime": { @@ -572,7 +572,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": null, "id": "9f1a1dbd", "metadata": {}, "outputs": [], @@ -758,7 +758,7 @@ "from cognee.infrastructure.databases.vector import get_vector_engine\n", "\n", "vector_engine = get_vector_engine()\n", - "results = await search(vector_engine, \"entities\", \"sarah.nguyen@example.com\")\n", + "results = await search(vector_engine, \"Entity_name\", \"sarah.nguyen@example.com\")\n", "for result in results:\n", " print(result)" ] @@ -788,8 +788,8 @@ "source": [ "from cognee.api.v1.search import SearchType\n", "\n", - "node = (await vector_engine.search(\"entities\", \"sarah.nguyen@example.com\"))[0]\n", - "node_name = node.payload[\"name\"]\n", + "node = (await vector_engine.search(\"Entity_name\", \"sarah.nguyen@example.com\"))[0]\n", + "node_name = node.payload[\"text\"]\n", "\n", "search_results = await cognee.search(SearchType.SUMMARIES, query = node_name)\n", "print(\"\\n\\Extracted summaries are:\\n\")\n", @@ -881,7 +881,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.6" + "version": "3.11.8" } }, "nbformat": 4, diff --git a/poetry.lock b/poetry.lock index d30aa907b..270e66027 100644 --- a/poetry.lock +++ b/poetry.lock @@ -597,17 +597,17 @@ css = ["tinycss2 (>=1.1.0,<1.5)"] [[package]] name = "boto3" -version = "1.35.55" +version = "1.35.57" description = "The AWS SDK for Python" optional = false python-versions = ">=3.8" files = [ - {file = "boto3-1.35.55-py3-none-any.whl", hash = "sha256:c7a0a0bc5ae3bed5d38e8bfe5a56b31621e79bdd7c1ea6e5ba4326d820cde3a5"}, - {file = "boto3-1.35.55.tar.gz", hash = "sha256:82fa8cdb00731aeffe7a5829821ae78d75c7ae959b638c15ff3b4681192ace90"}, + {file = "boto3-1.35.57-py3-none-any.whl", hash = "sha256:9edf49640c79a05b0a72f4c2d1e24dfc164344b680535a645f455ac624dc3680"}, + {file = "boto3-1.35.57.tar.gz", hash = "sha256:db58348849a5af061f0f5ec9c3b699da5221ca83354059fdccb798e3ddb6b62a"}, ] [package.dependencies] -botocore = ">=1.35.55,<1.36.0" +botocore = ">=1.35.57,<1.36.0" jmespath = ">=0.7.1,<2.0.0" s3transfer = ">=0.10.0,<0.11.0" @@ -616,13 +616,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] [[package]] name = "botocore" -version = "1.35.55" +version = "1.35.57" description = "Low-level, data-driven core of boto 3." optional = false python-versions = ">=3.8" files = [ - {file = "botocore-1.35.55-py3-none-any.whl", hash = "sha256:3d54739e498534c9d7a6e9732ae2d17ed29c7d5e29fe36c956d8488b859538b0"}, - {file = "botocore-1.35.55.tar.gz", hash = "sha256:61ae18f688250372d7b6046e35c86f8fd09a7c0f0064b52688f3490b4d6c9d6b"}, + {file = "botocore-1.35.57-py3-none-any.whl", hash = "sha256:92ddd02469213766872cb2399269dd20948f90348b42bf08379881d5e946cc34"}, + {file = "botocore-1.35.57.tar.gz", hash = "sha256:d96306558085baf0bcb3b022d7a8c39c93494f031edb376694d2b2dcd0e81327"}, ] [package.dependencies] @@ -2606,22 +2606,22 @@ colors = ["colorama (>=0.4.6)"] [[package]] name = "jedi" -version = "0.19.1" +version = "0.19.2" description = "An autocompletion tool for Python that can be used for text editors." optional = false python-versions = ">=3.6" files = [ - {file = "jedi-0.19.1-py2.py3-none-any.whl", hash = "sha256:e983c654fe5c02867aef4cdfce5a2fbb4a50adc0af145f70504238f18ef5e7e0"}, - {file = "jedi-0.19.1.tar.gz", hash = "sha256:cf0496f3651bc65d7174ac1b7d043eff454892c708a87d1b683e57b569927ffd"}, + {file = "jedi-0.19.2-py2.py3-none-any.whl", hash = "sha256:a8ef22bde8490f57fe5c7681a3c83cb58874daf72b4784de3cce5b6ef6edb5b9"}, + {file = "jedi-0.19.2.tar.gz", hash = "sha256:4770dc3de41bde3966b02eb84fbcf557fb33cce26ad23da12c742fb50ecb11f0"}, ] [package.dependencies] -parso = ">=0.8.3,<0.9.0" +parso = ">=0.8.4,<0.9.0" [package.extras] docs = ["Jinja2 (==2.11.3)", "MarkupSafe (==1.1.1)", "Pygments (==2.8.1)", "alabaster (==0.7.12)", "babel (==2.9.1)", "chardet (==4.0.0)", "commonmark (==0.8.1)", "docutils (==0.17.1)", "future (==0.18.2)", "idna (==2.10)", "imagesize (==1.2.0)", "mock (==1.0.1)", "packaging (==20.9)", "pyparsing (==2.4.7)", "pytz (==2021.1)", "readthedocs-sphinx-ext (==2.1.4)", "recommonmark (==0.5.0)", "requests (==2.25.1)", "six (==1.15.0)", "snowballstemmer (==2.1.0)", "sphinx (==1.8.5)", "sphinx-rtd-theme (==0.4.3)", "sphinxcontrib-serializinghtml (==1.1.4)", "sphinxcontrib-websupport (==1.2.4)", "urllib3 (==1.26.4)"] qa = ["flake8 (==5.0.4)", "mypy (==0.971)", "types-setuptools (==67.2.0.1)"] -testing = ["Django", "attrs", "colorama", "docopt", "pytest (<7.0.0)"] +testing = ["Django", "attrs", "colorama", "docopt", "pytest (<9.0.0)"] [[package]] name = "jinja2" @@ -2755,15 +2755,18 @@ files = [ [[package]] name = "json5" -version = "0.9.25" +version = "0.9.27" description = "A Python implementation of the JSON5 data format." optional = false -python-versions = ">=3.8" +python-versions = ">=3.8.0" files = [ - {file = "json5-0.9.25-py3-none-any.whl", hash = "sha256:34ed7d834b1341a86987ed52f3f76cd8ee184394906b6e22a1e0deb9ab294e8f"}, - {file = "json5-0.9.25.tar.gz", hash = "sha256:548e41b9be043f9426776f05df8635a00fe06104ea51ed24b67f908856e151ae"}, + {file = "json5-0.9.27-py3-none-any.whl", hash = "sha256:17b43d78d3a6daeca4d7030e9bf22092dba29b1282cc2d0cfa56f6febee8dc93"}, + {file = "json5-0.9.27.tar.gz", hash = "sha256:5a19de4a6ca24ba664dc7d50307eb73ba9a16dea5d6bde85677ae85d3ed2d8e0"}, ] +[package.extras] +dev = ["build (==1.2.1)", "coverage (==7.5.3)", "mypy (==1.10.0)", "pip (==24.1)", "pylint (==3.2.3)", "ruff (==0.5.1)", "twine (==5.1.1)", "uv (==0.2.13)"] + [[package]] name = "jsonpatch" version = "1.33" @@ -4360,13 +4363,13 @@ files = [ [[package]] name = "packaging" -version = "24.1" +version = "24.2" description = "Core utilities for Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "packaging-24.1-py3-none-any.whl", hash = "sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124"}, - {file = "packaging-24.1.tar.gz", hash = "sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002"}, + {file = "packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759"}, + {file = "packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f"}, ] [[package]] @@ -4432,8 +4435,8 @@ files = [ [package.dependencies] numpy = [ {version = ">=1.20.3", markers = "python_version < \"3.10\""}, - {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, + {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -5007,7 +5010,7 @@ test = ["pytest", "pytest-xdist", "setuptools"] name = "psycopg2" version = "2.9.10" description = "psycopg2 - Python-PostgreSQL Database Adapter" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "psycopg2-2.9.10-cp310-cp310-win32.whl", hash = "sha256:5df2b672140f95adb453af93a7d669d7a7bf0a56bcd26f1502329166f4a61716"}, @@ -7032,13 +7035,13 @@ test = ["vcrpy (>=1.10.3)"] [[package]] name = "typer" -version = "0.12.5" +version = "0.13.0" description = "Typer, build great CLIs. Easy to code. Based on Python type hints." optional = false python-versions = ">=3.7" files = [ - {file = "typer-0.12.5-py3-none-any.whl", hash = "sha256:62fe4e471711b147e3365034133904df3e235698399bc4de2b36c8579298d52b"}, - {file = "typer-0.12.5.tar.gz", hash = "sha256:f592f089bedcc8ec1b974125d64851029c3b1af145f04aca64d69410f0c9b722"}, + {file = "typer-0.13.0-py3-none-any.whl", hash = "sha256:d85fe0b777b2517cc99c8055ed735452f2659cd45e451507c76f48ce5c1d00e2"}, + {file = "typer-0.13.0.tar.gz", hash = "sha256:f1c7198347939361eec90139ffa0fd8b3df3a2259d5852a0f7400e476d95985c"}, ] [package.dependencies] @@ -7333,19 +7336,15 @@ validators = "0.33.0" [[package]] name = "webcolors" -version = "24.8.0" +version = "24.11.1" description = "A library for working with the color formats defined by HTML and CSS." optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" files = [ - {file = "webcolors-24.8.0-py3-none-any.whl", hash = "sha256:fc4c3b59358ada164552084a8ebee637c221e4059267d0f8325b3b560f6c7f0a"}, - {file = "webcolors-24.8.0.tar.gz", hash = "sha256:08b07af286a01bcd30d583a7acadf629583d1f79bfef27dd2c2c5c263817277d"}, + {file = "webcolors-24.11.1-py3-none-any.whl", hash = "sha256:515291393b4cdf0eb19c155749a096f779f7d909f7cceea072791cb9095b92e9"}, + {file = "webcolors-24.11.1.tar.gz", hash = "sha256:ecb3d768f32202af770477b8b65f318fa4f566c22948673a977b00d589dd80f6"}, ] -[package.extras] -docs = ["furo", "sphinx", "sphinx-copybutton", "sphinx-inline-tabs", "sphinx-notfound-page", "sphinxext-opengraph"] -tests = ["coverage[toml]"] - [[package]] name = "webencodings" version = "0.5.1" @@ -7375,13 +7374,13 @@ test = ["websockets"] [[package]] name = "wheel" -version = "0.44.0" +version = "0.45.0" description = "A built-package format for Python" optional = false python-versions = ">=3.8" files = [ - {file = "wheel-0.44.0-py3-none-any.whl", hash = "sha256:2376a90c98cc337d18623527a97c31797bd02bad0033d41547043a1cbfbe448f"}, - {file = "wheel-0.44.0.tar.gz", hash = "sha256:a29c3f2817e95ab89aa4660681ad547c0e9547f20e75b0562fe7723c9a2a9d49"}, + {file = "wheel-0.45.0-py3-none-any.whl", hash = "sha256:52f0baa5e6522155090a09c6bd95718cc46956d1b51d537ea5454249edb671c7"}, + {file = "wheel-0.45.0.tar.gz", hash = "sha256:a57353941a3183b3d5365346b567a260a0602a0f8a635926a7dede41b94c674a"}, ] [package.extras] @@ -7718,13 +7717,13 @@ propcache = ">=0.2.0" [[package]] name = "zipp" -version = "3.20.2" +version = "3.21.0" description = "Backport of pathlib-compatible object wrapper for zip files" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" files = [ - {file = "zipp-3.20.2-py3-none-any.whl", hash = "sha256:a817ac80d6cf4b23bf7f2828b7cabf326f15a001bea8b1f9b49631780ba28350"}, - {file = "zipp-3.20.2.tar.gz", hash = "sha256:bc9eb26f4506fda01b81bcde0ca78103b6e62f991b381fec825435c836edbc29"}, + {file = "zipp-3.21.0-py3-none-any.whl", hash = "sha256:ac1bbe05fd2991f160ebce24ffbac5f6d11d83dc90891255885223d42b3cd931"}, + {file = "zipp-3.21.0.tar.gz", hash = "sha256:2c9958f6430a2040341a52eb608ed6dd93ef4392e02ffe219417c1b28b5dd1f4"}, ] [package.extras] @@ -7740,11 +7739,11 @@ cli = [] filesystem = ["botocore"] neo4j = ["neo4j"] notebook = [] -postgres = ["psycopg2"] +postgres = ["asyncpg", "pgvector", "psycopg2"] qdrant = ["qdrant-client"] weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.9.0,<3.12" -content-hash = "426fa990f2bdd15fa5be55392beb4cf77dba320f2e95cc503d1c0549d9758d64" +content-hash = "fb09733ff7a70fb91c5f72ff0c8a8137b857557930a7aa025aad3154de4d8ceb" diff --git a/pyproject.toml b/pyproject.toml index 93ec8e0ea..0bc3849b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,7 @@ aiosqlite = "^0.20.0" pandas = "2.0.3" filetype = "^1.2.0" nltk = "^3.8.1" -dlt = {extras = ["sqlalchemy"], version = "^1.2.0"} +dlt = {extras = ["sqlalchemy"], version = "^1.3.0"} aiofiles = "^23.2.1" qdrant-client = "^1.9.0" graphistry = "^0.33.5" @@ -66,10 +66,10 @@ pydantic-settings = "^2.2.1" anthropic = "^0.26.1" sentry-sdk = {extras = ["fastapi"], version = "^2.9.0"} fastapi-users = {version = "*", extras = ["sqlalchemy"]} -asyncpg = "^0.29.0" alembic = "^1.13.3" +asyncpg = "^0.29.0" pgvector = "^0.3.5" -psycopg2 = {version = "^2.9.10", optional = true} +psycopg2 = "^2.9.10" [tool.poetry.extras] filesystem = ["s3fs", "botocore"] @@ -77,9 +77,10 @@ cli = ["pipdeptree", "cron-descriptor"] weaviate = ["weaviate-client"] qdrant = ["qdrant-client"] neo4j = ["neo4j"] -postgres = ["psycopg2"] +postgres = ["psycopg2", "pgvector", "asyncpg"] notebook = ["ipykernel", "overrides", "ipywidgets", "jupyterlab", "jupyterlab_widgets", "jupyterlab-server", "jupyterlab-git"] + [tool.poetry.group.dev.dependencies] pytest = "^7.4.0" pytest-asyncio = "^0.21.1"