Merge remote-tracking branch 'origin/main' into feat/COG-553-graph-memory-projection

This commit is contained in:
hajdul88 2024-11-11 18:56:17 +01:00
commit 3e7df33c15
29 changed files with 276 additions and 211 deletions

View file

@ -5,7 +5,7 @@ on:
pull_request:
branches:
- main
types: [labeled]
types: [labeled, synchronize]
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
@ -22,7 +22,7 @@ jobs:
run_neo4j_integration_test:
name: test
needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' | ${{ github.event.label.name == 'run-checks' }}
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest
defaults:

View file

@ -5,7 +5,7 @@ on:
pull_request:
branches:
- main
types: [labeled]
types: [labeled, synchronize]
concurrency:
@ -23,7 +23,7 @@ jobs:
run_notebook_test:
name: test
needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' | ${{ github.event.label.name == 'run-checks' }}
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest
defaults:
run:

View file

@ -5,7 +5,7 @@ on:
pull_request:
branches:
- main
types: [labeled]
types: [labeled, synchronize]
concurrency:
@ -23,7 +23,7 @@ jobs:
run_pgvector_integration_test:
name: test
needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' | ${{ github.event.label.name == 'run-checks' }}
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest
defaults:
run:

View file

@ -1,10 +1,11 @@
name: test | python 3.10
on:
workflow_dispatch:
pull_request:
branches:
- main
workflow_dispatch:
types: [labeled, synchronize]
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
@ -21,7 +22,7 @@ jobs:
run_common:
name: test
needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true'
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest
strategy:
fail-fast: false

View file

@ -1,10 +1,11 @@
name: test | python 3.11
on:
workflow_dispatch:
pull_request:
branches:
- main
workflow_dispatch:
types: [labeled, synchronize]
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
@ -21,7 +22,7 @@ jobs:
run_common:
name: test
needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true'
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest
strategy:
fail-fast: false

View file

@ -1,10 +1,11 @@
name: test | python 3.9
on:
workflow_dispatch:
pull_request:
branches:
- main
workflow_dispatch:
types: [labeled, synchronize]
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
@ -21,7 +22,7 @@ jobs:
run_common:
name: test
needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true'
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest
strategy:
fail-fast: false

View file

@ -5,7 +5,7 @@ on:
pull_request:
branches:
- main
types: [labeled]
types: [labeled, synchronize]
concurrency:
@ -23,7 +23,7 @@ jobs:
run_qdrant_integration_test:
name: test
needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' | ${{ github.event.label.name == 'run-checks' }}
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest
defaults:

View file

@ -5,7 +5,7 @@ on:
pull_request:
branches:
- main
types: [labeled]
types: [labeled, synchronize]
concurrency:
@ -23,7 +23,7 @@ jobs:
run_weaviate_integration_test:
name: test
needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' | ${{ github.event.label.name == 'run-checks' }}
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest
defaults:

1
.gitignore vendored
View file

@ -177,5 +177,6 @@ cognee/cache/
# Default cognee system directory, used in development
.cognee_system/
.data_storage/
.anon_id
node_modules/

View file

@ -109,24 +109,34 @@ import asyncio
from cognee.api.v1.search import SearchType
async def main():
await cognee.prune.prune_data() # Reset cognee data
await cognee.prune.prune_system(metadata=True) # Reset cognee system state
# Reset cognee data
await cognee.prune.prune_data()
# Reset cognee system state
await cognee.prune.prune_system(metadata=True)
text = """
Natural language processing (NLP) is an interdisciplinary
subfield of computer science and information retrieval.
"""
await cognee.add(text) # Add text to cognee
await cognee.cognify() # Use LLMs and cognee to create knowledge graph
# Add text to cognee
await cognee.add(text)
search_results = await cognee.search( # Search cognee for insights
# Use LLMs and cognee to create knowledge graph
await cognee.cognify()
# Search cognee for insights
search_results = await cognee.search(
SearchType.INSIGHTS,
{'query': 'Tell me about NLP'}
"Tell me about NLP",
)
for result_text in search_results: # Display results
# Display results
for result_text in search_results:
print(result_text)
# natural_language_processing is_a field
# natural_language_processing is_subfield_of computer_science
# natural_language_processing is_subfield_of information_retrieval
asyncio.run(main())
```

View file

@ -1,6 +1,7 @@
""" Neo4j Adapter for Graph Database"""
import logging
import asyncio
from textwrap import dedent
from typing import Optional, Any, List, Dict
from contextlib import asynccontextmanager
from uuid import UUID
@ -18,7 +19,7 @@ class Neo4jAdapter(GraphDBInterface):
graph_database_url: str,
graph_database_username: str,
graph_database_password: str,
driver: Optional[Any] = None
driver: Optional[Any] = None,
):
self.driver = driver or AsyncGraphDatabase.driver(
graph_database_url,
@ -26,6 +27,9 @@ class Neo4jAdapter(GraphDBInterface):
max_connection_lifetime = 120
)
async def close(self) -> None:
await self.driver.close()
@asynccontextmanager
async def get_session(self) -> AsyncSession:
async with self.driver.session() as session:
@ -59,11 +63,10 @@ class Neo4jAdapter(GraphDBInterface):
async def add_node(self, node: DataPoint):
serialized_properties = self.serialize_properties(node.model_dump())
query = """MERGE (node {id: $node_id})
ON CREATE SET node += $properties
ON MATCH SET node += $properties
ON MATCH SET node.updated_at = timestamp()
RETURN ID(node) AS internal_id, node.id AS nodeId"""
query = dedent("""MERGE (node {id: $node_id})
ON CREATE SET node += $properties, node.updated_at = timestamp()
ON MATCH SET node += $properties, node.updated_at = timestamp()
RETURN ID(node) AS internal_id, node.id AS nodeId""")
params = {
"node_id": str(node.id),
@ -76,9 +79,8 @@ class Neo4jAdapter(GraphDBInterface):
query = """
UNWIND $nodes AS node
MERGE (n {id: node.node_id})
ON CREATE SET n += node.properties
ON MATCH SET n += node.properties
ON MATCH SET n.updated_at = timestamp()
ON CREATE SET n += node.properties, n.updated_at = timestamp()
ON MATCH SET n += node.properties, n.updated_at = timestamp()
WITH n, node.node_id AS label
CALL apoc.create.addLabels(n, [label]) YIELD node AS labeledNode
RETURN ID(labeledNode) AS internal_id, labeledNode.id AS nodeId
@ -133,12 +135,19 @@ class Neo4jAdapter(GraphDBInterface):
return await self.query(query, params)
async def has_edge(self, from_node: UUID, to_node: UUID, edge_label: str) -> bool:
query = f"""
MATCH (from_node:`{str(from_node)}`)-[relationship:`{edge_label}`]->(to_node:`{str(to_node)}`)
query = """
MATCH (from_node)-[relationship]->(to_node)
WHERE from_node.id = $from_node_id AND to_node.id = $to_node_id AND type(relationship) = $edge_label
RETURN COUNT(relationship) > 0 AS edge_exists
"""
edge_exists = await self.query(query)
params = {
"from_node_id": str(from_node),
"to_node_id": str(to_node),
"edge_label": edge_label,
}
edge_exists = await self.query(query, params)
return edge_exists
async def has_edges(self, edges):
@ -165,22 +174,21 @@ class Neo4jAdapter(GraphDBInterface):
raise error
async def add_edge(self, from_node: str, to_node: str, relationship_name: str, edge_properties: Optional[Dict[str, Any]] = {}):
async def add_edge(self, from_node: UUID, to_node: UUID, relationship_name: str, edge_properties: Optional[Dict[str, Any]] = {}):
serialized_properties = self.serialize_properties(edge_properties)
from_node = from_node.replace(":", "_")
to_node = to_node.replace(":", "_")
query = f"""MATCH (from_node:`{str(from_node)}`
{{id: $from_node}}),
(to_node:`{str(to_node)}` {{id: $to_node}})
MERGE (from_node)-[r:`{relationship_name}`]->(to_node)
ON CREATE SET r += $properties, r.updated_at = timestamp()
ON MATCH SET r += $properties, r.updated_at = timestamp()
RETURN r"""
query = dedent("""MATCH (from_node {id: $from_node}),
(to_node {id: $to_node})
MERGE (from_node)-[r]->(to_node)
ON CREATE SET r += $properties, r.updated_at = timestamp(), r.type = $relationship_name
ON MATCH SET r += $properties, r.updated_at = timestamp()
RETURN r
""")
params = {
"from_node": str(from_node),
"to_node": str(to_node),
"relationship_name": relationship_name,
"properties": serialized_properties
}
@ -347,8 +355,8 @@ class Neo4jAdapter(GraphDBInterface):
"""
predecessors, successors = await asyncio.gather(
self.query(predecessors_query, dict(node_id = node_id)),
self.query(successors_query, dict(node_id = node_id)),
self.query(predecessors_query, dict(node_id = str(node_id))),
self.query(successors_query, dict(node_id = str(node_id))),
)
connections = []

View file

@ -270,7 +270,7 @@ class NetworkXAdapter(GraphDBInterface):
except:
pass
if "updated_at" in node:
if "updated_at" in edge:
edge["updated_at"] = datetime.strptime(edge["updated_at"], "%Y-%m-%dT%H:%M:%S.%f%z")
self.graph = nx.readwrite.json_graph.node_link_graph(graph_data)

View file

@ -112,18 +112,10 @@ class LanceDBAdapter(VectorDBInterface):
for (data_point_index, data_point) in enumerate(data_points)
]
# TODO: This enables us to work with pydantic version but shouldn't
# stay like this, existing rows should be updated
await collection.delete("id IS NOT NULL")
original_size = await collection.count_rows()
await collection.add(lance_data_points)
new_size = await collection.count_rows()
if new_size <= original_size:
raise ValueError(
"LanceDB create_datapoints error: data points did not get added.")
await collection.merge_insert("id") \
.when_matched_update_all() \
.when_not_matched_insert_all() \
.execute(lance_data_points)
async def retrieve(self, collection_name: str, data_point_ids: list[str]):

View file

@ -54,7 +54,6 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
vector_size = self.embedding_engine.get_vector_size()
if not await self.has_collection(collection_name):
class PGVectorDataPoint(Base):
__tablename__ = collection_name
__table_args__ = {"extend_existing": True}
@ -80,44 +79,44 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
async def create_data_points(
self, collection_name: str, data_points: List[DataPoint]
):
async with self.get_async_session() as session:
if not await self.has_collection(collection_name):
await self.create_collection(
collection_name=collection_name,
payload_schema=type(data_points[0]),
)
data_vectors = await self.embed_data(
[data_point.get_embeddable_data() for data_point in data_points]
if not await self.has_collection(collection_name):
await self.create_collection(
collection_name = collection_name,
payload_schema = type(data_points[0]),
)
vector_size = self.embedding_engine.get_vector_size()
data_vectors = await self.embed_data(
[data_point.get_embeddable_data() for data_point in data_points]
)
class PGVectorDataPoint(Base):
__tablename__ = collection_name
__table_args__ = {"extend_existing": True}
# PGVector requires one column to be the primary key
primary_key: Mapped[int] = mapped_column(
primary_key=True, autoincrement=True
)
id: Mapped[type(data_points[0].id)]
payload = Column(JSON)
vector = Column(Vector(vector_size))
vector_size = self.embedding_engine.get_vector_size()
def __init__(self, id, payload, vector):
self.id = id
self.payload = payload
self.vector = vector
class PGVectorDataPoint(Base):
__tablename__ = collection_name
__table_args__ = {"extend_existing": True}
# PGVector requires one column to be the primary key
primary_key: Mapped[int] = mapped_column(
primary_key=True, autoincrement=True
)
id: Mapped[type(data_points[0].id)]
payload = Column(JSON)
vector = Column(Vector(vector_size))
pgvector_data_points = [
PGVectorDataPoint(
id=data_point.id,
vector=data_vectors[data_index],
payload=serialize_data(data_point.model_dump()),
)
for (data_index, data_point) in enumerate(data_points)
]
def __init__(self, id, payload, vector):
self.id = id
self.payload = payload
self.vector = vector
pgvector_data_points = [
PGVectorDataPoint(
id = data_point.id,
vector = data_vectors[data_index],
payload = serialize_data(data_point.model_dump()),
)
for (data_index, data_point) in enumerate(data_points)
]
async with self.get_async_session() as session:
session.add_all(pgvector_data_points)
await session.commit()
@ -128,7 +127,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
await self.create_data_points(f"{index_name}_{index_property_name}", [
IndexSchema(
id = data_point.id,
text = getattr(data_point, data_point._metadata["index_fields"][0]),
text = data_point.get_embeddable_data(),
) for data_point in data_points
])
@ -146,10 +145,10 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
raise ValueError(f"Table '{collection_name}' not found.")
async def retrieve(self, collection_name: str, data_point_ids: List[str]):
async with self.get_async_session() as session:
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
async with self.get_async_session() as session:
results = await session.execute(
select(PGVectorDataPoint).where(PGVectorDataPoint.c.id.in_(data_point_ids))
)
@ -177,11 +176,13 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
if query_text and not query_vector:
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
closest_items = []
# Use async session to connect to the database
async with self.get_async_session() as session:
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
# Find closest vectors to query_vector
closest_items = await session.execute(
select(
@ -194,20 +195,21 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
.limit(limit)
)
vector_list = []
# Extract distances and find min/max for normalization
for vector in closest_items:
# TODO: Add normalization of similarity score
vector_list.append(vector)
vector_list = []
# Create and return ScoredResult objects
return [
ScoredResult(
id = UUID(str(row.id)),
payload = row.payload,
score = row.similarity
) for row in vector_list
]
# Extract distances and find min/max for normalization
for vector in closest_items:
# TODO: Add normalization of similarity score
vector_list.append(vector)
# Create and return ScoredResult objects
return [
ScoredResult(
id = UUID(str(row.id)),
payload = row.payload,
score = row.similarity
) for row in vector_list
]
async def batch_search(
self,

View file

@ -1,7 +1,9 @@
import logging
from uuid import UUID
from typing import List, Dict, Optional
from qdrant_client import AsyncQdrantClient, models
from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult
from cognee.infrastructure.engine import DataPoint
from ..vector_db_interface import VectorDBInterface
from ..embeddings.EmbeddingEngine import EmbeddingEngine
@ -153,7 +155,7 @@ class QDrantAdapter(VectorDBInterface):
client = self.get_qdrant_client()
result = await client.search(
results = await client.search(
collection_name = collection_name,
query_vector = models.NamedVector(
name = "text",
@ -165,7 +167,16 @@ class QDrantAdapter(VectorDBInterface):
await client.close()
return result
return [
ScoredResult(
id = UUID(result.id),
payload = {
**result.payload,
"id": UUID(result.id),
},
score = 1 - result.score,
) for result in results
]
async def batch_search(self, collection_name: str, query_texts: List[str], limit: int = None, with_vectors: bool = False):

View file

@ -11,7 +11,6 @@ from ..embeddings.EmbeddingEngine import EmbeddingEngine
logger = logging.getLogger("WeaviateAdapter")
class IndexSchema(DataPoint):
uuid: str
text: str
_metadata: dict = {
@ -58,18 +57,21 @@ class WeaviateAdapter(VectorDBInterface):
future = asyncio.Future()
future.set_result(
self.client.collections.create(
name=collection_name,
properties=[
wvcc.Property(
name="text",
data_type=wvcc.DataType.TEXT,
skip_vectorization=True
)
]
if not self.client.collections.exists(collection_name):
future.set_result(
self.client.collections.create(
name = collection_name,
properties = [
wvcc.Property(
name = "text",
data_type = wvcc.DataType.TEXT,
skip_vectorization = True
)
]
)
)
)
else:
future.set_result(self.get_collection(collection_name))
return await future
@ -80,13 +82,16 @@ class WeaviateAdapter(VectorDBInterface):
from weaviate.classes.data import DataObject
data_vectors = await self.embed_data(
list(map(lambda data_point: data_point.get_embeddable_data(), data_points)))
[data_point.get_embeddable_data() for data_point in data_points]
)
def convert_to_weaviate_data_points(data_point: DataPoint):
vector = data_vectors[data_points.index(data_point)]
properties = data_point.model_dump()
properties["uuid"] = properties["id"]
del properties["id"]
if "id" in properties:
properties["uuid"] = str(data_point.id)
del properties["id"]
return DataObject(
uuid = data_point.id,
@ -94,22 +99,28 @@ class WeaviateAdapter(VectorDBInterface):
vector = vector
)
data_points = list(map(convert_to_weaviate_data_points, data_points))
data_points = [convert_to_weaviate_data_points(data_point) for data_point in data_points]
collection = self.get_collection(collection_name)
try:
if len(data_points) > 1:
return collection.data.insert_many(data_points)
with collection.batch.dynamic() as batch:
for data_point in data_points:
batch.add_object(
uuid = data_point.uuid,
vector = data_point.vector,
properties = data_point.properties,
references = data_point.references,
)
else:
return collection.data.insert(data_points[0])
# with collection.batch.dynamic() as batch:
# for point in data_points:
# batch.add_object(
# uuid = point.uuid,
# properties = point.properties,
# vector = point.vector
# )
data_point: DataObject = data_points[0]
return collection.data.update(
uuid = data_point.uuid,
vector = data_point.vector,
properties = data_point.properties,
references = data_point.references,
)
except Exception as error:
logger.error("Error creating data points: %s", str(error))
raise error
@ -120,8 +131,8 @@ class WeaviateAdapter(VectorDBInterface):
async def index_data_points(self, index_name: str, index_property_name: str, data_points: list[DataPoint]):
await self.create_data_points(f"{index_name}_{index_property_name}", [
IndexSchema(
uuid = str(data_point.id),
text = getattr(data_point, data_point._metadata["index_fields"][0]),
id = data_point.id,
text = data_point.get_embeddable_data(),
) for data_point in data_points
])
@ -168,9 +179,9 @@ class WeaviateAdapter(VectorDBInterface):
return [
ScoredResult(
id = UUID(result.uuid),
id = UUID(str(result.uuid)),
payload = result.properties,
score = float(result.metadata.score)
score = 1 - float(result.metadata.score)
) for result in search_result.objects
]

View file

@ -1,2 +1,3 @@
from .generate_node_id import generate_node_id
from .generate_node_name import generate_node_name
from .generate_edge_name import generate_edge_name

View file

@ -0,0 +1,2 @@
def generate_edge_name(name: str) -> str:
return name.lower().replace(" ", "_").replace("'", "")

View file

@ -1,2 +1,2 @@
def generate_node_name(name: str) -> str:
return name.lower().replace(" ", "_").replace("'", "")
return name.lower().replace("'", "")

View file

@ -9,9 +9,29 @@ import matplotlib.pyplot as plt
import tiktoken
import nltk
from posthog import Posthog
from cognee.base_config import get_base_config
from cognee.infrastructure.databases.graph import get_graph_engine
from uuid import uuid4
import pathlib
def get_anonymous_id():
"""Creates or reads a anonymous user id"""
home_dir = str(pathlib.Path(pathlib.Path(__file__).parent.parent.parent.resolve()))
if not os.path.isdir(home_dir):
os.makedirs(home_dir, exist_ok=True)
anonymous_id_file = os.path.join(home_dir, ".anon_id")
if not os.path.isfile(anonymous_id_file):
anonymous_id = str(uuid4())
with open(anonymous_id_file, "w", encoding="utf-8") as f:
f.write(anonymous_id)
else:
with open(anonymous_id_file, "r", encoding="utf-8") as f:
anonymous_id = f.read()
return anonymous_id
def send_telemetry(event_name: str, user_id, additional_properties: dict = {}):
if os.getenv("TELEMETRY_DISABLED"):
return
@ -28,11 +48,15 @@ def send_telemetry(event_name: str, user_id, additional_properties: dict = {}):
current_time = datetime.now(timezone.utc)
properties = {
"time": current_time.strftime("%m/%d/%Y"),
"user_id": user_id,
**additional_properties,
}
# Needed to forward properties to PostHog along with id
posthog.identify(get_anonymous_id(), properties)
try:
posthog.capture(user_id, event_name, properties)
posthog.capture(get_anonymous_id(), event_name, properties)
except Exception as e:
print("ERROR sending telemetric data to Posthog. See exception: %s", e)

View file

@ -5,7 +5,7 @@ from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.data.extraction.knowledge_graph import extract_content_graph
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
from cognee.modules.engine.models import EntityType, Entity
from cognee.modules.engine.utils import generate_node_id, generate_node_name
from cognee.modules.engine.utils import generate_edge_name, generate_node_id, generate_node_name
from cognee.tasks.storage import add_data_points
async def extract_graph_from_data(data_chunks: list[DocumentChunk], graph_model: Type[BaseModel]):
@ -95,7 +95,7 @@ async def extract_graph_from_data(data_chunks: list[DocumentChunk], graph_model:
for edge in graph.edges:
source_node_id = generate_node_id(edge.source_node_id)
target_node_id = generate_node_id(edge.target_node_id)
relationship_name = generate_node_name(edge.relationship_name)
relationship_name = generate_edge_name(edge.relationship_name)
edge_key = str(source_node_id) + str(target_node_id) + relationship_name
@ -105,7 +105,7 @@ async def extract_graph_from_data(data_chunks: list[DocumentChunk], graph_model:
target_node_id,
edge.relationship_name,
dict(
relationship_name = generate_node_name(edge.relationship_name),
relationship_name = generate_edge_name(edge.relationship_name),
source_node_id = source_node_id,
target_node_id = target_node_id,
),

View file

@ -37,7 +37,7 @@ async def query_graph_connections(query: str, exploration_levels = 1) -> list[(s
return []
node_connections_results = await asyncio.gather(
*[graph_engine.get_connections(str(result.id)) for result in relevant_results]
*[graph_engine.get_connections(result.id) for result in relevant_results]
)
node_connections = []

View file

@ -36,7 +36,7 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity_name", "AI"))[0]
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
random_node_name = random_node.payload["text"]
search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name)

View file

@ -41,7 +41,7 @@ async def main():
cognee.config.system_root_directory(cognee_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await cognee.prune.prune_system(metadata = True)
dataset_name = "cs_explanations"
@ -65,7 +65,7 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity_name", "AI"))[0]
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
random_node_name = random_node.payload["text"]
search_results = await cognee.search(SearchType.INSIGHTS, query=random_node_name)

View file

@ -37,7 +37,7 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity_name", "AI"))[0]
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
random_node_name = random_node.payload["text"]
search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name)

View file

@ -35,7 +35,7 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity_name", "AI"))[0]
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
random_node_name = random_node.payload["text"]
search_results = await cognee.search(SearchType.INSIGHTS, query = random_node_name)

View file

@ -265,7 +265,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": null,
"id": "df16431d0f48b006",
"metadata": {
"ExecuteTime": {
@ -304,7 +304,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": null,
"id": "9086abf3af077ab4",
"metadata": {
"ExecuteTime": {
@ -349,7 +349,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": null,
"id": "a9de0cc07f798b7f",
"metadata": {
"ExecuteTime": {
@ -393,7 +393,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": null,
"id": "185ff1c102d06111",
"metadata": {
"ExecuteTime": {
@ -437,7 +437,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": null,
"id": "d55ce4c58f8efb67",
"metadata": {
"ExecuteTime": {
@ -479,7 +479,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": null,
"id": "ca4ecc32721ad332",
"metadata": {
"ExecuteTime": {
@ -572,7 +572,7 @@
},
{
"cell_type": "code",
"execution_count": 26,
"execution_count": null,
"id": "9f1a1dbd",
"metadata": {},
"outputs": [],
@ -758,7 +758,7 @@
"from cognee.infrastructure.databases.vector import get_vector_engine\n",
"\n",
"vector_engine = get_vector_engine()\n",
"results = await search(vector_engine, \"entities\", \"sarah.nguyen@example.com\")\n",
"results = await search(vector_engine, \"Entity_name\", \"sarah.nguyen@example.com\")\n",
"for result in results:\n",
" print(result)"
]
@ -788,8 +788,8 @@
"source": [
"from cognee.api.v1.search import SearchType\n",
"\n",
"node = (await vector_engine.search(\"entities\", \"sarah.nguyen@example.com\"))[0]\n",
"node_name = node.payload[\"name\"]\n",
"node = (await vector_engine.search(\"Entity_name\", \"sarah.nguyen@example.com\"))[0]\n",
"node_name = node.payload[\"text\"]\n",
"\n",
"search_results = await cognee.search(SearchType.SUMMARIES, query = node_name)\n",
"print(\"\\n\\Extracted summaries are:\\n\")\n",
@ -881,7 +881,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.6"
"version": "3.11.8"
}
},
"nbformat": 4,

81
poetry.lock generated
View file

@ -597,17 +597,17 @@ css = ["tinycss2 (>=1.1.0,<1.5)"]
[[package]]
name = "boto3"
version = "1.35.55"
version = "1.35.57"
description = "The AWS SDK for Python"
optional = false
python-versions = ">=3.8"
files = [
{file = "boto3-1.35.55-py3-none-any.whl", hash = "sha256:c7a0a0bc5ae3bed5d38e8bfe5a56b31621e79bdd7c1ea6e5ba4326d820cde3a5"},
{file = "boto3-1.35.55.tar.gz", hash = "sha256:82fa8cdb00731aeffe7a5829821ae78d75c7ae959b638c15ff3b4681192ace90"},
{file = "boto3-1.35.57-py3-none-any.whl", hash = "sha256:9edf49640c79a05b0a72f4c2d1e24dfc164344b680535a645f455ac624dc3680"},
{file = "boto3-1.35.57.tar.gz", hash = "sha256:db58348849a5af061f0f5ec9c3b699da5221ca83354059fdccb798e3ddb6b62a"},
]
[package.dependencies]
botocore = ">=1.35.55,<1.36.0"
botocore = ">=1.35.57,<1.36.0"
jmespath = ">=0.7.1,<2.0.0"
s3transfer = ">=0.10.0,<0.11.0"
@ -616,13 +616,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"]
[[package]]
name = "botocore"
version = "1.35.55"
version = "1.35.57"
description = "Low-level, data-driven core of boto 3."
optional = false
python-versions = ">=3.8"
files = [
{file = "botocore-1.35.55-py3-none-any.whl", hash = "sha256:3d54739e498534c9d7a6e9732ae2d17ed29c7d5e29fe36c956d8488b859538b0"},
{file = "botocore-1.35.55.tar.gz", hash = "sha256:61ae18f688250372d7b6046e35c86f8fd09a7c0f0064b52688f3490b4d6c9d6b"},
{file = "botocore-1.35.57-py3-none-any.whl", hash = "sha256:92ddd02469213766872cb2399269dd20948f90348b42bf08379881d5e946cc34"},
{file = "botocore-1.35.57.tar.gz", hash = "sha256:d96306558085baf0bcb3b022d7a8c39c93494f031edb376694d2b2dcd0e81327"},
]
[package.dependencies]
@ -2606,22 +2606,22 @@ colors = ["colorama (>=0.4.6)"]
[[package]]
name = "jedi"
version = "0.19.1"
version = "0.19.2"
description = "An autocompletion tool for Python that can be used for text editors."
optional = false
python-versions = ">=3.6"
files = [
{file = "jedi-0.19.1-py2.py3-none-any.whl", hash = "sha256:e983c654fe5c02867aef4cdfce5a2fbb4a50adc0af145f70504238f18ef5e7e0"},
{file = "jedi-0.19.1.tar.gz", hash = "sha256:cf0496f3651bc65d7174ac1b7d043eff454892c708a87d1b683e57b569927ffd"},
{file = "jedi-0.19.2-py2.py3-none-any.whl", hash = "sha256:a8ef22bde8490f57fe5c7681a3c83cb58874daf72b4784de3cce5b6ef6edb5b9"},
{file = "jedi-0.19.2.tar.gz", hash = "sha256:4770dc3de41bde3966b02eb84fbcf557fb33cce26ad23da12c742fb50ecb11f0"},
]
[package.dependencies]
parso = ">=0.8.3,<0.9.0"
parso = ">=0.8.4,<0.9.0"
[package.extras]
docs = ["Jinja2 (==2.11.3)", "MarkupSafe (==1.1.1)", "Pygments (==2.8.1)", "alabaster (==0.7.12)", "babel (==2.9.1)", "chardet (==4.0.0)", "commonmark (==0.8.1)", "docutils (==0.17.1)", "future (==0.18.2)", "idna (==2.10)", "imagesize (==1.2.0)", "mock (==1.0.1)", "packaging (==20.9)", "pyparsing (==2.4.7)", "pytz (==2021.1)", "readthedocs-sphinx-ext (==2.1.4)", "recommonmark (==0.5.0)", "requests (==2.25.1)", "six (==1.15.0)", "snowballstemmer (==2.1.0)", "sphinx (==1.8.5)", "sphinx-rtd-theme (==0.4.3)", "sphinxcontrib-serializinghtml (==1.1.4)", "sphinxcontrib-websupport (==1.2.4)", "urllib3 (==1.26.4)"]
qa = ["flake8 (==5.0.4)", "mypy (==0.971)", "types-setuptools (==67.2.0.1)"]
testing = ["Django", "attrs", "colorama", "docopt", "pytest (<7.0.0)"]
testing = ["Django", "attrs", "colorama", "docopt", "pytest (<9.0.0)"]
[[package]]
name = "jinja2"
@ -2755,15 +2755,18 @@ files = [
[[package]]
name = "json5"
version = "0.9.25"
version = "0.9.27"
description = "A Python implementation of the JSON5 data format."
optional = false
python-versions = ">=3.8"
python-versions = ">=3.8.0"
files = [
{file = "json5-0.9.25-py3-none-any.whl", hash = "sha256:34ed7d834b1341a86987ed52f3f76cd8ee184394906b6e22a1e0deb9ab294e8f"},
{file = "json5-0.9.25.tar.gz", hash = "sha256:548e41b9be043f9426776f05df8635a00fe06104ea51ed24b67f908856e151ae"},
{file = "json5-0.9.27-py3-none-any.whl", hash = "sha256:17b43d78d3a6daeca4d7030e9bf22092dba29b1282cc2d0cfa56f6febee8dc93"},
{file = "json5-0.9.27.tar.gz", hash = "sha256:5a19de4a6ca24ba664dc7d50307eb73ba9a16dea5d6bde85677ae85d3ed2d8e0"},
]
[package.extras]
dev = ["build (==1.2.1)", "coverage (==7.5.3)", "mypy (==1.10.0)", "pip (==24.1)", "pylint (==3.2.3)", "ruff (==0.5.1)", "twine (==5.1.1)", "uv (==0.2.13)"]
[[package]]
name = "jsonpatch"
version = "1.33"
@ -4360,13 +4363,13 @@ files = [
[[package]]
name = "packaging"
version = "24.1"
version = "24.2"
description = "Core utilities for Python packages"
optional = false
python-versions = ">=3.8"
files = [
{file = "packaging-24.1-py3-none-any.whl", hash = "sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124"},
{file = "packaging-24.1.tar.gz", hash = "sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002"},
{file = "packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759"},
{file = "packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f"},
]
[[package]]
@ -4432,8 +4435,8 @@ files = [
[package.dependencies]
numpy = [
{version = ">=1.20.3", markers = "python_version < \"3.10\""},
{version = ">=1.23.2", markers = "python_version >= \"3.11\""},
{version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""},
{version = ">=1.23.2", markers = "python_version >= \"3.11\""},
]
python-dateutil = ">=2.8.2"
pytz = ">=2020.1"
@ -5007,7 +5010,7 @@ test = ["pytest", "pytest-xdist", "setuptools"]
name = "psycopg2"
version = "2.9.10"
description = "psycopg2 - Python-PostgreSQL Database Adapter"
optional = true
optional = false
python-versions = ">=3.8"
files = [
{file = "psycopg2-2.9.10-cp310-cp310-win32.whl", hash = "sha256:5df2b672140f95adb453af93a7d669d7a7bf0a56bcd26f1502329166f4a61716"},
@ -7032,13 +7035,13 @@ test = ["vcrpy (>=1.10.3)"]
[[package]]
name = "typer"
version = "0.12.5"
version = "0.13.0"
description = "Typer, build great CLIs. Easy to code. Based on Python type hints."
optional = false
python-versions = ">=3.7"
files = [
{file = "typer-0.12.5-py3-none-any.whl", hash = "sha256:62fe4e471711b147e3365034133904df3e235698399bc4de2b36c8579298d52b"},
{file = "typer-0.12.5.tar.gz", hash = "sha256:f592f089bedcc8ec1b974125d64851029c3b1af145f04aca64d69410f0c9b722"},
{file = "typer-0.13.0-py3-none-any.whl", hash = "sha256:d85fe0b777b2517cc99c8055ed735452f2659cd45e451507c76f48ce5c1d00e2"},
{file = "typer-0.13.0.tar.gz", hash = "sha256:f1c7198347939361eec90139ffa0fd8b3df3a2259d5852a0f7400e476d95985c"},
]
[package.dependencies]
@ -7333,19 +7336,15 @@ validators = "0.33.0"
[[package]]
name = "webcolors"
version = "24.8.0"
version = "24.11.1"
description = "A library for working with the color formats defined by HTML and CSS."
optional = false
python-versions = ">=3.8"
python-versions = ">=3.9"
files = [
{file = "webcolors-24.8.0-py3-none-any.whl", hash = "sha256:fc4c3b59358ada164552084a8ebee637c221e4059267d0f8325b3b560f6c7f0a"},
{file = "webcolors-24.8.0.tar.gz", hash = "sha256:08b07af286a01bcd30d583a7acadf629583d1f79bfef27dd2c2c5c263817277d"},
{file = "webcolors-24.11.1-py3-none-any.whl", hash = "sha256:515291393b4cdf0eb19c155749a096f779f7d909f7cceea072791cb9095b92e9"},
{file = "webcolors-24.11.1.tar.gz", hash = "sha256:ecb3d768f32202af770477b8b65f318fa4f566c22948673a977b00d589dd80f6"},
]
[package.extras]
docs = ["furo", "sphinx", "sphinx-copybutton", "sphinx-inline-tabs", "sphinx-notfound-page", "sphinxext-opengraph"]
tests = ["coverage[toml]"]
[[package]]
name = "webencodings"
version = "0.5.1"
@ -7375,13 +7374,13 @@ test = ["websockets"]
[[package]]
name = "wheel"
version = "0.44.0"
version = "0.45.0"
description = "A built-package format for Python"
optional = false
python-versions = ">=3.8"
files = [
{file = "wheel-0.44.0-py3-none-any.whl", hash = "sha256:2376a90c98cc337d18623527a97c31797bd02bad0033d41547043a1cbfbe448f"},
{file = "wheel-0.44.0.tar.gz", hash = "sha256:a29c3f2817e95ab89aa4660681ad547c0e9547f20e75b0562fe7723c9a2a9d49"},
{file = "wheel-0.45.0-py3-none-any.whl", hash = "sha256:52f0baa5e6522155090a09c6bd95718cc46956d1b51d537ea5454249edb671c7"},
{file = "wheel-0.45.0.tar.gz", hash = "sha256:a57353941a3183b3d5365346b567a260a0602a0f8a635926a7dede41b94c674a"},
]
[package.extras]
@ -7718,13 +7717,13 @@ propcache = ">=0.2.0"
[[package]]
name = "zipp"
version = "3.20.2"
version = "3.21.0"
description = "Backport of pathlib-compatible object wrapper for zip files"
optional = false
python-versions = ">=3.8"
python-versions = ">=3.9"
files = [
{file = "zipp-3.20.2-py3-none-any.whl", hash = "sha256:a817ac80d6cf4b23bf7f2828b7cabf326f15a001bea8b1f9b49631780ba28350"},
{file = "zipp-3.20.2.tar.gz", hash = "sha256:bc9eb26f4506fda01b81bcde0ca78103b6e62f991b381fec825435c836edbc29"},
{file = "zipp-3.21.0-py3-none-any.whl", hash = "sha256:ac1bbe05fd2991f160ebce24ffbac5f6d11d83dc90891255885223d42b3cd931"},
{file = "zipp-3.21.0.tar.gz", hash = "sha256:2c9958f6430a2040341a52eb608ed6dd93ef4392e02ffe219417c1b28b5dd1f4"},
]
[package.extras]
@ -7740,11 +7739,11 @@ cli = []
filesystem = ["botocore"]
neo4j = ["neo4j"]
notebook = []
postgres = ["psycopg2"]
postgres = ["asyncpg", "pgvector", "psycopg2"]
qdrant = ["qdrant-client"]
weaviate = ["weaviate-client"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.9.0,<3.12"
content-hash = "426fa990f2bdd15fa5be55392beb4cf77dba320f2e95cc503d1c0549d9758d64"
content-hash = "fb09733ff7a70fb91c5f72ff0c8a8137b857557930a7aa025aad3154de4d8ceb"

View file

@ -42,7 +42,7 @@ aiosqlite = "^0.20.0"
pandas = "2.0.3"
filetype = "^1.2.0"
nltk = "^3.8.1"
dlt = {extras = ["sqlalchemy"], version = "^1.2.0"}
dlt = {extras = ["sqlalchemy"], version = "^1.3.0"}
aiofiles = "^23.2.1"
qdrant-client = "^1.9.0"
graphistry = "^0.33.5"
@ -66,10 +66,10 @@ pydantic-settings = "^2.2.1"
anthropic = "^0.26.1"
sentry-sdk = {extras = ["fastapi"], version = "^2.9.0"}
fastapi-users = {version = "*", extras = ["sqlalchemy"]}
asyncpg = "^0.29.0"
alembic = "^1.13.3"
asyncpg = "^0.29.0"
pgvector = "^0.3.5"
psycopg2 = {version = "^2.9.10", optional = true}
psycopg2 = "^2.9.10"
[tool.poetry.extras]
filesystem = ["s3fs", "botocore"]
@ -77,9 +77,10 @@ cli = ["pipdeptree", "cron-descriptor"]
weaviate = ["weaviate-client"]
qdrant = ["qdrant-client"]
neo4j = ["neo4j"]
postgres = ["psycopg2"]
postgres = ["psycopg2", "pgvector", "asyncpg"]
notebook = ["ipykernel", "overrides", "ipywidgets", "jupyterlab", "jupyterlab_widgets", "jupyterlab-server", "jupyterlab-git"]
[tool.poetry.group.dev.dependencies]
pytest = "^7.4.0"
pytest-asyncio = "^0.21.1"