Merge remote-tracking branch 'origin/main' into feat/COG-553-graph-memory-projection

This commit is contained in:
hajdul88 2024-11-11 18:56:17 +01:00
commit 3e7df33c15
29 changed files with 276 additions and 211 deletions

View file

@ -5,7 +5,7 @@ on:
pull_request: pull_request:
branches: branches:
- main - main
types: [labeled] types: [labeled, synchronize]
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
@ -22,7 +22,7 @@ jobs:
run_neo4j_integration_test: run_neo4j_integration_test:
name: test name: test
needs: get_docs_changes needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' | ${{ github.event.label.name == 'run-checks' }} if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest runs-on: ubuntu-latest
defaults: defaults:

View file

@ -5,7 +5,7 @@ on:
pull_request: pull_request:
branches: branches:
- main - main
types: [labeled] types: [labeled, synchronize]
concurrency: concurrency:
@ -23,7 +23,7 @@ jobs:
run_notebook_test: run_notebook_test:
name: test name: test
needs: get_docs_changes needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' | ${{ github.event.label.name == 'run-checks' }} if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest runs-on: ubuntu-latest
defaults: defaults:
run: run:

View file

@ -5,7 +5,7 @@ on:
pull_request: pull_request:
branches: branches:
- main - main
types: [labeled] types: [labeled, synchronize]
concurrency: concurrency:
@ -23,7 +23,7 @@ jobs:
run_pgvector_integration_test: run_pgvector_integration_test:
name: test name: test
needs: get_docs_changes needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' | ${{ github.event.label.name == 'run-checks' }} if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest runs-on: ubuntu-latest
defaults: defaults:
run: run:

View file

@ -1,10 +1,11 @@
name: test | python 3.10 name: test | python 3.10
on: on:
workflow_dispatch:
pull_request: pull_request:
branches: branches:
- main - main
workflow_dispatch: types: [labeled, synchronize]
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
@ -21,7 +22,7 @@ jobs:
run_common: run_common:
name: test name: test
needs: get_docs_changes needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
fail-fast: false fail-fast: false

View file

@ -1,10 +1,11 @@
name: test | python 3.11 name: test | python 3.11
on: on:
workflow_dispatch:
pull_request: pull_request:
branches: branches:
- main - main
workflow_dispatch: types: [labeled, synchronize]
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
@ -21,7 +22,7 @@ jobs:
run_common: run_common:
name: test name: test
needs: get_docs_changes needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
fail-fast: false fail-fast: false

View file

@ -1,10 +1,11 @@
name: test | python 3.9 name: test | python 3.9
on: on:
workflow_dispatch:
pull_request: pull_request:
branches: branches:
- main - main
workflow_dispatch: types: [labeled, synchronize]
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
@ -21,7 +22,7 @@ jobs:
run_common: run_common:
name: test name: test
needs: get_docs_changes needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
fail-fast: false fail-fast: false

View file

@ -5,7 +5,7 @@ on:
pull_request: pull_request:
branches: branches:
- main - main
types: [labeled] types: [labeled, synchronize]
concurrency: concurrency:
@ -23,7 +23,7 @@ jobs:
run_qdrant_integration_test: run_qdrant_integration_test:
name: test name: test
needs: get_docs_changes needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' | ${{ github.event.label.name == 'run-checks' }} if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest runs-on: ubuntu-latest
defaults: defaults:

View file

@ -5,7 +5,7 @@ on:
pull_request: pull_request:
branches: branches:
- main - main
types: [labeled] types: [labeled, synchronize]
concurrency: concurrency:
@ -23,7 +23,7 @@ jobs:
run_weaviate_integration_test: run_weaviate_integration_test:
name: test name: test
needs: get_docs_changes needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' | ${{ github.event.label.name == 'run-checks' }} if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest runs-on: ubuntu-latest
defaults: defaults:

1
.gitignore vendored
View file

@ -177,5 +177,6 @@ cognee/cache/
# Default cognee system directory, used in development # Default cognee system directory, used in development
.cognee_system/ .cognee_system/
.data_storage/ .data_storage/
.anon_id
node_modules/ node_modules/

View file

@ -109,24 +109,34 @@ import asyncio
from cognee.api.v1.search import SearchType from cognee.api.v1.search import SearchType
async def main(): async def main():
await cognee.prune.prune_data() # Reset cognee data # Reset cognee data
await cognee.prune.prune_system(metadata=True) # Reset cognee system state await cognee.prune.prune_data()
# Reset cognee system state
await cognee.prune.prune_system(metadata=True)
text = """ text = """
Natural language processing (NLP) is an interdisciplinary Natural language processing (NLP) is an interdisciplinary
subfield of computer science and information retrieval. subfield of computer science and information retrieval.
""" """
await cognee.add(text) # Add text to cognee # Add text to cognee
await cognee.cognify() # Use LLMs and cognee to create knowledge graph await cognee.add(text)
search_results = await cognee.search( # Search cognee for insights # Use LLMs and cognee to create knowledge graph
await cognee.cognify()
# Search cognee for insights
search_results = await cognee.search(
SearchType.INSIGHTS, SearchType.INSIGHTS,
{'query': 'Tell me about NLP'} "Tell me about NLP",
) )
for result_text in search_results: # Display results # Display results
for result_text in search_results:
print(result_text) print(result_text)
# natural_language_processing is_a field
# natural_language_processing is_subfield_of computer_science
# natural_language_processing is_subfield_of information_retrieval
asyncio.run(main()) asyncio.run(main())
``` ```

View file

@ -1,6 +1,7 @@
""" Neo4j Adapter for Graph Database""" """ Neo4j Adapter for Graph Database"""
import logging import logging
import asyncio import asyncio
from textwrap import dedent
from typing import Optional, Any, List, Dict from typing import Optional, Any, List, Dict
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from uuid import UUID from uuid import UUID
@ -18,7 +19,7 @@ class Neo4jAdapter(GraphDBInterface):
graph_database_url: str, graph_database_url: str,
graph_database_username: str, graph_database_username: str,
graph_database_password: str, graph_database_password: str,
driver: Optional[Any] = None driver: Optional[Any] = None,
): ):
self.driver = driver or AsyncGraphDatabase.driver( self.driver = driver or AsyncGraphDatabase.driver(
graph_database_url, graph_database_url,
@ -26,6 +27,9 @@ class Neo4jAdapter(GraphDBInterface):
max_connection_lifetime = 120 max_connection_lifetime = 120
) )
async def close(self) -> None:
await self.driver.close()
@asynccontextmanager @asynccontextmanager
async def get_session(self) -> AsyncSession: async def get_session(self) -> AsyncSession:
async with self.driver.session() as session: async with self.driver.session() as session:
@ -59,11 +63,10 @@ class Neo4jAdapter(GraphDBInterface):
async def add_node(self, node: DataPoint): async def add_node(self, node: DataPoint):
serialized_properties = self.serialize_properties(node.model_dump()) serialized_properties = self.serialize_properties(node.model_dump())
query = """MERGE (node {id: $node_id}) query = dedent("""MERGE (node {id: $node_id})
ON CREATE SET node += $properties ON CREATE SET node += $properties, node.updated_at = timestamp()
ON MATCH SET node += $properties ON MATCH SET node += $properties, node.updated_at = timestamp()
ON MATCH SET node.updated_at = timestamp() RETURN ID(node) AS internal_id, node.id AS nodeId""")
RETURN ID(node) AS internal_id, node.id AS nodeId"""
params = { params = {
"node_id": str(node.id), "node_id": str(node.id),
@ -76,9 +79,8 @@ class Neo4jAdapter(GraphDBInterface):
query = """ query = """
UNWIND $nodes AS node UNWIND $nodes AS node
MERGE (n {id: node.node_id}) MERGE (n {id: node.node_id})
ON CREATE SET n += node.properties ON CREATE SET n += node.properties, n.updated_at = timestamp()
ON MATCH SET n += node.properties ON MATCH SET n += node.properties, n.updated_at = timestamp()
ON MATCH SET n.updated_at = timestamp()
WITH n, node.node_id AS label WITH n, node.node_id AS label
CALL apoc.create.addLabels(n, [label]) YIELD node AS labeledNode CALL apoc.create.addLabels(n, [label]) YIELD node AS labeledNode
RETURN ID(labeledNode) AS internal_id, labeledNode.id AS nodeId RETURN ID(labeledNode) AS internal_id, labeledNode.id AS nodeId
@ -133,12 +135,19 @@ class Neo4jAdapter(GraphDBInterface):
return await self.query(query, params) return await self.query(query, params)
async def has_edge(self, from_node: UUID, to_node: UUID, edge_label: str) -> bool: async def has_edge(self, from_node: UUID, to_node: UUID, edge_label: str) -> bool:
query = f""" query = """
MATCH (from_node:`{str(from_node)}`)-[relationship:`{edge_label}`]->(to_node:`{str(to_node)}`) MATCH (from_node)-[relationship]->(to_node)
WHERE from_node.id = $from_node_id AND to_node.id = $to_node_id AND type(relationship) = $edge_label
RETURN COUNT(relationship) > 0 AS edge_exists RETURN COUNT(relationship) > 0 AS edge_exists
""" """
edge_exists = await self.query(query) params = {
"from_node_id": str(from_node),
"to_node_id": str(to_node),
"edge_label": edge_label,
}
edge_exists = await self.query(query, params)
return edge_exists return edge_exists
async def has_edges(self, edges): async def has_edges(self, edges):
@ -165,22 +174,21 @@ class Neo4jAdapter(GraphDBInterface):
raise error raise error
async def add_edge(self, from_node: str, to_node: str, relationship_name: str, edge_properties: Optional[Dict[str, Any]] = {}): async def add_edge(self, from_node: UUID, to_node: UUID, relationship_name: str, edge_properties: Optional[Dict[str, Any]] = {}):
serialized_properties = self.serialize_properties(edge_properties) serialized_properties = self.serialize_properties(edge_properties)
from_node = from_node.replace(":", "_")
to_node = to_node.replace(":", "_")
query = f"""MATCH (from_node:`{str(from_node)}` query = dedent("""MATCH (from_node {id: $from_node}),
{{id: $from_node}}), (to_node {id: $to_node})
(to_node:`{str(to_node)}` {{id: $to_node}}) MERGE (from_node)-[r]->(to_node)
MERGE (from_node)-[r:`{relationship_name}`]->(to_node) ON CREATE SET r += $properties, r.updated_at = timestamp(), r.type = $relationship_name
ON CREATE SET r += $properties, r.updated_at = timestamp() ON MATCH SET r += $properties, r.updated_at = timestamp()
ON MATCH SET r += $properties, r.updated_at = timestamp() RETURN r
RETURN r""" """)
params = { params = {
"from_node": str(from_node), "from_node": str(from_node),
"to_node": str(to_node), "to_node": str(to_node),
"relationship_name": relationship_name,
"properties": serialized_properties "properties": serialized_properties
} }
@ -347,8 +355,8 @@ class Neo4jAdapter(GraphDBInterface):
""" """
predecessors, successors = await asyncio.gather( predecessors, successors = await asyncio.gather(
self.query(predecessors_query, dict(node_id = node_id)), self.query(predecessors_query, dict(node_id = str(node_id))),
self.query(successors_query, dict(node_id = node_id)), self.query(successors_query, dict(node_id = str(node_id))),
) )
connections = [] connections = []

View file

@ -270,7 +270,7 @@ class NetworkXAdapter(GraphDBInterface):
except: except:
pass pass
if "updated_at" in node: if "updated_at" in edge:
edge["updated_at"] = datetime.strptime(edge["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z") edge["updated_at"] = datetime.strptime(edge["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z")
self.graph = nx.readwrite.json_graph.node_link_graph(graph_data) self.graph = nx.readwrite.json_graph.node_link_graph(graph_data)

View file

@ -112,18 +112,10 @@ class LanceDBAdapter(VectorDBInterface):
for (data_point_index, data_point) in enumerate(data_points) for (data_point_index, data_point) in enumerate(data_points)
] ]
# TODO: This enables us to work with pydantic version but shouldn't await collection.merge_insert("id") \
# stay like this, existing rows should be updated .when_matched_update_all() \
.when_not_matched_insert_all() \
await collection.delete("id IS NOT NULL") .execute(lance_data_points)
original_size = await collection.count_rows()
await collection.add(lance_data_points)
new_size = await collection.count_rows()
if new_size <= original_size:
raise ValueError(
"LanceDB create_datapoints error: data points did not get added.")
async def retrieve(self, collection_name: str, data_point_ids: list[str]): async def retrieve(self, collection_name: str, data_point_ids: list[str]):

View file

@ -54,7 +54,6 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
vector_size = self.embedding_engine.get_vector_size() vector_size = self.embedding_engine.get_vector_size()
if not await self.has_collection(collection_name): if not await self.has_collection(collection_name):
class PGVectorDataPoint(Base): class PGVectorDataPoint(Base):
__tablename__ = collection_name __tablename__ = collection_name
__table_args__ = {"extend_existing": True} __table_args__ = {"extend_existing": True}
@ -80,44 +79,44 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
async def create_data_points( async def create_data_points(
self, collection_name: str, data_points: List[DataPoint] self, collection_name: str, data_points: List[DataPoint]
): ):
async with self.get_async_session() as session: if not await self.has_collection(collection_name):
if not await self.has_collection(collection_name): await self.create_collection(
await self.create_collection( collection_name = collection_name,
collection_name=collection_name, payload_schema = type(data_points[0]),
payload_schema=type(data_points[0]),
)
data_vectors = await self.embed_data(
[data_point.get_embeddable_data() for data_point in data_points]
) )
vector_size = self.embedding_engine.get_vector_size() data_vectors = await self.embed_data(
[data_point.get_embeddable_data() for data_point in data_points]
)
class PGVectorDataPoint(Base): vector_size = self.embedding_engine.get_vector_size()
__tablename__ = collection_name
__table_args__ = {"extend_existing": True}
# PGVector requires one column to be the primary key
primary_key: Mapped[int] = mapped_column(
primary_key=True, autoincrement=True
)
id: Mapped[type(data_points[0].id)]
payload = Column(JSON)
vector = Column(Vector(vector_size))
def __init__(self, id, payload, vector): class PGVectorDataPoint(Base):
self.id = id __tablename__ = collection_name
self.payload = payload __table_args__ = {"extend_existing": True}
self.vector = vector # PGVector requires one column to be the primary key
primary_key: Mapped[int] = mapped_column(
primary_key=True, autoincrement=True
)
id: Mapped[type(data_points[0].id)]
payload = Column(JSON)
vector = Column(Vector(vector_size))
pgvector_data_points = [ def __init__(self, id, payload, vector):
PGVectorDataPoint( self.id = id
id=data_point.id, self.payload = payload
vector=data_vectors[data_index], self.vector = vector
payload=serialize_data(data_point.model_dump()),
)
for (data_index, data_point) in enumerate(data_points)
]
pgvector_data_points = [
PGVectorDataPoint(
id = data_point.id,
vector = data_vectors[data_index],
payload = serialize_data(data_point.model_dump()),
)
for (data_index, data_point) in enumerate(data_points)
]
async with self.get_async_session() as session:
session.add_all(pgvector_data_points) session.add_all(pgvector_data_points)
await session.commit() await session.commit()
@ -128,7 +127,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
await self.create_data_points(f"{index_name}_{index_property_name}", [ await self.create_data_points(f"{index_name}_{index_property_name}", [
IndexSchema( IndexSchema(
id = data_point.id, id = data_point.id,
text = getattr(data_point, data_point._metadata["index_fields"][0]), text = data_point.get_embeddable_data(),
) for data_point in data_points ) for data_point in data_points
]) ])
@ -146,10 +145,10 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
raise ValueError(f"Table '{collection_name}' not found.") raise ValueError(f"Table '{collection_name}' not found.")
async def retrieve(self, collection_name: str, data_point_ids: List[str]): async def retrieve(self, collection_name: str, data_point_ids: List[str]):
async with self.get_async_session() as session: # Get PGVectorDataPoint Table from database
# Get PGVectorDataPoint Table from database PGVectorDataPoint = await self.get_table(collection_name)
PGVectorDataPoint = await self.get_table(collection_name)
async with self.get_async_session() as session:
results = await session.execute( results = await session.execute(
select(PGVectorDataPoint).where(PGVectorDataPoint.c.id.in_(data_point_ids)) select(PGVectorDataPoint).where(PGVectorDataPoint.c.id.in_(data_point_ids))
) )
@ -177,11 +176,13 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
if query_text and not query_vector: if query_text and not query_vector:
query_vector = (await self.embedding_engine.embed_text([query_text]))[0] query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
closest_items = []
# Use async session to connect to the database # Use async session to connect to the database
async with self.get_async_session() as session: async with self.get_async_session() as session:
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
# Find closest vectors to query_vector # Find closest vectors to query_vector
closest_items = await session.execute( closest_items = await session.execute(
select( select(
@ -194,20 +195,21 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
.limit(limit) .limit(limit)
) )
vector_list = [] vector_list = []
# Extract distances and find min/max for normalization
for vector in closest_items:
# TODO: Add normalization of similarity score
vector_list.append(vector)
# Create and return ScoredResult objects # Extract distances and find min/max for normalization
return [ for vector in closest_items:
ScoredResult( # TODO: Add normalization of similarity score
id = UUID(str(row.id)), vector_list.append(vector)
payload = row.payload,
score = row.similarity # Create and return ScoredResult objects
) for row in vector_list return [
] ScoredResult(
id = UUID(str(row.id)),
payload = row.payload,
score = row.similarity
) for row in vector_list
]
async def batch_search( async def batch_search(
self, self,

View file

@ -1,7 +1,9 @@
import logging import logging
from uuid import UUID
from typing import List, Dict, Optional from typing import List, Dict, Optional
from qdrant_client import AsyncQdrantClient, models from qdrant_client import AsyncQdrantClient, models
from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from ..vector_db_interface import VectorDBInterface from ..vector_db_interface import VectorDBInterface
from ..embeddings.EmbeddingEngine import EmbeddingEngine from ..embeddings.EmbeddingEngine import EmbeddingEngine
@ -153,7 +155,7 @@ class QDrantAdapter(VectorDBInterface):
client = self.get_qdrant_client() client = self.get_qdrant_client()
result = await client.search( results = await client.search(
collection_name = collection_name, collection_name = collection_name,
query_vector = models.NamedVector( query_vector = models.NamedVector(
name = "text", name = "text",
@ -165,7 +167,16 @@ class QDrantAdapter(VectorDBInterface):
await client.close() await client.close()
return result return [
ScoredResult(
id = UUID(result.id),
payload = {
**result.payload,
"id": UUID(result.id),
},
score = 1 - result.score,
) for result in results
]
async def batch_search(self, collection_name: str, query_texts: List[str], limit: int = None, with_vectors: bool = False): async def batch_search(self, collection_name: str, query_texts: List[str], limit: int = None, with_vectors: bool = False):

View file

@ -11,7 +11,6 @@ from ..embeddings.EmbeddingEngine import EmbeddingEngine
logger = logging.getLogger("WeaviateAdapter") logger = logging.getLogger("WeaviateAdapter")
class IndexSchema(DataPoint): class IndexSchema(DataPoint):
uuid: str
text: str text: str
_metadata: dict = { _metadata: dict = {
@ -58,18 +57,21 @@ class WeaviateAdapter(VectorDBInterface):
future = asyncio.Future() future = asyncio.Future()
future.set_result( if not self.client.collections.exists(collection_name):
self.client.collections.create( future.set_result(
name=collection_name, self.client.collections.create(
properties=[ name = collection_name,
wvcc.Property( properties = [
name="text", wvcc.Property(
data_type=wvcc.DataType.TEXT, name = "text",
skip_vectorization=True data_type = wvcc.DataType.TEXT,
) skip_vectorization = True
] )
]
)
) )
) else:
future.set_result(self.get_collection(collection_name))
return await future return await future
@ -80,13 +82,16 @@ class WeaviateAdapter(VectorDBInterface):
from weaviate.classes.data import DataObject from weaviate.classes.data import DataObject
data_vectors = await self.embed_data( data_vectors = await self.embed_data(
list(map(lambda data_point: data_point.get_embeddable_data(), data_points))) [data_point.get_embeddable_data() for data_point in data_points]
)
def convert_to_weaviate_data_points(data_point: DataPoint): def convert_to_weaviate_data_points(data_point: DataPoint):
vector = data_vectors[data_points.index(data_point)] vector = data_vectors[data_points.index(data_point)]
properties = data_point.model_dump() properties = data_point.model_dump()
properties["uuid"] = properties["id"]
del properties["id"] if "id" in properties:
properties["uuid"] = str(data_point.id)
del properties["id"]
return DataObject( return DataObject(
uuid = data_point.id, uuid = data_point.id,
@ -94,22 +99,28 @@ class WeaviateAdapter(VectorDBInterface):
vector = vector vector = vector
) )
data_points = list(map(convert_to_weaviate_data_points, data_points)) data_points = [convert_to_weaviate_data_points(data_point) for data_point in data_points]
collection = self.get_collection(collection_name) collection = self.get_collection(collection_name)
try: try:
if len(data_points) > 1: if len(data_points) > 1:
return collection.data.insert_many(data_points) with collection.batch.dynamic() as batch:
for data_point in data_points:
batch.add_object(
uuid = data_point.uuid,
vector = data_point.vector,
properties = data_point.properties,
references = data_point.references,
)
else: else:
return collection.data.insert(data_points[0]) data_point: DataObject = data_points[0]
# with collection.batch.dynamic() as batch: return collection.data.update(
# for point in data_points: uuid = data_point.uuid,
# batch.add_object( vector = data_point.vector,
# uuid = point.uuid, properties = data_point.properties,
# properties = point.properties, references = data_point.references,
# vector = point.vector )
# )
except Exception as error: except Exception as error:
logger.error("Error creating data points: %s", str(error)) logger.error("Error creating data points: %s", str(error))
raise error raise error
@ -120,8 +131,8 @@ class WeaviateAdapter(VectorDBInterface):
async def index_data_points(self, index_name: str, index_property_name: str, data_points: list[DataPoint]): async def index_data_points(self, index_name: str, index_property_name: str, data_points: list[DataPoint]):
await self.create_data_points(f"{index_name}_{index_property_name}", [ await self.create_data_points(f"{index_name}_{index_property_name}", [
IndexSchema( IndexSchema(
uuid = str(data_point.id), id = data_point.id,
text = getattr(data_point, data_point._metadata["index_fields"][0]), text = data_point.get_embeddable_data(),
) for data_point in data_points ) for data_point in data_points
]) ])
@ -168,9 +179,9 @@ class WeaviateAdapter(VectorDBInterface):
return [ return [
ScoredResult( ScoredResult(
id = UUID(result.uuid), id = UUID(str(result.uuid)),
payload = result.properties, payload = result.properties,
score = float(result.metadata.score) score = 1 - float(result.metadata.score)
) for result in search_result.objects ) for result in search_result.objects
] ]

View file

@ -1,2 +1,3 @@
from .generate_node_id import generate_node_id from .generate_node_id import generate_node_id
from .generate_node_name import generate_node_name from .generate_node_name import generate_node_name
from .generate_edge_name import generate_edge_name

View file

@ -0,0 +1,2 @@
def generate_edge_name(name: str) -> str:
return name.lower().replace(" ", "_").replace("'", "")

View file

@ -1,2 +1,2 @@
def generate_node_name(name: str) -> str: def generate_node_name(name: str) -> str:
return name.lower().replace(" ", "_").replace("'", "") return name.lower().replace("'", "")

View file

@ -9,9 +9,29 @@ import matplotlib.pyplot as plt
import tiktoken import tiktoken
import nltk import nltk
from posthog import Posthog from posthog import Posthog
from cognee.base_config import get_base_config from cognee.base_config import get_base_config
from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.graph import get_graph_engine
from uuid import uuid4
import pathlib
def get_anonymous_id():
"""Creates or reads a anonymous user id"""
home_dir = str(pathlib.Path(pathlib.Path(__file__).parent.parent.parent.resolve()))
if not os.path.isdir(home_dir):
os.makedirs(home_dir, exist_ok=True)
anonymous_id_file = os.path.join(home_dir, ".anon_id")
if not os.path.isfile(anonymous_id_file):
anonymous_id = str(uuid4())
with open(anonymous_id_file, "w", encoding="utf-8") as f:
f.write(anonymous_id)
else:
with open(anonymous_id_file, "r", encoding="utf-8") as f:
anonymous_id = f.read()
return anonymous_id
def send_telemetry(event_name: str, user_id, additional_properties: dict = {}): def send_telemetry(event_name: str, user_id, additional_properties: dict = {}):
if os.getenv("TELEMETRY_DISABLED"): if os.getenv("TELEMETRY_DISABLED"):
return return
@ -28,11 +48,15 @@ def send_telemetry(event_name: str, user_id, additional_properties: dict = {}):
current_time = datetime.now(timezone.utc) current_time = datetime.now(timezone.utc)
properties = { properties = {
"time": current_time.strftime("%m/%d/%Y"), "time": current_time.strftime("%m/%d/%Y"),
"user_id": user_id,
**additional_properties, **additional_properties,
} }
# Needed to forward properties to PostHog along with id
posthog.identify(get_anonymous_id(), properties)
try: try:
posthog.capture(user_id, event_name, properties) posthog.capture(get_anonymous_id(), event_name, properties)
except Exception as e: except Exception as e:
print("ERROR sending telemetric data to Posthog. See exception: %s", e) print("ERROR sending telemetric data to Posthog. See exception: %s", e)

View file

@ -5,7 +5,7 @@ from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.data.extraction.knowledge_graph import extract_content_graph from cognee.modules.data.extraction.knowledge_graph import extract_content_graph
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
from cognee.modules.engine.models import EntityType, Entity from cognee.modules.engine.models import EntityType, Entity
from cognee.modules.engine.utils import generate_node_id, generate_node_name from cognee.modules.engine.utils import generate_edge_name, generate_node_id, generate_node_name
from cognee.tasks.storage import add_data_points from cognee.tasks.storage import add_data_points
async def extract_graph_from_data(data_chunks: list[DocumentChunk], graph_model: Type[BaseModel]): async def extract_graph_from_data(data_chunks: list[DocumentChunk], graph_model: Type[BaseModel]):
@ -95,7 +95,7 @@ async def extract_graph_from_data(data_chunks: list[DocumentChunk], graph_model:
for edge in graph.edges: for edge in graph.edges:
source_node_id = generate_node_id(edge.source_node_id) source_node_id = generate_node_id(edge.source_node_id)
target_node_id = generate_node_id(edge.target_node_id) target_node_id = generate_node_id(edge.target_node_id)
relationship_name = generate_node_name(edge.relationship_name) relationship_name = generate_edge_name(edge.relationship_name)
edge_key = str(source_node_id) + str(target_node_id) + relationship_name edge_key = str(source_node_id) + str(target_node_id) + relationship_name
@ -105,7 +105,7 @@ async def extract_graph_from_data(data_chunks: list[DocumentChunk], graph_model:
target_node_id, target_node_id,
edge.relationship_name, edge.relationship_name,
dict( dict(
relationship_name = generate_node_name(edge.relationship_name), relationship_name = generate_edge_name(edge.relationship_name),
source_node_id = source_node_id, source_node_id = source_node_id,
target_node_id = target_node_id, target_node_id = target_node_id,
), ),

View file

@ -37,7 +37,7 @@ async def query_graph_connections(query: str, exploration_levels = 1) -> list[(s
return [] return []
node_connections_results = await asyncio.gather( node_connections_results = await asyncio.gather(
*[graph_engine.get_connections(str(result.id)) for result in relevant_results] *[graph_engine.get_connections(result.id) for result in relevant_results]
) )
node_connections = [] node_connections = []

View file

@ -36,7 +36,7 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine() vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity_name", "AI"))[0] random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
random_node_name = random_node.payload["text"] random_node_name = random_node.payload["text"]
search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name) search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name)

View file

@ -41,7 +41,7 @@ async def main():
cognee.config.system_root_directory(cognee_directory_path) cognee.config.system_root_directory(cognee_directory_path)
await cognee.prune.prune_data() await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True) await cognee.prune.prune_system(metadata = True)
dataset_name = "cs_explanations" dataset_name = "cs_explanations"
@ -65,7 +65,7 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine() vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity_name", "AI"))[0] random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
random_node_name = random_node.payload["text"] random_node_name = random_node.payload["text"]
search_results = await cognee.search(SearchType.INSIGHTS, query=random_node_name) search_results = await cognee.search(SearchType.INSIGHTS, query=random_node_name)

View file

@ -37,7 +37,7 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine() vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity_name", "AI"))[0] random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
random_node_name = random_node.payload["text"] random_node_name = random_node.payload["text"]
search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name) search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name)

View file

@ -35,7 +35,7 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine() vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity_name", "AI"))[0] random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
random_node_name = random_node.payload["text"] random_node_name = random_node.payload["text"]
search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name) search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name)

View file

@ -265,7 +265,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 13, "execution_count": null,
"id": "df16431d0f48b006", "id": "df16431d0f48b006",
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
@ -304,7 +304,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 14, "execution_count": null,
"id": "9086abf3af077ab4", "id": "9086abf3af077ab4",
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
@ -349,7 +349,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 15, "execution_count": null,
"id": "a9de0cc07f798b7f", "id": "a9de0cc07f798b7f",
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
@ -393,7 +393,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 16, "execution_count": null,
"id": "185ff1c102d06111", "id": "185ff1c102d06111",
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
@ -437,7 +437,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 17, "execution_count": null,
"id": "d55ce4c58f8efb67", "id": "d55ce4c58f8efb67",
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
@ -479,7 +479,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 18, "execution_count": null,
"id": "ca4ecc32721ad332", "id": "ca4ecc32721ad332",
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
@ -572,7 +572,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 26, "execution_count": null,
"id": "9f1a1dbd", "id": "9f1a1dbd",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -758,7 +758,7 @@
"from cognee.infrastructure.databases.vector import get_vector_engine\n", "from cognee.infrastructure.databases.vector import get_vector_engine\n",
"\n", "\n",
"vector_engine = get_vector_engine()\n", "vector_engine = get_vector_engine()\n",
"results = await search(vector_engine, \"entities\", \"sarah.nguyen@example.com\")\n", "results = await search(vector_engine, \"Entity_name\", \"sarah.nguyen@example.com\")\n",
"for result in results:\n", "for result in results:\n",
" print(result)" " print(result)"
] ]
@ -788,8 +788,8 @@
"source": [ "source": [
"from cognee.api.v1.search import SearchType\n", "from cognee.api.v1.search import SearchType\n",
"\n", "\n",
"node = (await vector_engine.search(\"entities\", \"sarah.nguyen@example.com\"))[0]\n", "node = (await vector_engine.search(\"Entity_name\", \"sarah.nguyen@example.com\"))[0]\n",
"node_name = node.payload[\"name\"]\n", "node_name = node.payload[\"text\"]\n",
"\n", "\n",
"search_results = await cognee.search(SearchType.SUMMARIES, query = node_name)\n", "search_results = await cognee.search(SearchType.SUMMARIES, query = node_name)\n",
"print(\"\\n\\Extracted summaries are:\\n\")\n", "print(\"\\n\\Extracted summaries are:\\n\")\n",
@ -881,7 +881,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.9.6" "version": "3.11.8"
} }
}, },
"nbformat": 4, "nbformat": 4,

81
poetry.lock generated
View file

@ -597,17 +597,17 @@ css = ["tinycss2 (>=1.1.0,<1.5)"]
[[package]] [[package]]
name = "boto3" name = "boto3"
version = "1.35.55" version = "1.35.57"
description = "The AWS SDK for Python" description = "The AWS SDK for Python"
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "boto3-1.35.55-py3-none-any.whl", hash = "sha256:c7a0a0bc5ae3bed5d38e8bfe5a56b31621e79bdd7c1ea6e5ba4326d820cde3a5"}, {file = "boto3-1.35.57-py3-none-any.whl", hash = "sha256:9edf49640c79a05b0a72f4c2d1e24dfc164344b680535a645f455ac624dc3680"},
{file = "boto3-1.35.55.tar.gz", hash = "sha256:82fa8cdb00731aeffe7a5829821ae78d75c7ae959b638c15ff3b4681192ace90"}, {file = "boto3-1.35.57.tar.gz", hash = "sha256:db58348849a5af061f0f5ec9c3b699da5221ca83354059fdccb798e3ddb6b62a"},
] ]
[package.dependencies] [package.dependencies]
botocore = ">=1.35.55,<1.36.0" botocore = ">=1.35.57,<1.36.0"
jmespath = ">=0.7.1,<2.0.0" jmespath = ">=0.7.1,<2.0.0"
s3transfer = ">=0.10.0,<0.11.0" s3transfer = ">=0.10.0,<0.11.0"
@ -616,13 +616,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"]
[[package]] [[package]]
name = "botocore" name = "botocore"
version = "1.35.55" version = "1.35.57"
description = "Low-level, data-driven core of boto 3." description = "Low-level, data-driven core of boto 3."
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "botocore-1.35.55-py3-none-any.whl", hash = "sha256:3d54739e498534c9d7a6e9732ae2d17ed29c7d5e29fe36c956d8488b859538b0"}, {file = "botocore-1.35.57-py3-none-any.whl", hash = "sha256:92ddd02469213766872cb2399269dd20948f90348b42bf08379881d5e946cc34"},
{file = "botocore-1.35.55.tar.gz", hash = "sha256:61ae18f688250372d7b6046e35c86f8fd09a7c0f0064b52688f3490b4d6c9d6b"}, {file = "botocore-1.35.57.tar.gz", hash = "sha256:d96306558085baf0bcb3b022d7a8c39c93494f031edb376694d2b2dcd0e81327"},
] ]
[package.dependencies] [package.dependencies]
@ -2606,22 +2606,22 @@ colors = ["colorama (>=0.4.6)"]
[[package]] [[package]]
name = "jedi" name = "jedi"
version = "0.19.1" version = "0.19.2"
description = "An autocompletion tool for Python that can be used for text editors." description = "An autocompletion tool for Python that can be used for text editors."
optional = false optional = false
python-versions = ">=3.6" python-versions = ">=3.6"
files = [ files = [
{file = "jedi-0.19.1-py2.py3-none-any.whl", hash = "sha256:e983c654fe5c02867aef4cdfce5a2fbb4a50adc0af145f70504238f18ef5e7e0"}, {file = "jedi-0.19.2-py2.py3-none-any.whl", hash = "sha256:a8ef22bde8490f57fe5c7681a3c83cb58874daf72b4784de3cce5b6ef6edb5b9"},
{file = "jedi-0.19.1.tar.gz", hash = "sha256:cf0496f3651bc65d7174ac1b7d043eff454892c708a87d1b683e57b569927ffd"}, {file = "jedi-0.19.2.tar.gz", hash = "sha256:4770dc3de41bde3966b02eb84fbcf557fb33cce26ad23da12c742fb50ecb11f0"},
] ]
[package.dependencies] [package.dependencies]
parso = ">=0.8.3,<0.9.0" parso = ">=0.8.4,<0.9.0"
[package.extras] [package.extras]
docs = ["Jinja2 (==2.11.3)", "MarkupSafe (==1.1.1)", "Pygments (==2.8.1)", "alabaster (==0.7.12)", "babel (==2.9.1)", "chardet (==4.0.0)", "commonmark (==0.8.1)", "docutils (==0.17.1)", "future (==0.18.2)", "idna (==2.10)", "imagesize (==1.2.0)", "mock (==1.0.1)", "packaging (==20.9)", "pyparsing (==2.4.7)", "pytz (==2021.1)", "readthedocs-sphinx-ext (==2.1.4)", "recommonmark (==0.5.0)", "requests (==2.25.1)", "six (==1.15.0)", "snowballstemmer (==2.1.0)", "sphinx (==1.8.5)", "sphinx-rtd-theme (==0.4.3)", "sphinxcontrib-serializinghtml (==1.1.4)", "sphinxcontrib-websupport (==1.2.4)", "urllib3 (==1.26.4)"] docs = ["Jinja2 (==2.11.3)", "MarkupSafe (==1.1.1)", "Pygments (==2.8.1)", "alabaster (==0.7.12)", "babel (==2.9.1)", "chardet (==4.0.0)", "commonmark (==0.8.1)", "docutils (==0.17.1)", "future (==0.18.2)", "idna (==2.10)", "imagesize (==1.2.0)", "mock (==1.0.1)", "packaging (==20.9)", "pyparsing (==2.4.7)", "pytz (==2021.1)", "readthedocs-sphinx-ext (==2.1.4)", "recommonmark (==0.5.0)", "requests (==2.25.1)", "six (==1.15.0)", "snowballstemmer (==2.1.0)", "sphinx (==1.8.5)", "sphinx-rtd-theme (==0.4.3)", "sphinxcontrib-serializinghtml (==1.1.4)", "sphinxcontrib-websupport (==1.2.4)", "urllib3 (==1.26.4)"]
qa = ["flake8 (==5.0.4)", "mypy (==0.971)", "types-setuptools (==67.2.0.1)"] qa = ["flake8 (==5.0.4)", "mypy (==0.971)", "types-setuptools (==67.2.0.1)"]
testing = ["Django", "attrs", "colorama", "docopt", "pytest (<7.0.0)"] testing = ["Django", "attrs", "colorama", "docopt", "pytest (<9.0.0)"]
[[package]] [[package]]
name = "jinja2" name = "jinja2"
@ -2755,15 +2755,18 @@ files = [
[[package]] [[package]]
name = "json5" name = "json5"
version = "0.9.25" version = "0.9.27"
description = "A Python implementation of the JSON5 data format." description = "A Python implementation of the JSON5 data format."
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8.0"
files = [ files = [
{file = "json5-0.9.25-py3-none-any.whl", hash = "sha256:34ed7d834b1341a86987ed52f3f76cd8ee184394906b6e22a1e0deb9ab294e8f"}, {file = "json5-0.9.27-py3-none-any.whl", hash = "sha256:17b43d78d3a6daeca4d7030e9bf22092dba29b1282cc2d0cfa56f6febee8dc93"},
{file = "json5-0.9.25.tar.gz", hash = "sha256:548e41b9be043f9426776f05df8635a00fe06104ea51ed24b67f908856e151ae"}, {file = "json5-0.9.27.tar.gz", hash = "sha256:5a19de4a6ca24ba664dc7d50307eb73ba9a16dea5d6bde85677ae85d3ed2d8e0"},
] ]
[package.extras]
dev = ["build (==1.2.1)", "coverage (==7.5.3)", "mypy (==1.10.0)", "pip (==24.1)", "pylint (==3.2.3)", "ruff (==0.5.1)", "twine (==5.1.1)", "uv (==0.2.13)"]
[[package]] [[package]]
name = "jsonpatch" name = "jsonpatch"
version = "1.33" version = "1.33"
@ -4360,13 +4363,13 @@ files = [
[[package]] [[package]]
name = "packaging" name = "packaging"
version = "24.1" version = "24.2"
description = "Core utilities for Python packages" description = "Core utilities for Python packages"
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "packaging-24.1-py3-none-any.whl", hash = "sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124"}, {file = "packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759"},
{file = "packaging-24.1.tar.gz", hash = "sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002"}, {file = "packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f"},
] ]
[[package]] [[package]]
@ -4432,8 +4435,8 @@ files = [
[package.dependencies] [package.dependencies]
numpy = [ numpy = [
{version = ">=1.20.3", markers = "python_version < \"3.10\""}, {version = ">=1.20.3", markers = "python_version < \"3.10\""},
{version = ">=1.23.2", markers = "python_version >= \"3.11\""},
{version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""},
{version = ">=1.23.2", markers = "python_version >= \"3.11\""},
] ]
python-dateutil = ">=2.8.2" python-dateutil = ">=2.8.2"
pytz = ">=2020.1" pytz = ">=2020.1"
@ -5007,7 +5010,7 @@ test = ["pytest", "pytest-xdist", "setuptools"]
name = "psycopg2" name = "psycopg2"
version = "2.9.10" version = "2.9.10"
description = "psycopg2 - Python-PostgreSQL Database Adapter" description = "psycopg2 - Python-PostgreSQL Database Adapter"
optional = true optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "psycopg2-2.9.10-cp310-cp310-win32.whl", hash = "sha256:5df2b672140f95adb453af93a7d669d7a7bf0a56bcd26f1502329166f4a61716"}, {file = "psycopg2-2.9.10-cp310-cp310-win32.whl", hash = "sha256:5df2b672140f95adb453af93a7d669d7a7bf0a56bcd26f1502329166f4a61716"},
@ -7032,13 +7035,13 @@ test = ["vcrpy (>=1.10.3)"]
[[package]] [[package]]
name = "typer" name = "typer"
version = "0.12.5" version = "0.13.0"
description = "Typer, build great CLIs. Easy to code. Based on Python type hints." description = "Typer, build great CLIs. Easy to code. Based on Python type hints."
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
{file = "typer-0.12.5-py3-none-any.whl", hash = "sha256:62fe4e471711b147e3365034133904df3e235698399bc4de2b36c8579298d52b"}, {file = "typer-0.13.0-py3-none-any.whl", hash = "sha256:d85fe0b777b2517cc99c8055ed735452f2659cd45e451507c76f48ce5c1d00e2"},
{file = "typer-0.12.5.tar.gz", hash = "sha256:f592f089bedcc8ec1b974125d64851029c3b1af145f04aca64d69410f0c9b722"}, {file = "typer-0.13.0.tar.gz", hash = "sha256:f1c7198347939361eec90139ffa0fd8b3df3a2259d5852a0f7400e476d95985c"},
] ]
[package.dependencies] [package.dependencies]
@ -7333,19 +7336,15 @@ validators = "0.33.0"
[[package]] [[package]]
name = "webcolors" name = "webcolors"
version = "24.8.0" version = "24.11.1"
description = "A library for working with the color formats defined by HTML and CSS." description = "A library for working with the color formats defined by HTML and CSS."
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.9"
files = [ files = [
{file = "webcolors-24.8.0-py3-none-any.whl", hash = "sha256:fc4c3b59358ada164552084a8ebee637c221e4059267d0f8325b3b560f6c7f0a"}, {file = "webcolors-24.11.1-py3-none-any.whl", hash = "sha256:515291393b4cdf0eb19c155749a096f779f7d909f7cceea072791cb9095b92e9"},
{file = "webcolors-24.8.0.tar.gz", hash = "sha256:08b07af286a01bcd30d583a7acadf629583d1f79bfef27dd2c2c5c263817277d"}, {file = "webcolors-24.11.1.tar.gz", hash = "sha256:ecb3d768f32202af770477b8b65f318fa4f566c22948673a977b00d589dd80f6"},
] ]
[package.extras]
docs = ["furo", "sphinx", "sphinx-copybutton", "sphinx-inline-tabs", "sphinx-notfound-page", "sphinxext-opengraph"]
tests = ["coverage[toml]"]
[[package]] [[package]]
name = "webencodings" name = "webencodings"
version = "0.5.1" version = "0.5.1"
@ -7375,13 +7374,13 @@ test = ["websockets"]
[[package]] [[package]]
name = "wheel" name = "wheel"
version = "0.44.0" version = "0.45.0"
description = "A built-package format for Python" description = "A built-package format for Python"
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "wheel-0.44.0-py3-none-any.whl", hash = "sha256:2376a90c98cc337d18623527a97c31797bd02bad0033d41547043a1cbfbe448f"}, {file = "wheel-0.45.0-py3-none-any.whl", hash = "sha256:52f0baa5e6522155090a09c6bd95718cc46956d1b51d537ea5454249edb671c7"},
{file = "wheel-0.44.0.tar.gz", hash = "sha256:a29c3f2817e95ab89aa4660681ad547c0e9547f20e75b0562fe7723c9a2a9d49"}, {file = "wheel-0.45.0.tar.gz", hash = "sha256:a57353941a3183b3d5365346b567a260a0602a0f8a635926a7dede41b94c674a"},
] ]
[package.extras] [package.extras]
@ -7718,13 +7717,13 @@ propcache = ">=0.2.0"
[[package]] [[package]]
name = "zipp" name = "zipp"
version = "3.20.2" version = "3.21.0"
description = "Backport of pathlib-compatible object wrapper for zip files" description = "Backport of pathlib-compatible object wrapper for zip files"
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.9"
files = [ files = [
{file = "zipp-3.20.2-py3-none-any.whl", hash = "sha256:a817ac80d6cf4b23bf7f2828b7cabf326f15a001bea8b1f9b49631780ba28350"}, {file = "zipp-3.21.0-py3-none-any.whl", hash = "sha256:ac1bbe05fd2991f160ebce24ffbac5f6d11d83dc90891255885223d42b3cd931"},
{file = "zipp-3.20.2.tar.gz", hash = "sha256:bc9eb26f4506fda01b81bcde0ca78103b6e62f991b381fec825435c836edbc29"}, {file = "zipp-3.21.0.tar.gz", hash = "sha256:2c9958f6430a2040341a52eb608ed6dd93ef4392e02ffe219417c1b28b5dd1f4"},
] ]
[package.extras] [package.extras]
@ -7740,11 +7739,11 @@ cli = []
filesystem = ["botocore"] filesystem = ["botocore"]
neo4j = ["neo4j"] neo4j = ["neo4j"]
notebook = [] notebook = []
postgres = ["psycopg2"] postgres = ["asyncpg", "pgvector", "psycopg2"]
qdrant = ["qdrant-client"] qdrant = ["qdrant-client"]
weaviate = ["weaviate-client"] weaviate = ["weaviate-client"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.9.0,<3.12" python-versions = ">=3.9.0,<3.12"
content-hash = "426fa990f2bdd15fa5be55392beb4cf77dba320f2e95cc503d1c0549d9758d64" content-hash = "fb09733ff7a70fb91c5f72ff0c8a8137b857557930a7aa025aad3154de4d8ceb"

View file

@ -42,7 +42,7 @@ aiosqlite = "^0.20.0"
pandas = "2.0.3" pandas = "2.0.3"
filetype = "^1.2.0" filetype = "^1.2.0"
nltk = "^3.8.1" nltk = "^3.8.1"
dlt = {extras = ["sqlalchemy"], version = "^1.2.0"} dlt = {extras = ["sqlalchemy"], version = "^1.3.0"}
aiofiles = "^23.2.1" aiofiles = "^23.2.1"
qdrant-client = "^1.9.0" qdrant-client = "^1.9.0"
graphistry = "^0.33.5" graphistry = "^0.33.5"
@ -66,10 +66,10 @@ pydantic-settings = "^2.2.1"
anthropic = "^0.26.1" anthropic = "^0.26.1"
sentry-sdk = {extras = ["fastapi"], version = "^2.9.0"} sentry-sdk = {extras = ["fastapi"], version = "^2.9.0"}
fastapi-users = {version = "*", extras = ["sqlalchemy"]} fastapi-users = {version = "*", extras = ["sqlalchemy"]}
asyncpg = "^0.29.0"
alembic = "^1.13.3" alembic = "^1.13.3"
asyncpg = "^0.29.0"
pgvector = "^0.3.5" pgvector = "^0.3.5"
psycopg2 = {version = "^2.9.10", optional = true} psycopg2 = "^2.9.10"
[tool.poetry.extras] [tool.poetry.extras]
filesystem = ["s3fs", "botocore"] filesystem = ["s3fs", "botocore"]
@ -77,9 +77,10 @@ cli = ["pipdeptree", "cron-descriptor"]
weaviate = ["weaviate-client"] weaviate = ["weaviate-client"]
qdrant = ["qdrant-client"] qdrant = ["qdrant-client"]
neo4j = ["neo4j"] neo4j = ["neo4j"]
postgres = ["psycopg2"] postgres = ["psycopg2", "pgvector", "asyncpg"]
notebook = ["ipykernel", "overrides", "ipywidgets", "jupyterlab", "jupyterlab_widgets", "jupyterlab-server", "jupyterlab-git"] notebook = ["ipykernel", "overrides", "ipywidgets", "jupyterlab", "jupyterlab_widgets", "jupyterlab-server", "jupyterlab-git"]
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
pytest = "^7.4.0" pytest = "^7.4.0"
pytest-asyncio = "^0.21.1" pytest-asyncio = "^0.21.1"