Merge remote-tracking branch 'origin/main'

This commit is contained in:
Boris Arzentar 2024-12-04 11:16:16 +01:00
commit 4678aaef52
21 changed files with 858 additions and 697 deletions

View file

@ -14,7 +14,7 @@ GRAPH_DATABASE_URL=
GRAPH_DATABASE_USERNAME=
GRAPH_DATABASE_PASSWORD=
# "qdrant", "pgvector", "weaviate" or "lancedb"
# "qdrant", "pgvector", "weaviate", "milvus" or "lancedb"
VECTOR_DB_PROVIDER="lancedb"
# Not needed if using "lancedb" or "pgvector"
VECTOR_DB_URL=

64
.github/workflows/test_milvus.yml vendored Normal file
View file

@ -0,0 +1,64 @@
name: test | milvus
on:
workflow_dispatch:
pull_request:
branches:
- main
types: [labeled, synchronize]
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true
env:
RUNTIME__LOG_LEVEL: ERROR
ENV: 'dev'
jobs:
get_docs_changes:
name: docs changes
uses: ./.github/workflows/get_docs_changes.yml
run_milvus:
name: test
needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest
strategy:
fail-fast: false
defaults:
run:
shell: bash
steps:
- name: Check out
uses: actions/checkout@master
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: '3.11.x'
- name: Install Poetry
# https://github.com/snok/install-poetry#running-on-windows
uses: snok/install-poetry@v1.3.2
with:
virtualenvs-create: true
virtualenvs-in-project: true
installer-parallel: true
- name: Install dependencies
run: poetry install -E milvus --no-interaction
- name: Run default basic pipeline
env:
ENV: 'dev'
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: poetry run python ./cognee/tests/test_milvus.py
- name: Clean up disk space
run: |
sudo rm -rf ~/.cache
sudo rm -rf /tmp/*
df -h

View file

@ -46,7 +46,7 @@ jobs:
installer-parallel: true
- name: Install dependencies
run: poetry install --no-interaction
run: poetry install -E neo4j --no-interaction
- name: Run default Neo4j
env:

View file

@ -47,7 +47,7 @@ jobs:
installer-parallel: true
- name: Install dependencies
run: poetry install --no-interaction
run: poetry install -E qdrant --no-interaction
- name: Run default Qdrant
env:

View file

@ -47,7 +47,7 @@ jobs:
installer-parallel: true
- name: Install dependencies
run: poetry install --no-interaction
run: poetry install -E weaviate --no-interaction
- name: Run default Weaviate
env:

View file

@ -20,30 +20,57 @@ If you have questions, join our <a href="https://discord.gg/NQPKmU5CCg">Discord
## 📦 Installation
You can install Cognee using either **pip** or **poetry**.
Support for various databases and vector stores is available through extras.
### With pip
```bash
pip install cognee
```
### With pip with PostgreSQL support
```bash
pip install 'cognee[postgres]'
```
### With poetry
```bash
poetry add cognee
```
### With poetry with PostgreSQL support
### With pip with specific database support
To install Cognee with support for specific databases use the appropriate command below. Replace \<database> with the name of the database you need.
```bash
poetry add cognee -E postgres
pip install 'cognee[<database>]'
```
Replace \<database> with any of the following databases:
- postgres
- weaviate
- qdrant
- neo4j
- milvus
Installing Cognee with PostgreSQL and Neo4j support example:
```bash
pip install 'cognee[postgres, neo4j]'
```
### With poetry with specific database support
To install Cognee with support for specific databases use the appropriate command below. Replace \<database> with the name of the database you need.
```bash
poetry add cognee -E <database>
```
Replace \<database> with any of the following databases:
- postgres
- weaviate
- qdrant
- neo4j
- milvus
Installing Cognee with PostgreSQL and Neo4j support example:
```bash
poetry add cognee -E postgres -E neo4j
```
## 💻 Basic Usage
@ -317,12 +344,13 @@ pip install cognee
}
</style>
| Name | Type | Current state | Known Issues |
|------------------|--------------------|-------------------|---------------------------------------|
| Qdrant | Vector | Stable &#x2705; | |
| Weaviate | Vector | Stable &#x2705; | |
| LanceDB | Vector | Stable &#x2705; | |
| Neo4j | Graph | Stable &#x2705; | |
| NetworkX | Graph | Stable &#x2705; | |
| FalkorDB | Vector/Graph | Unstable &#x274C; | |
| PGVector | Vector | Unstable &#x274C; | Postgres DB returns the Timeout error |
| Name | Type | Current state | Known Issues |
|----------|--------------------|-------------------|--------------|
| Qdrant | Vector | Stable &#x2705; | |
| Weaviate | Vector | Stable &#x2705; | |
| LanceDB | Vector | Stable &#x2705; | |
| Neo4j | Graph | Stable &#x2705; | |
| NetworkX | Graph | Stable &#x2705; | |
| FalkorDB | Vector/Graph | Unstable &#x274C; | |
| PGVector | Vector | Stable &#x2705; | |
| Milvus | Vector | Stable &#x2705; | |

View file

@ -1,11 +1,13 @@
from typing import Dict
class VectorConfig(Dict):
vector_db_url: str
vector_db_port: str
vector_db_key: str
vector_db_provider: str
def create_vector_engine(config: VectorConfig, embedding_engine):
if config["vector_db_provider"] == "weaviate":
from .weaviate_db import WeaviateAdapter
@ -16,24 +18,37 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
return WeaviateAdapter(
config["vector_db_url"],
config["vector_db_key"],
embedding_engine = embedding_engine
embedding_engine=embedding_engine
)
elif config["vector_db_provider"] == "qdrant":
if not (config["vector_db_url"] and config["vector_db_key"]):
raise EnvironmentError("Missing requred Qdrant credentials!")
from .qdrant.QDrantAdapter import QDrantAdapter
return QDrantAdapter(
url = config["vector_db_url"],
api_key = config["vector_db_key"],
embedding_engine = embedding_engine
url=config["vector_db_url"],
api_key=config["vector_db_key"],
embedding_engine=embedding_engine
)
elif config['vector_db_provider'] == 'milvus':
from .milvus.MilvusAdapter import MilvusAdapter
if not config["vector_db_url"]:
raise EnvironmentError("Missing required Milvus credentials!")
return MilvusAdapter(
url=config["vector_db_url"],
api_key=config['vector_db_key'],
embedding_engine=embedding_engine
)
elif config["vector_db_provider"] == "pgvector":
from cognee.infrastructure.databases.relational import get_relational_config
# Get configuration for postgres database
relational_config = get_relational_config()
db_username = relational_config.db_username
@ -52,8 +67,8 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
from .pgvector.PGVectorAdapter import PGVectorAdapter
return PGVectorAdapter(
connection_string,
config["vector_db_key"],
connection_string,
config["vector_db_key"],
embedding_engine,
)
@ -64,16 +79,16 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
from ..hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter
return FalkorDBAdapter(
database_url = config["vector_db_url"],
database_port = config["vector_db_port"],
embedding_engine = embedding_engine,
database_url=config["vector_db_url"],
database_port=config["vector_db_port"],
embedding_engine=embedding_engine,
)
else:
from .lancedb.LanceDBAdapter import LanceDBAdapter
return LanceDBAdapter(
url = config["vector_db_url"],
api_key = config["vector_db_key"],
embedding_engine = embedding_engine,
url=config["vector_db_url"],
api_key=config["vector_db_key"],
embedding_engine=embedding_engine,
)

View file

@ -0,0 +1,252 @@
from __future__ import annotations
import asyncio
import logging
from typing import List, Optional
from uuid import UUID
from cognee.infrastructure.engine import DataPoint
from ..vector_db_interface import VectorDBInterface
from ..models.ScoredResult import ScoredResult
from ..embeddings.EmbeddingEngine import EmbeddingEngine
logger = logging.getLogger("MilvusAdapter")
class IndexSchema(DataPoint):
text: str
_metadata: dict = {
"index_fields": ["text"]
}
class MilvusAdapter(VectorDBInterface):
name = "Milvus"
url: str
api_key: Optional[str]
embedding_engine: EmbeddingEngine = None
def __init__(self, url: str, api_key: Optional[str], embedding_engine: EmbeddingEngine):
self.url = url
self.api_key = api_key
self.embedding_engine = embedding_engine
def get_milvus_client(self) -> "MilvusClient":
from pymilvus import MilvusClient
if self.api_key:
client = MilvusClient(uri=self.url, token=self.api_key)
else:
client = MilvusClient(uri=self.url)
return client
async def embed_data(self, data: List[str]) -> list[list[float]]:
return await self.embedding_engine.embed_text(data)
async def has_collection(self, collection_name: str) -> bool:
future = asyncio.Future()
client = self.get_milvus_client()
future.set_result(client.has_collection(collection_name=collection_name))
return await future
async def create_collection(
self,
collection_name: str,
payload_schema=None,
):
from pymilvus import DataType, MilvusException
client = self.get_milvus_client()
if client.has_collection(collection_name=collection_name):
logger.info(f"Collection '{collection_name}' already exists.")
return True
try:
dimension = self.embedding_engine.get_vector_size()
assert dimension > 0, "Embedding dimension must be greater than 0."
schema = client.create_schema(
auto_id=False,
enable_dynamic_field=False,
)
schema.add_field(
field_name="id",
datatype=DataType.VARCHAR,
is_primary=True,
max_length=36
)
schema.add_field(
field_name="vector",
datatype=DataType.FLOAT_VECTOR,
dim=dimension
)
schema.add_field(
field_name="text",
datatype=DataType.VARCHAR,
max_length=60535
)
index_params = client.prepare_index_params()
index_params.add_index(
field_name="vector",
metric_type="COSINE"
)
client.create_collection(
collection_name=collection_name,
schema=schema,
index_params=index_params
)
client.load_collection(collection_name)
logger.info(f"Collection '{collection_name}' created successfully.")
return True
except MilvusException as e:
logger.error(f"Error creating collection '{collection_name}': {str(e)}")
raise e
async def create_data_points(
self,
collection_name: str,
data_points: List[DataPoint]
):
from pymilvus import MilvusException
client = self.get_milvus_client()
data_vectors = await self.embed_data(
[data_point.get_embeddable_data() for data_point in data_points]
)
insert_data = [
{
"id": str(data_point.id),
"vector": data_vectors[index],
"text": data_point.text,
}
for index, data_point in enumerate(data_points)
]
try:
result = client.insert(
collection_name=collection_name,
data=insert_data
)
logger.info(
f"Inserted {result.get('insert_count', 0)} data points into collection '{collection_name}'."
)
return result
except MilvusException as e:
logger.error(f"Error inserting data points into collection '{collection_name}': {str(e)}")
raise e
async def create_vector_index(self, index_name: str, index_property_name: str):
await self.create_collection(f"{index_name}_{index_property_name}")
async def index_data_points(self, index_name: str, index_property_name: str, data_points: List[DataPoint]):
formatted_data_points = [
IndexSchema(
id=data_point.id,
text=getattr(data_point, data_point._metadata["index_fields"][0]),
)
for data_point in data_points
]
collection_name = f"{index_name}_{index_property_name}"
await self.create_data_points(collection_name, formatted_data_points)
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
from pymilvus import MilvusException
client = self.get_milvus_client()
try:
filter_expression = f"""id in [{", ".join(f'"{id}"' for id in data_point_ids)}]"""
results = client.query(
collection_name=collection_name,
expr=filter_expression,
output_fields=["*"],
)
return results
except MilvusException as e:
logger.error(f"Error retrieving data points from collection '{collection_name}': {str(e)}")
raise e
async def search(
self,
collection_name: str,
query_text: Optional[str] = None,
query_vector: Optional[List[float]] = None,
limit: int = 5,
with_vector: bool = False,
):
from pymilvus import MilvusException
client = self.get_milvus_client()
if query_text is None and query_vector is None:
raise ValueError("One of query_text or query_vector must be provided!")
try:
query_vector = query_vector or (await self.embed_data([query_text]))[0]
output_fields = ["id", "text"]
if with_vector:
output_fields.append("vector")
results = client.search(
collection_name=collection_name,
data=[query_vector],
anns_field="vector",
limit=limit,
output_fields=output_fields,
search_params={
"metric_type": "COSINE",
},
)
return [
ScoredResult(
id=UUID(result["id"]),
score=result["distance"],
payload=result.get("entity", {}),
)
for result in results[0]
]
except MilvusException as e:
logger.error(f"Error during search in collection '{collection_name}': {str(e)}")
raise e
async def batch_search(self, collection_name: str, query_texts: List[str], limit: int, with_vectors: bool = False):
query_vectors = await self.embed_data(query_texts)
return await asyncio.gather(
*[self.search(collection_name=collection_name,
query_vector=query_vector,
limit=limit,
with_vector=with_vectors,
) for query_vector in query_vectors]
)
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
from pymilvus import MilvusException
client = self.get_milvus_client()
try:
filter_expression = f"""id in [{", ".join(f'"{id}"' for id in data_point_ids)}]"""
delete_result = client.delete(
collection_name=collection_name,
filter=filter_expression
)
logger.info(f"Deleted data points with IDs {data_point_ids} from collection '{collection_name}'.")
return delete_result
except MilvusException as e:
logger.error(f"Error deleting data points from collection '{collection_name}': {str(e)}")
raise e
async def prune(self):
client = self.get_milvus_client()
if client:
collections = client.list_collections()
for collection_name in collections:
client.drop_collection(collection_name=collection_name)
client.close()

View file

@ -0,0 +1 @@
from .MilvusAdapter import MilvusAdapter

View file

@ -1,6 +1,5 @@
import asyncio
from uuid import UUID
from pgvector.sqlalchemy import Vector
from typing import List, Optional, get_type_hints
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy import JSON, Column, Table, select, delete
@ -70,6 +69,8 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
vector_size = self.embedding_engine.get_vector_size()
if not await self.has_collection(collection_name):
from pgvector.sqlalchemy import Vector
class PGVectorDataPoint(Base):
__tablename__ = collection_name
__table_args__ = {"extend_existing": True}
@ -107,6 +108,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
vector_size = self.embedding_engine.get_vector_size()
from pgvector.sqlalchemy import Vector
class PGVectorDataPoint(Base):
__tablename__ = collection_name
__table_args__ = {"extend_existing": True}

View file

@ -1,41 +1,29 @@
from typing import Callable
from pydantic_core import PydanticUndefined
from cognee.infrastructure.engine import DataPoint
from cognee.modules.storage.utils import copy_model
def get_model_instance_from_graph(
nodes: list[DataPoint],
edges: list[tuple[str, str, str, dict[str, str]]],
entity_id: str,
):
node_map = {node.id: node for node in nodes}
def get_model_instance_from_graph(nodes: list[DataPoint], edges: list, entity_id: str):
node_map = {}
for source_node_id, target_node_id, edge_label, edge_properties in edges:
source_node = node_map[source_node_id]
target_node = node_map[target_node_id]
for node in nodes:
node_map[node.id] = node
for edge in edges:
source_node = node_map[edge[0]]
target_node = node_map[edge[1]]
edge_label = edge[2]
edge_properties = edge[3] if len(edge) == 4 else {}
edge_metadata = edge_properties.get("metadata", {})
edge_type = edge_metadata.get("type", "default")
edge_type = edge_metadata.get("type")
if edge_type == "list":
NewModel = copy_model(
type(source_node),
{edge_label: (list[type(target_node)], PydanticUndefined)},
)
source_node_dict = source_node.model_dump()
source_node_edge_label_values = source_node_dict.get(edge_label, [])
source_node_dict[edge_label] = source_node_edge_label_values + [target_node]
NewModel = copy_model(type(source_node), { edge_label: (list[type(target_node)], PydanticUndefined) })
node_map[source_node_id] = NewModel(**source_node_dict)
node_map[edge[0]] = NewModel(**source_node.model_dump(), **{ edge_label: [target_node] })
else:
NewModel = copy_model(
type(source_node), {edge_label: (type(target_node), PydanticUndefined)}
)
NewModel = copy_model(type(source_node), { edge_label: (type(target_node), PydanticUndefined) })
node_map[target_node_id] = NewModel(
**source_node.model_dump(), **{edge_label: target_node}
)
node_map[edge[0]] = NewModel(**source_node.model_dump(), **{ edge_label: target_node })
return node_map[entity_id]

View file

@ -17,7 +17,7 @@ from uuid import uuid4
import pathlib
# Analytics Proxy Url, currently hosted by Vercel
vercel_url = "https://proxyanalytics.vercel.app"
proxy_url = "https://test.prometh.ai"
def get_anonymous_id():
"""Creates or reads a anonymous user id"""
@ -57,7 +57,7 @@ def send_telemetry(event_name: str, user_id, additional_properties: dict = {}):
},
}
response = requests.post(vercel_url, json=payload)
response = requests.post(proxy_url, json=payload)
if response.status_code != 200:
print(f"Error sending telemetry through proxy: {response.status_code}")

View file

@ -0,0 +1,84 @@
import os
import logging
import pathlib
import cognee
from cognee.api.v1.search import SearchType
logging.basicConfig(level=logging.DEBUG)
async def main():
cognee.config.set_vector_db_provider("milvus")
data_directory_path = str(
pathlib.Path(os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_milvus")).resolve())
cognee.config.data_root_directory(data_directory_path)
cognee_directory_path = str(
pathlib.Path(os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_milvus")).resolve())
cognee.config.system_root_directory(cognee_directory_path)
cognee.config.set_vector_db_config(
{
"vector_db_url": os.path.join(cognee_directory_path, "databases/milvus.db"),
"vector_db_key": "",
"vector_db_provider": "milvus"
}
)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
dataset_name = "cs_explanations"
explanation_file_path = os.path.join(pathlib.Path(__file__).parent, "test_data/Natural_language_processing.txt")
await cognee.add([explanation_file_path], dataset_name)
text = """A quantum computer is a computer that takes advantage of quantum mechanical phenomena.
At small scales, physical matter exhibits properties of both particles and waves, and quantum computing leverages this behavior, specifically quantum superposition and entanglement, using specialized hardware that supports the preparation and manipulation of quantum states.
Classical physics cannot explain the operation of these quantum devices, and a scalable quantum computer could perform some calculations exponentially faster (with respect to input size scaling) than any modern "classical" computer. In particular, a large-scale quantum computer could break widely used encryption schemes and aid physicists in performing physical simulations; however, the current state of the technology is largely experimental and impractical, with several obstacles to useful applications. Moreover, scalable quantum computers do not hold promise for many practical tasks, and for many important tasks quantum speedups are proven impossible.
The basic unit of information in quantum computing is the qubit, similar to the bit in traditional digital electronics. Unlike a classical bit, a qubit can exist in a superposition of its two "basis" states. When measuring a qubit, the result is a probabilistic output of a classical bit, therefore making quantum computers nondeterministic in general. If a quantum computer manipulates the qubit in a particular way, wave interference effects can amplify the desired measurement results. The design of quantum algorithms involves creating procedures that allow a quantum computer to perform calculations efficiently and quickly.
Physically engineering high-quality qubits has proven challenging. If a physical qubit is not sufficiently isolated from its environment, it suffers from quantum decoherence, introducing noise into calculations. Paradoxically, perfectly isolating qubits is also undesirable because quantum computations typically need to initialize qubits, perform controlled qubit interactions, and measure the resulting quantum states. Each of those operations introduces errors and suffers from noise, and such inaccuracies accumulate.
In principle, a non-quantum (classical) computer can solve the same computational problems as a quantum computer, given enough time. Quantum advantage comes in the form of time complexity rather than computability, and quantum complexity theory shows that some quantum algorithms for carefully selected tasks require exponentially fewer computational steps than the best known non-quantum algorithms. Such tasks can in theory be solved on a large-scale quantum computer whereas classical computers would not finish computations in any reasonable amount of time. However, quantum speedup is not universal or even typical across computational tasks, since basic tasks such as sorting are proven to not allow any asymptotic quantum speedup. Claims of quantum supremacy have drawn significant attention to the discipline, but are demonstrated on contrived tasks, while near-term practical use cases remain limited.
"""
await cognee.add([text], dataset_name)
await cognee.cognify([dataset_name])
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("entity_name", "Quantum computer"))[0]
random_node_name = random_node.payload["text"]
search_results = await cognee.search(SearchType.INSIGHTS, query_text=random_node_name)
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted INSIGHTS are:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search(SearchType.CHUNKS, query_text=random_node_name)
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted CHUNKS are:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search(SearchType.SUMMARIES, query_text=random_node_name)
assert len(search_results) != 0, "The search results list is empty."
print("\nExtracted SUMMARIES are:\n")
for result in search_results:
print(f"{result}\n")
history = await cognee.get_search_history()
assert len(history) == 6, "Search history is not correct."
await cognee.prune.prune_data()
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
await cognee.prune.prune_system(metadata=True)
milvus_client = get_vector_engine().get_milvus_client()
collections = milvus_client.list_collections()
assert len(collections) == 0, "Milvus vector database is not empty"
if __name__ == "__main__":
import asyncio
asyncio.run(main())

View file

@ -1,68 +0,0 @@
from enum import Enum
from typing import Optional
import pytest
from cognee.infrastructure.engine import DataPoint
class CarTypeName(Enum):
Pickup = "Pickup"
Sedan = "Sedan"
SUV = "SUV"
Coupe = "Coupe"
Convertible = "Convertible"
Hatchback = "Hatchback"
Wagon = "Wagon"
Minivan = "Minivan"
Van = "Van"
class CarType(DataPoint):
id: str
name: CarTypeName
_metadata: dict = dict(index_fields=["name"])
class Car(DataPoint):
id: str
brand: str
model: str
year: int
color: str
is_type: CarType
class Person(DataPoint):
id: str
name: str
age: int
owns_car: list[Car]
driving_license: Optional[dict]
_metadata: dict = dict(index_fields=["name"])
@pytest.fixture(scope="function")
def boris():
boris = Person(
id="boris",
name="Boris",
age=30,
owns_car=[
Car(
id="car1",
brand="Toyota",
model="Camry",
year=2020,
color="Blue",
is_type=CarType(id="sedan", name=CarTypeName.Sedan),
)
],
driving_license={
"issued_by": "PU Vrsac",
"issued_on": "2025-11-06",
"number": "1234567890",
"expires_on": "2025-11-06",
},
)
return boris

View file

@ -1,37 +0,0 @@
import warnings
import pytest
from cognee.modules.graph.utils import get_graph_from_model
from cognee.tests.unit.interfaces.graph.util import (
PERSON_NAMES,
count_society,
create_organization_recursive,
)
@pytest.mark.parametrize("recursive_depth", [1, 2, 3])
async def test_society_nodes_and_edges(recursive_depth):
import sys
if sys.version_info[0] == 3 and sys.version_info[1] >= 11:
society = create_organization_recursive(
"society", "Society", PERSON_NAMES, recursive_depth
)
n_organizations, n_persons = count_society(society)
society_counts_total = n_organizations + n_persons
nodes, edges = await get_graph_from_model(society)
assert (
len(nodes) == society_counts_total
), f"{society_counts_total = } != {len(nodes) = }, not all DataPoint instances were found"
assert len(edges) == (
len(nodes) - 1
), f"{(len(nodes) - 1) = } != {len(edges) = }, there have to be n_nodes - 1 edges, as each node has exactly one parent node, except for the root node"
else:
warnings.warn(
"The recursive pydantic data structure cannot be reconstructed from the graph because the 'inner' pydantic class is not defined. Hence this test is skipped. This problem is solved in Python 3.11"
)

View file

@ -1,89 +0,0 @@
from cognee.modules.graph.utils import get_graph_from_model
from cognee.tests.unit.interfaces.graph.util import run_test_against_ground_truth
CAR_SEDAN_EDGE = (
"car1",
"sedan",
"is_type",
{
"source_node_id": "car1",
"target_node_id": "sedan",
"relationship_name": "is_type",
},
)
BORIS_CAR_EDGE_GROUND_TRUTH = (
"boris",
"car1",
"owns_car",
{
"source_node_id": "boris",
"target_node_id": "car1",
"relationship_name": "owns_car",
"metadata": {"type": "list"},
},
)
CAR_TYPE_GROUND_TRUTH = {"id": "sedan"}
CAR_GROUND_TRUTH = {
"id": "car1",
"brand": "Toyota",
"model": "Camry",
"year": 2020,
"color": "Blue",
}
PERSON_GROUND_TRUTH = {
"id": "boris",
"name": "Boris",
"age": 30,
"driving_license": {
"issued_by": "PU Vrsac",
"issued_on": "2025-11-06",
"number": "1234567890",
"expires_on": "2025-11-06",
},
}
async def test_extracted_car_type(boris):
nodes, _ = await get_graph_from_model(boris)
assert len(nodes) == 3
car_type = nodes[0]
run_test_against_ground_truth("car_type", car_type, CAR_TYPE_GROUND_TRUTH)
async def test_extracted_car(boris):
nodes, _ = await get_graph_from_model(boris)
assert len(nodes) == 3
car = nodes[1]
run_test_against_ground_truth("car", car, CAR_GROUND_TRUTH)
async def test_extracted_person(boris):
nodes, _ = await get_graph_from_model(boris)
assert len(nodes) == 3
person = nodes[2]
run_test_against_ground_truth("person", person, PERSON_GROUND_TRUTH)
async def test_extracted_car_sedan_edge(boris):
_, edges = await get_graph_from_model(boris)
edge = edges[0]
assert CAR_SEDAN_EDGE[:3] == edge[:3], f"{CAR_SEDAN_EDGE[:3] = } != {edge[:3] = }"
for key, ground_truth in CAR_SEDAN_EDGE[3].items():
assert ground_truth == edge[3][key], f"{ground_truth = } != {edge[3][key] = }"
async def test_extracted_boris_car_edge(boris):
_, edges = await get_graph_from_model(boris)
edge = edges[1]
assert (
BORIS_CAR_EDGE_GROUND_TRUTH[:3] == edge[:3]
), f"{BORIS_CAR_EDGE_GROUND_TRUTH[:3] = } != {edge[:3] = }"
for key, ground_truth in BORIS_CAR_EDGE_GROUND_TRUTH[3].items():
assert ground_truth == edge[3][key], f"{ground_truth = } != {edge[3][key] = }"

View file

@ -1,33 +0,0 @@
import warnings
import pytest
from cognee.modules.graph.utils import (
get_graph_from_model,
get_model_instance_from_graph,
)
from cognee.tests.unit.interfaces.graph.util import (
PERSON_NAMES,
create_organization_recursive,
show_first_difference,
)
@pytest.mark.parametrize("recursive_depth", [1, 2, 3])
async def test_society_nodes_and_edges(recursive_depth):
import sys
if sys.version_info[0] == 3 and sys.version_info[1] >= 11:
society = create_organization_recursive(
"society", "Society", PERSON_NAMES, recursive_depth
)
nodes, edges = await get_graph_from_model(society)
parsed_society = get_model_instance_from_graph(nodes, edges, "society")
assert str(society) == (str(parsed_society)), show_first_difference(
str(society), str(parsed_society), "society", "parsed_society"
)
else:
warnings.warn(
"The recursive pydantic data structure cannot be reconstructed from the graph because the 'inner' pydantic class is not defined. Hence this test is skipped. This problem is solved in Python 3.11"
)

View file

@ -1,35 +0,0 @@
from cognee.modules.graph.utils import (
get_graph_from_model,
get_model_instance_from_graph,
)
from cognee.tests.unit.interfaces.graph.util import run_test_against_ground_truth
PARSED_PERSON_GROUND_TRUTH = {
"id": "boris",
"name": "Boris",
"age": 30,
"driving_license": {
"issued_by": "PU Vrsac",
"issued_on": "2025-11-06",
"number": "1234567890",
"expires_on": "2025-11-06",
},
}
CAR_GROUND_TRUTH = {
"id": "car1",
"brand": "Toyota",
"model": "Camry",
"year": 2020,
"color": "Blue",
}
async def test_parsed_person(boris):
nodes, edges = await get_graph_from_model(boris)
parsed_person = get_model_instance_from_graph(nodes, edges, "boris")
run_test_against_ground_truth(
"parsed_person", parsed_person, PARSED_PERSON_GROUND_TRUTH
)
run_test_against_ground_truth("car", parsed_person.owns_car[0], CAR_GROUND_TRUTH)

View file

@ -1,150 +0,0 @@
import random
import string
from datetime import datetime, timezone
from typing import Any, Dict, Optional
from cognee.infrastructure.engine import DataPoint
def run_test_against_ground_truth(
test_target_item_name: str, test_target_item: Any, ground_truth_dict: Dict[str, Any]
):
"""Validates test target item attributes against ground truth values.
Args:
test_target_item_name: Name of the item being tested (for error messages)
test_target_item: Object whose attributes are being validated
ground_truth_dict: Dictionary containing expected values
Raises:
AssertionError: If any attribute doesn't match ground truth or if update timestamp is too old
"""
for key, ground_truth in ground_truth_dict.items():
if isinstance(ground_truth, dict):
for key2, ground_truth2 in ground_truth.items():
assert (
ground_truth2 == getattr(test_target_item, key)[key2]
), f"{test_target_item_name}/{key = }/{key2 = }: {ground_truth2 = } != {getattr(test_target_item, key)[key2] = }"
elif isinstance(ground_truth, list):
raise NotImplementedError("Currently not implemented for 'list'")
else:
assert ground_truth == getattr(
test_target_item, key
), f"{test_target_item_name}/{key = }: {ground_truth = } != {getattr(test_target_item, key) = }"
time_delta = datetime.now(timezone.utc) - getattr(test_target_item, "updated_at")
assert time_delta.total_seconds() < 60, f"{ time_delta.total_seconds() = }"
class Organization(DataPoint):
id: str
name: str
members: Optional[list["SocietyPerson"]]
class SocietyPerson(DataPoint):
id: str
name: str
memberships: Optional[list[Organization]]
SocietyPerson.model_rebuild()
Organization.model_rebuild()
ORGANIZATION_NAMES = [
"ChessClub",
"RowingClub",
"TheatreTroupe",
"PoliticalParty",
"Charity",
"FanClub",
"FilmClub",
"NeighborhoodGroup",
"LocalCouncil",
"Band",
]
PERSON_NAMES = ["Sarah", "Anna", "John", "Sam"]
def create_society_person_recursive(id, name, organization_names, max_depth, depth=0):
id_suffix = "".join(random.choice(string.ascii_lowercase) for _ in range(10))
if depth < max_depth:
memberships = [
create_organization_recursive(
f"{org_name}-{depth}-{id_suffix}",
org_name.lower(),
PERSON_NAMES,
max_depth,
depth + 1,
)
for org_name in organization_names
]
else:
memberships = None
return SocietyPerson(id=id, name=f"{name}{depth}", memberships=memberships)
def create_organization_recursive(id, name, member_names, max_depth, depth=0):
id_suffix = "".join(random.choice(string.ascii_lowercase) for _ in range(10))
if depth < max_depth:
members = [
create_society_person_recursive(
f"{member_name}-{depth}-{id_suffix}",
member_name.lower(),
ORGANIZATION_NAMES,
max_depth,
depth + 1,
)
for member_name in member_names
]
else:
members = None
return Organization(id=id, name=f"{name}{depth}", members=members)
def count_society(obj):
if isinstance(obj, SocietyPerson):
if obj.memberships is not None:
organization_counts, society_person_counts = zip(
*[count_society(organization) for organization in obj.memberships]
)
organization_count = sum(organization_counts)
society_person_count = sum(society_person_counts) + 1
return (organization_count, society_person_count)
else:
return (0, 1)
if isinstance(obj, Organization):
if obj.members is not None:
organization_counts, society_person_counts = zip(
*[count_society(organization) for organization in obj.members]
)
organization_count = sum(organization_counts) + 1
society_person_count = sum(society_person_counts)
return (organization_count, society_person_count)
else:
return (1, 0)
else:
raise Exception("Not allowed")
def show_first_difference(str1, str2, str1_name, str2_name, context=30):
for i, (c1, c2) in enumerate(zip(str1, str2)):
if c1 != c2:
start = max(0, i - context)
end1 = min(len(str1), i + context + 1)
end2 = min(len(str2), i + context + 1)
if i > 0:
return f"identical: '{str1[start:i-1]}' | {str1_name}: '{str1[i-1:end1]}'... != {str2_name}: '{str2[i-1:end2]}'..."
else:
return f"{str1_name} and {str2_name} have no overlap in characters"
if len(str1) > len(str2):
return f"{str2_name} is identical up to the {i}th character, missing afterwards '{str1[i:i+context]}'..."
if len(str2) > len(str1):
return f"{str1_name} is identical up to the {i}th character, missing afterwards '{str2[i:i+context]}'..."
else:
return f"{str1_name} and {str2_name} are identical."

543
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -30,7 +30,7 @@ typing_extensions = "4.12.2"
nest_asyncio = "1.6.0"
numpy = "1.26.4"
datasets = "3.1.0"
falkordb = "1.0.9"
falkordb = {version = "1.0.9", optional = true}
boto3 = "^1.26.125"
botocore="^1.35.54"
gunicorn = "^20.1.0"
@ -43,55 +43,59 @@ filetype = "^1.2.0"
nltk = "^3.8.1"
dlt = {extras = ["sqlalchemy"], version = "^1.4.1"}
aiofiles = "^23.2.1"
qdrant-client = "^1.9.0"
qdrant-client = {version = "^1.9.0", optional = true}
graphistry = "^0.33.5"
tenacity = "^8.4.1"
weaviate-client = "4.6.7"
weaviate-client = {version = "4.6.7", optional = true}
scikit-learn = "^1.5.0"
pypdf = "^4.1.0"
neo4j = "^5.20.0"
neo4j = {version = "^5.20.0", optional = true}
jinja2 = "^3.1.3"
matplotlib = "^3.8.3"
tiktoken = "0.7.0"
langchain_text_splitters = "0.3.2"
langsmith = "0.1.139"
langchain_text_splitters = {version = "0.3.2", optional = true}
langsmith = {version = "0.1.139", optional = true}
langdetect = "1.0.9"
posthog = "^3.5.0"
posthog = {version = "^3.5.0", optional = true}
lancedb = "0.15.0"
litellm = "1.49.1"
groq = "0.8.0"
langfuse = "^2.32.0"
groq = {version = "0.8.0", optional = true}
langfuse = {version = "^2.32.0", optional = true}
pydantic-settings = "^2.2.1"
anthropic = "^0.26.1"
sentry-sdk = {extras = ["fastapi"], version = "^2.9.0"}
fastapi-users = {version = "*", extras = ["sqlalchemy"]}
alembic = "^1.13.3"
asyncpg = "0.30.0"
pgvector = "^0.3.5"
asyncpg = {version = "0.30.0", optional = true}
pgvector = {version = "^0.3.5", optional = true}
psycopg2 = {version = "^2.9.10", optional = true}
llama-index-core = {version = "^0.11.22", optional = true}
deepeval = {version = "^2.0.1", optional = true}
transformers = "^4.46.3"
pymilvus = {version = "^2.5.0", optional = true}
[tool.poetry.extras]
filesystem = ["s3fs", "botocore"]
cli = ["pipdeptree", "cron-descriptor"]
weaviate = ["weaviate-client"]
qdrant = ["qdrant-client"]
neo4j = ["neo4j"]
postgres = ["psycopg2", "pgvector", "asyncpg"]
notebook = ["ipykernel", "overrides", "ipywidgets", "jupyterlab", "jupyterlab_widgets", "jupyterlab-server", "jupyterlab-git"]
notebook = ["notebook", "ipykernel", "overrides", "ipywidgets", "jupyterlab", "jupyterlab_widgets", "jupyterlab-server", "jupyterlab-git"]
langchain = ["langsmith", "langchain_text_splitters"]
llama-index = ["llama-index-core"]
deepeval = ["deepeval"]
posthog = ["posthog"]
falkordb = ["falkordb"]
groq = ["groq"]
langfuse = ["langfuse"]
milvus = ["pymilvus"]
[tool.poetry.group.dev.dependencies]
pytest = "^7.4.0"
pytest-asyncio = "^0.21.1"
coverage = "^7.3.2"
mypy = "^1.7.1"
notebook = "^7.1.1"
notebook = {version = "^7.1.1", optional = true}
deptry = "^0.20.0"
debugpy = "1.8.2"
pylint = "^3.0.3"