Merge branch 'main' into fix-dlt-for-metadata

This commit is contained in:
Igor Ilic 2024-12-04 11:56:41 +01:00 committed by GitHub
commit c505ee5f98
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 1280 additions and 792 deletions

View file

@ -14,7 +14,7 @@ GRAPH_DATABASE_URL=
GRAPH_DATABASE_USERNAME=
GRAPH_DATABASE_PASSWORD=
# "qdrant", "pgvector", "weaviate" or "lancedb"
# "qdrant", "pgvector", "weaviate", "milvus" or "lancedb"
VECTOR_DB_PROVIDER="lancedb"
# Not needed if using "lancedb" or "pgvector"
VECTOR_DB_URL=

64
.github/workflows/test_milvus.yml vendored Normal file
View file

@ -0,0 +1,64 @@
name: test | milvus
on:
workflow_dispatch:
pull_request:
branches:
- main
types: [labeled, synchronize]
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true
env:
RUNTIME__LOG_LEVEL: ERROR
ENV: 'dev'
jobs:
get_docs_changes:
name: docs changes
uses: ./.github/workflows/get_docs_changes.yml
run_milvus:
name: test
needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' && ${{ github.event.label.name == 'run-checks' }}
runs-on: ubuntu-latest
strategy:
fail-fast: false
defaults:
run:
shell: bash
steps:
- name: Check out
uses: actions/checkout@master
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: '3.11.x'
- name: Install Poetry
# https://github.com/snok/install-poetry#running-on-windows
uses: snok/install-poetry@v1.3.2
with:
virtualenvs-create: true
virtualenvs-in-project: true
installer-parallel: true
- name: Install dependencies
run: poetry install -E milvus --no-interaction
- name: Run default basic pipeline
env:
ENV: 'dev'
LLM_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: poetry run python ./cognee/tests/test_milvus.py
- name: Clean up disk space
run: |
sudo rm -rf ~/.cache
sudo rm -rf /tmp/*
df -h

View file

@ -46,7 +46,7 @@ jobs:
installer-parallel: true
- name: Install dependencies
run: poetry install --no-interaction
run: poetry install -E neo4j --no-interaction
- name: Run default Neo4j
env:

View file

@ -47,7 +47,7 @@ jobs:
installer-parallel: true
- name: Install dependencies
run: poetry install --no-interaction
run: poetry install -E qdrant --no-interaction
- name: Run default Qdrant
env:

View file

@ -47,7 +47,7 @@ jobs:
installer-parallel: true
- name: Install dependencies
run: poetry install --no-interaction
run: poetry install -E weaviate --no-interaction
- name: Run default Weaviate
env:

View file

@ -13,37 +13,64 @@ We build for developers who need a reliable, production-ready data layer for AI
## What is cognee?
Cognee implements scalable, modular ECL (Extract, Cognify, Load) pipelines that allow you to interconnect and retrieve past conversations, documents, and audio transcriptions while reducing hallucinations, developer effort, and cost.
Try it in a Google Colab <a href="https://colab.research.google.com/drive/1g-Qnx6l_ecHZi0IOw23rg0qC4TYvEvWZ?usp=sharing">notebook</a> or have a look at our <a href="https://topoteretes.github.io/cognee">documentation</a>
Try it in a Google Colab <a href="https://colab.research.google.com/drive/1g-Qnx6l_ecHZi0IOw23rg0qC4TYvEvWZ?usp=sharing">notebook</a> or have a look at our <a href="https://docs.cognee.ai">documentation</a>
If you have questions, join our <a href="https://discord.gg/NQPKmU5CCg">Discord</a> community
## 📦 Installation
You can install Cognee using either **pip** or **poetry**.
Support for various databases and vector stores is available through extras.
### With pip
```bash
pip install cognee
```
### With pip with PostgreSQL support
```bash
pip install 'cognee[postgres]'
```
### With poetry
```bash
poetry add cognee
```
### With poetry with PostgreSQL support
### With pip with specific database support
To install Cognee with support for specific databases use the appropriate command below. Replace \<database> with the name of the database you need.
```bash
poetry add cognee -E postgres
pip install 'cognee[<database>]'
```
Replace \<database> with any of the following databases:
- postgres
- weaviate
- qdrant
- neo4j
- milvus
Installing Cognee with PostgreSQL and Neo4j support example:
```bash
pip install 'cognee[postgres, neo4j]'
```
### With poetry with specific database support
To install Cognee with support for specific databases use the appropriate command below. Replace \<database> with the name of the database you need.
```bash
poetry add cognee -E <database>
```
Replace \<database> with any of the following databases:
- postgres
- weaviate
- qdrant
- neo4j
- milvus
Installing Cognee with PostgreSQL and Neo4j support example:
```bash
poetry add cognee -E postgres -E neo4j
```
## 💻 Basic Usage
@ -61,7 +88,7 @@ import cognee
cognee.config.set_llm_api_key("YOUR_OPENAI_API_KEY")
```
You can also set the variables by creating .env file, here is our <a href="https://github.com/topoteretes/cognee/blob/main/.env.template">template.</a>
To use different LLM providers, for more info check out our <a href="https://topoteretes.github.io/cognee">documentation</a>
To use different LLM providers, for more info check out our <a href="https://docs.cognee.ai">documentation</a>
If you are using Network, create an account on Graphistry to visualize results:
```
@ -282,7 +309,7 @@ Check out our demo notebook [here](https://github.com/topoteretes/cognee/blob/ma
### Install Server
Please see the [cognee Quick Start Guide](https://topoteretes.github.io/cognee/quickstart/) for important configuration information.
Please see the [cognee Quick Start Guide](https://docs.cognee.ai/quickstart/) for important configuration information.
```bash
docker compose up
@ -291,7 +318,7 @@ docker compose up
### Install SDK
Please see the cognee [Development Guide](https://topoteretes.github.io/cognee/quickstart/) for important beta information and usage instructions.
Please see the cognee [Development Guide](https://docs.cognee.ai/quickstart/) for important beta information and usage instructions.
```bash
pip install cognee
@ -317,12 +344,13 @@ pip install cognee
}
</style>
| Name | Type | Current state | Known Issues |
|------------------|--------------------|-------------------|---------------------------------------|
| Qdrant | Vector | Stable &#x2705; | |
| Weaviate | Vector | Stable &#x2705; | |
| LanceDB | Vector | Stable &#x2705; | |
| Neo4j | Graph | Stable &#x2705; | |
| NetworkX | Graph | Stable &#x2705; | |
| FalkorDB | Vector/Graph | Unstable &#x274C; | |
| PGVector | Vector | Unstable &#x274C; | Postgres DB returns the Timeout error |
| Name | Type | Current state | Known Issues |
|----------|--------------------|-------------------|--------------|
| Qdrant | Vector | Stable &#x2705; | |
| Weaviate | Vector | Stable &#x2705; | |
| LanceDB | Vector | Stable &#x2705; | |
| Neo4j | Graph | Stable &#x2705; | |
| NetworkX | Graph | Stable &#x2705; | |
| FalkorDB | Vector/Graph | Unstable &#x274C; | |
| PGVector | Vector | Stable &#x2705; | |
| Milvus | Vector | Stable &#x2705; | |

View file

@ -1,9 +1,9 @@
from functools import lru_cache
# from functools import lru_cache
from .config import get_relational_config
from .create_relational_engine import create_relational_engine
@lru_cache
# @lru_cache
def get_relational_engine():
relational_config = get_relational_config()

View file

@ -1,11 +1,13 @@
from typing import Dict
class VectorConfig(Dict):
vector_db_url: str
vector_db_port: str
vector_db_key: str
vector_db_provider: str
def create_vector_engine(config: VectorConfig, embedding_engine):
if config["vector_db_provider"] == "weaviate":
from .weaviate_db import WeaviateAdapter
@ -16,24 +18,37 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
return WeaviateAdapter(
config["vector_db_url"],
config["vector_db_key"],
embedding_engine = embedding_engine
embedding_engine=embedding_engine
)
elif config["vector_db_provider"] == "qdrant":
if not (config["vector_db_url"] and config["vector_db_key"]):
raise EnvironmentError("Missing requred Qdrant credentials!")
from .qdrant.QDrantAdapter import QDrantAdapter
return QDrantAdapter(
url = config["vector_db_url"],
api_key = config["vector_db_key"],
embedding_engine = embedding_engine
url=config["vector_db_url"],
api_key=config["vector_db_key"],
embedding_engine=embedding_engine
)
elif config['vector_db_provider'] == 'milvus':
from .milvus.MilvusAdapter import MilvusAdapter
if not config["vector_db_url"]:
raise EnvironmentError("Missing required Milvus credentials!")
return MilvusAdapter(
url=config["vector_db_url"],
api_key=config['vector_db_key'],
embedding_engine=embedding_engine
)
elif config["vector_db_provider"] == "pgvector":
from cognee.infrastructure.databases.relational import get_relational_config
# Get configuration for postgres database
relational_config = get_relational_config()
db_username = relational_config.db_username
@ -52,8 +67,8 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
from .pgvector.PGVectorAdapter import PGVectorAdapter
return PGVectorAdapter(
connection_string,
config["vector_db_key"],
connection_string,
config["vector_db_key"],
embedding_engine,
)
@ -64,16 +79,16 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
from ..hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter
return FalkorDBAdapter(
database_url = config["vector_db_url"],
database_port = config["vector_db_port"],
embedding_engine = embedding_engine,
database_url=config["vector_db_url"],
database_port=config["vector_db_port"],
embedding_engine=embedding_engine,
)
else:
from .lancedb.LanceDBAdapter import LanceDBAdapter
return LanceDBAdapter(
url = config["vector_db_url"],
api_key = config["vector_db_key"],
embedding_engine = embedding_engine,
url=config["vector_db_url"],
api_key=config["vector_db_key"],
embedding_engine=embedding_engine,
)

View file

@ -1,4 +1,6 @@
import asyncio
import logging
import math
from typing import List, Optional
import litellm
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
@ -36,11 +38,26 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
api_base = self.endpoint,
api_version = self.api_version
)
except litellm.exceptions.BadRequestError as error:
return [data["embedding"] for data in response.data]
except litellm.exceptions.ContextWindowExceededError as error:
if isinstance(text, list):
parts = [text[0:math.ceil(len(text)/2)], text[math.ceil(len(text)/2):]]
parts_futures = [self.embed_text(part) for part in parts]
embeddings = await asyncio.gather(*parts_futures)
all_embeddings = []
for embeddings_part in embeddings:
all_embeddings.extend(embeddings_part)
return [data["embedding"] for data in all_embeddings]
logger.error("Context window exceeded for embedding text: %s", str(error))
raise error
except Exception as error:
logger.error("Error embedding text: %s", str(error))
raise error
return [data["embedding"] for data in response.data]
def get_vector_size(self) -> int:
return self.dimensions

View file

@ -0,0 +1,252 @@
from __future__ import annotations
import asyncio
import logging
from typing import List, Optional
from uuid import UUID
from cognee.infrastructure.engine import DataPoint
from ..vector_db_interface import VectorDBInterface
from ..models.ScoredResult import ScoredResult
from ..embeddings.EmbeddingEngine import EmbeddingEngine
logger = logging.getLogger("MilvusAdapter")
class IndexSchema(DataPoint):
text: str
_metadata: dict = {
"index_fields": ["text"]
}
class MilvusAdapter(VectorDBInterface):
name = "Milvus"
url: str
api_key: Optional[str]
embedding_engine: EmbeddingEngine = None
def __init__(self, url: str, api_key: Optional[str], embedding_engine: EmbeddingEngine):
self.url = url
self.api_key = api_key
self.embedding_engine = embedding_engine
def get_milvus_client(self) -> "MilvusClient":
from pymilvus import MilvusClient
if self.api_key:
client = MilvusClient(uri=self.url, token=self.api_key)
else:
client = MilvusClient(uri=self.url)
return client
async def embed_data(self, data: List[str]) -> list[list[float]]:
return await self.embedding_engine.embed_text(data)
async def has_collection(self, collection_name: str) -> bool:
future = asyncio.Future()
client = self.get_milvus_client()
future.set_result(client.has_collection(collection_name=collection_name))
return await future
async def create_collection(
self,
collection_name: str,
payload_schema=None,
):
from pymilvus import DataType, MilvusException
client = self.get_milvus_client()
if client.has_collection(collection_name=collection_name):
logger.info(f"Collection '{collection_name}' already exists.")
return True
try:
dimension = self.embedding_engine.get_vector_size()
assert dimension > 0, "Embedding dimension must be greater than 0."
schema = client.create_schema(
auto_id=False,
enable_dynamic_field=False,
)
schema.add_field(
field_name="id",
datatype=DataType.VARCHAR,
is_primary=True,
max_length=36
)
schema.add_field(
field_name="vector",
datatype=DataType.FLOAT_VECTOR,
dim=dimension
)
schema.add_field(
field_name="text",
datatype=DataType.VARCHAR,
max_length=60535
)
index_params = client.prepare_index_params()
index_params.add_index(
field_name="vector",
metric_type="COSINE"
)
client.create_collection(
collection_name=collection_name,
schema=schema,
index_params=index_params
)
client.load_collection(collection_name)
logger.info(f"Collection '{collection_name}' created successfully.")
return True
except MilvusException as e:
logger.error(f"Error creating collection '{collection_name}': {str(e)}")
raise e
async def create_data_points(
self,
collection_name: str,
data_points: List[DataPoint]
):
from pymilvus import MilvusException
client = self.get_milvus_client()
data_vectors = await self.embed_data(
[data_point.get_embeddable_data() for data_point in data_points]
)
insert_data = [
{
"id": str(data_point.id),
"vector": data_vectors[index],
"text": data_point.text,
}
for index, data_point in enumerate(data_points)
]
try:
result = client.insert(
collection_name=collection_name,
data=insert_data
)
logger.info(
f"Inserted {result.get('insert_count', 0)} data points into collection '{collection_name}'."
)
return result
except MilvusException as e:
logger.error(f"Error inserting data points into collection '{collection_name}': {str(e)}")
raise e
async def create_vector_index(self, index_name: str, index_property_name: str):
await self.create_collection(f"{index_name}_{index_property_name}")
async def index_data_points(self, index_name: str, index_property_name: str, data_points: List[DataPoint]):
formatted_data_points = [
IndexSchema(
id=data_point.id,
text=getattr(data_point, data_point._metadata["index_fields"][0]),
)
for data_point in data_points
]
collection_name = f"{index_name}_{index_property_name}"
await self.create_data_points(collection_name, formatted_data_points)
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
from pymilvus import MilvusException
client = self.get_milvus_client()
try:
filter_expression = f"""id in [{", ".join(f'"{id}"' for id in data_point_ids)}]"""
results = client.query(
collection_name=collection_name,
expr=filter_expression,
output_fields=["*"],
)
return results
except MilvusException as e:
logger.error(f"Error retrieving data points from collection '{collection_name}': {str(e)}")
raise e
async def search(
self,
collection_name: str,
query_text: Optional[str] = None,
query_vector: Optional[List[float]] = None,
limit: int = 5,
with_vector: bool = False,
):
from pymilvus import MilvusException
client = self.get_milvus_client()
if query_text is None and query_vector is None:
raise ValueError("One of query_text or query_vector must be provided!")
try:
query_vector = query_vector or (await self.embed_data([query_text]))[0]
output_fields = ["id", "text"]
if with_vector:
output_fields.append("vector")
results = client.search(
collection_name=collection_name,
data=[query_vector],
anns_field="vector",
limit=limit,
output_fields=output_fields,
search_params={
"metric_type": "COSINE",
},
)
return [
ScoredResult(
id=UUID(result["id"]),
score=result["distance"],
payload=result.get("entity", {}),
)
for result in results[0]
]
except MilvusException as e:
logger.error(f"Error during search in collection '{collection_name}': {str(e)}")
raise e
async def batch_search(self, collection_name: str, query_texts: List[str], limit: int, with_vectors: bool = False):
query_vectors = await self.embed_data(query_texts)
return await asyncio.gather(
*[self.search(collection_name=collection_name,
query_vector=query_vector,
limit=limit,
with_vector=with_vectors,
) for query_vector in query_vectors]
)
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
from pymilvus import MilvusException
client = self.get_milvus_client()
try:
filter_expression = f"""id in [{", ".join(f'"{id}"' for id in data_point_ids)}]"""
delete_result = client.delete(
collection_name=collection_name,
filter=filter_expression
)
logger.info(f"Deleted data points with IDs {data_point_ids} from collection '{collection_name}'.")
return delete_result
except MilvusException as e:
logger.error(f"Error deleting data points from collection '{collection_name}': {str(e)}")
raise e
async def prune(self):
client = self.get_milvus_client()
if client:
collections = client.list_collections()
for collection_name in collections:
client.drop_collection(collection_name=collection_name)
client.close()

View file

@ -0,0 +1 @@
from .MilvusAdapter import MilvusAdapter

View file

@ -1,6 +1,5 @@
import asyncio
from uuid import UUID
from pgvector.sqlalchemy import Vector
from typing import List, Optional, get_type_hints
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy import JSON, Column, Table, select, delete
@ -70,6 +69,8 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
vector_size = self.embedding_engine.get_vector_size()
if not await self.has_collection(collection_name):
from pgvector.sqlalchemy import Vector
class PGVectorDataPoint(Base):
__tablename__ = collection_name
__table_args__ = {"extend_existing": True}
@ -107,6 +108,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
vector_size = self.embedding_engine.get_vector_size()
from pgvector.sqlalchemy import Vector
class PGVectorDataPoint(Base):
__tablename__ = collection_name
__table_args__ = {"extend_existing": True}

View file

@ -1,41 +1,29 @@
from typing import Callable
from pydantic_core import PydanticUndefined
from cognee.infrastructure.engine import DataPoint
from cognee.modules.storage.utils import copy_model
def get_model_instance_from_graph(
nodes: list[DataPoint],
edges: list[tuple[str, str, str, dict[str, str]]],
entity_id: str,
):
node_map = {node.id: node for node in nodes}
def get_model_instance_from_graph(nodes: list[DataPoint], edges: list, entity_id: str):
node_map = {}
for source_node_id, target_node_id, edge_label, edge_properties in edges:
source_node = node_map[source_node_id]
target_node = node_map[target_node_id]
for node in nodes:
node_map[node.id] = node
for edge in edges:
source_node = node_map[edge[0]]
target_node = node_map[edge[1]]
edge_label = edge[2]
edge_properties = edge[3] if len(edge) == 4 else {}
edge_metadata = edge_properties.get("metadata", {})
edge_type = edge_metadata.get("type", "default")
edge_type = edge_metadata.get("type")
if edge_type == "list":
NewModel = copy_model(
type(source_node),
{edge_label: (list[type(target_node)], PydanticUndefined)},
)
source_node_dict = source_node.model_dump()
source_node_edge_label_values = source_node_dict.get(edge_label, [])
source_node_dict[edge_label] = source_node_edge_label_values + [target_node]
NewModel = copy_model(type(source_node), { edge_label: (list[type(target_node)], PydanticUndefined) })
node_map[source_node_id] = NewModel(**source_node_dict)
node_map[edge[0]] = NewModel(**source_node.model_dump(), **{ edge_label: [target_node] })
else:
NewModel = copy_model(
type(source_node), {edge_label: (type(target_node), PydanticUndefined)}
)
NewModel = copy_model(type(source_node), { edge_label: (type(target_node), PydanticUndefined) })
node_map[target_node_id] = NewModel(
**source_node.model_dump(), **{edge_label: target_node}
)
node_map[edge[0]] = NewModel(**source_node.model_dump(), **{ edge_label: target_node })
return node_map[entity_id]

View file

@ -17,7 +17,7 @@ from uuid import uuid4
import pathlib
# Analytics Proxy Url, currently hosted by Vercel
vercel_url = "https://proxyanalytics.vercel.app"
proxy_url = "https://test.prometh.ai"
def get_anonymous_id():
"""Creates or reads a anonymous user id"""
@ -57,7 +57,7 @@ def send_telemetry(event_name: str, user_id, additional_properties: dict = {}):
},
}
response = requests.post(vercel_url, json=payload)
response = requests.post(proxy_url, json=payload)
if response.status_code != 200:
print(f"Error sending telemetry through proxy: {response.status_code}")

View file

@ -0,0 +1,84 @@
import os
import logging
import pathlib
import cognee
from cognee.api.v1.search import SearchType
logging.basicConfig(level=logging.DEBUG)
async def main():
cognee.config.set_vector_db_provider("milvus")
data_directory_path = str(
pathlib.Path(os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_milvus")).resolve())
cognee.config.data_root_directory(data_directory_path)
cognee_directory_path = str(
pathlib.Path(os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_milvus")).resolve())
cognee.config.system_root_directory(cognee_directory_path)
cognee.config.set_vector_db_config(
{
"vector_db_url": os.path.join(cognee_directory_path, "databases/milvus.db"),
"vector_db_key": "",
"vector_db_provider": "milvus"
}
)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
dataset_name = "cs_explanations"
explanation_file_path = os.path.join(pathlib.Path(__file__).parent, "test_data/Natural_language_processing.txt")
await cognee.add([explanation_file_path], dataset_name)
text = """A quantum computer is a computer that takes advantage of quantum mechanical phenomena.
At small scales, physical matter exhibits properties of both particles and waves, and quantum computing leverages this behavior, specifically quantum superposition and entanglement, using specialized hardware that supports the preparation and manipulation of quantum states.
Classical physics cannot explain the operation of these quantum devices, and a scalable quantum computer could perform some calculations exponentially faster (with respect to input size scaling) than any modern "classical" computer. In particular, a large-scale quantum computer could break widely used encryption schemes and aid physicists in performing physical simulations; however, the current state of the technology is largely experimental and impractical, with several obstacles to useful applications. Moreover, scalable quantum computers do not hold promise for many practical tasks, and for many important tasks quantum speedups are proven impossible.
The basic unit of information in quantum computing is the qubit, similar to the bit in traditional digital electronics. Unlike a classical bit, a qubit can exist in a superposition of its two "basis" states. When measuring a qubit, the result is a probabilistic output of a classical bit, therefore making quantum computers nondeterministic in general. If a quantum computer manipulates the qubit in a particular way, wave interference effects can amplify the desired measurement results. The design of quantum algorithms involves creating procedures that allow a quantum computer to perform calculations efficiently and quickly.
Physically engineering high-quality qubits has proven challenging. If a physical qubit is not sufficiently isolated from its environment, it suffers from quantum decoherence, introducing noise into calculations. Paradoxically, perfectly isolating qubits is also undesirable because quantum computations typically need to initialize qubits, perform controlled qubit interactions, and measure the resulting quantum states. Each of those operations introduces errors and suffers from noise, and such inaccuracies accumulate.
In principle, a non-quantum (classical) computer can solve the same computational problems as a quantum computer, given enough time. Quantum advantage comes in the form of time complexity rather than computability, and quantum complexity theory shows that some quantum algorithms for carefully selected tasks require exponentially fewer computational steps than the best known non-quantum algorithms. Such tasks can in theory be solved on a large-scale quantum computer whereas classical computers would not finish computations in any reasonable amount of time. However, quantum speedup is not universal or even typical across computational tasks, since basic tasks such as sorting are proven to not allow any asymptotic quantum speedup. Claims of quantum supremacy have drawn significant attention to the discipline, but are demonstrated on contrived tasks, while near-term practical use cases remain limited.
"""
await cognee.add([text], dataset_name)
await cognee.cognify([dataset_name])
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("entity_name", "Quantum computer"))[0]
random_node_name = random_node.payload["text"]
search_results = await cognee.search(SearchType.INSIGHTS, query_text=random_node_name)
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted INSIGHTS are:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search(SearchType.CHUNKS, query_text=random_node_name)
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted CHUNKS are:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search(SearchType.SUMMARIES, query_text=random_node_name)
assert len(search_results) != 0, "The search results list is empty."
print("\nExtracted SUMMARIES are:\n")
for result in search_results:
print(f"{result}\n")
history = await cognee.get_search_history()
assert len(history) == 6, "Search history is not correct."
await cognee.prune.prune_data()
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
await cognee.prune.prune_system(metadata=True)
milvus_client = get_vector_engine().get_milvus_client()
collections = milvus_client.list_collections()
assert len(collections) == 0, "Milvus vector database is not empty"
if __name__ == "__main__":
import asyncio
asyncio.run(main())

View file

@ -1,68 +0,0 @@
from enum import Enum
from typing import Optional
import pytest
from cognee.infrastructure.engine import DataPoint
class CarTypeName(Enum):
Pickup = "Pickup"
Sedan = "Sedan"
SUV = "SUV"
Coupe = "Coupe"
Convertible = "Convertible"
Hatchback = "Hatchback"
Wagon = "Wagon"
Minivan = "Minivan"
Van = "Van"
class CarType(DataPoint):
id: str
name: CarTypeName
_metadata: dict = dict(index_fields=["name"])
class Car(DataPoint):
id: str
brand: str
model: str
year: int
color: str
is_type: CarType
class Person(DataPoint):
id: str
name: str
age: int
owns_car: list[Car]
driving_license: Optional[dict]
_metadata: dict = dict(index_fields=["name"])
@pytest.fixture(scope="function")
def boris():
boris = Person(
id="boris",
name="Boris",
age=30,
owns_car=[
Car(
id="car1",
brand="Toyota",
model="Camry",
year=2020,
color="Blue",
is_type=CarType(id="sedan", name=CarTypeName.Sedan),
)
],
driving_license={
"issued_by": "PU Vrsac",
"issued_on": "2025-11-06",
"number": "1234567890",
"expires_on": "2025-11-06",
},
)
return boris

View file

@ -1,37 +0,0 @@
import warnings
import pytest
from cognee.modules.graph.utils import get_graph_from_model
from cognee.tests.unit.interfaces.graph.util import (
PERSON_NAMES,
count_society,
create_organization_recursive,
)
@pytest.mark.parametrize("recursive_depth", [1, 2, 3])
async def test_society_nodes_and_edges(recursive_depth):
import sys
if sys.version_info[0] == 3 and sys.version_info[1] >= 11:
society = create_organization_recursive(
"society", "Society", PERSON_NAMES, recursive_depth
)
n_organizations, n_persons = count_society(society)
society_counts_total = n_organizations + n_persons
nodes, edges = await get_graph_from_model(society)
assert (
len(nodes) == society_counts_total
), f"{society_counts_total = } != {len(nodes) = }, not all DataPoint instances were found"
assert len(edges) == (
len(nodes) - 1
), f"{(len(nodes) - 1) = } != {len(edges) = }, there have to be n_nodes - 1 edges, as each node has exactly one parent node, except for the root node"
else:
warnings.warn(
"The recursive pydantic data structure cannot be reconstructed from the graph because the 'inner' pydantic class is not defined. Hence this test is skipped. This problem is solved in Python 3.11"
)

View file

@ -1,89 +0,0 @@
from cognee.modules.graph.utils import get_graph_from_model
from cognee.tests.unit.interfaces.graph.util import run_test_against_ground_truth
CAR_SEDAN_EDGE = (
"car1",
"sedan",
"is_type",
{
"source_node_id": "car1",
"target_node_id": "sedan",
"relationship_name": "is_type",
},
)
BORIS_CAR_EDGE_GROUND_TRUTH = (
"boris",
"car1",
"owns_car",
{
"source_node_id": "boris",
"target_node_id": "car1",
"relationship_name": "owns_car",
"metadata": {"type": "list"},
},
)
CAR_TYPE_GROUND_TRUTH = {"id": "sedan"}
CAR_GROUND_TRUTH = {
"id": "car1",
"brand": "Toyota",
"model": "Camry",
"year": 2020,
"color": "Blue",
}
PERSON_GROUND_TRUTH = {
"id": "boris",
"name": "Boris",
"age": 30,
"driving_license": {
"issued_by": "PU Vrsac",
"issued_on": "2025-11-06",
"number": "1234567890",
"expires_on": "2025-11-06",
},
}
async def test_extracted_car_type(boris):
nodes, _ = await get_graph_from_model(boris)
assert len(nodes) == 3
car_type = nodes[0]
run_test_against_ground_truth("car_type", car_type, CAR_TYPE_GROUND_TRUTH)
async def test_extracted_car(boris):
nodes, _ = await get_graph_from_model(boris)
assert len(nodes) == 3
car = nodes[1]
run_test_against_ground_truth("car", car, CAR_GROUND_TRUTH)
async def test_extracted_person(boris):
nodes, _ = await get_graph_from_model(boris)
assert len(nodes) == 3
person = nodes[2]
run_test_against_ground_truth("person", person, PERSON_GROUND_TRUTH)
async def test_extracted_car_sedan_edge(boris):
_, edges = await get_graph_from_model(boris)
edge = edges[0]
assert CAR_SEDAN_EDGE[:3] == edge[:3], f"{CAR_SEDAN_EDGE[:3] = } != {edge[:3] = }"
for key, ground_truth in CAR_SEDAN_EDGE[3].items():
assert ground_truth == edge[3][key], f"{ground_truth = } != {edge[3][key] = }"
async def test_extracted_boris_car_edge(boris):
_, edges = await get_graph_from_model(boris)
edge = edges[1]
assert (
BORIS_CAR_EDGE_GROUND_TRUTH[:3] == edge[:3]
), f"{BORIS_CAR_EDGE_GROUND_TRUTH[:3] = } != {edge[:3] = }"
for key, ground_truth in BORIS_CAR_EDGE_GROUND_TRUTH[3].items():
assert ground_truth == edge[3][key], f"{ground_truth = } != {edge[3][key] = }"

View file

@ -1,33 +0,0 @@
import warnings
import pytest
from cognee.modules.graph.utils import (
get_graph_from_model,
get_model_instance_from_graph,
)
from cognee.tests.unit.interfaces.graph.util import (
PERSON_NAMES,
create_organization_recursive,
show_first_difference,
)
@pytest.mark.parametrize("recursive_depth", [1, 2, 3])
async def test_society_nodes_and_edges(recursive_depth):
import sys
if sys.version_info[0] == 3 and sys.version_info[1] >= 11:
society = create_organization_recursive(
"society", "Society", PERSON_NAMES, recursive_depth
)
nodes, edges = await get_graph_from_model(society)
parsed_society = get_model_instance_from_graph(nodes, edges, "society")
assert str(society) == (str(parsed_society)), show_first_difference(
str(society), str(parsed_society), "society", "parsed_society"
)
else:
warnings.warn(
"The recursive pydantic data structure cannot be reconstructed from the graph because the 'inner' pydantic class is not defined. Hence this test is skipped. This problem is solved in Python 3.11"
)

View file

@ -1,35 +0,0 @@
from cognee.modules.graph.utils import (
get_graph_from_model,
get_model_instance_from_graph,
)
from cognee.tests.unit.interfaces.graph.util import run_test_against_ground_truth
PARSED_PERSON_GROUND_TRUTH = {
"id": "boris",
"name": "Boris",
"age": 30,
"driving_license": {
"issued_by": "PU Vrsac",
"issued_on": "2025-11-06",
"number": "1234567890",
"expires_on": "2025-11-06",
},
}
CAR_GROUND_TRUTH = {
"id": "car1",
"brand": "Toyota",
"model": "Camry",
"year": 2020,
"color": "Blue",
}
async def test_parsed_person(boris):
nodes, edges = await get_graph_from_model(boris)
parsed_person = get_model_instance_from_graph(nodes, edges, "boris")
run_test_against_ground_truth(
"parsed_person", parsed_person, PARSED_PERSON_GROUND_TRUTH
)
run_test_against_ground_truth("car", parsed_person.owns_car[0], CAR_GROUND_TRUTH)

View file

@ -1,150 +0,0 @@
import random
import string
from datetime import datetime, timezone
from typing import Any, Dict, Optional
from cognee.infrastructure.engine import DataPoint
def run_test_against_ground_truth(
test_target_item_name: str, test_target_item: Any, ground_truth_dict: Dict[str, Any]
):
"""Validates test target item attributes against ground truth values.
Args:
test_target_item_name: Name of the item being tested (for error messages)
test_target_item: Object whose attributes are being validated
ground_truth_dict: Dictionary containing expected values
Raises:
AssertionError: If any attribute doesn't match ground truth or if update timestamp is too old
"""
for key, ground_truth in ground_truth_dict.items():
if isinstance(ground_truth, dict):
for key2, ground_truth2 in ground_truth.items():
assert (
ground_truth2 == getattr(test_target_item, key)[key2]
), f"{test_target_item_name}/{key = }/{key2 = }: {ground_truth2 = } != {getattr(test_target_item, key)[key2] = }"
elif isinstance(ground_truth, list):
raise NotImplementedError("Currently not implemented for 'list'")
else:
assert ground_truth == getattr(
test_target_item, key
), f"{test_target_item_name}/{key = }: {ground_truth = } != {getattr(test_target_item, key) = }"
time_delta = datetime.now(timezone.utc) - getattr(test_target_item, "updated_at")
assert time_delta.total_seconds() < 60, f"{ time_delta.total_seconds() = }"
class Organization(DataPoint):
id: str
name: str
members: Optional[list["SocietyPerson"]]
class SocietyPerson(DataPoint):
id: str
name: str
memberships: Optional[list[Organization]]
SocietyPerson.model_rebuild()
Organization.model_rebuild()
ORGANIZATION_NAMES = [
"ChessClub",
"RowingClub",
"TheatreTroupe",
"PoliticalParty",
"Charity",
"FanClub",
"FilmClub",
"NeighborhoodGroup",
"LocalCouncil",
"Band",
]
PERSON_NAMES = ["Sarah", "Anna", "John", "Sam"]
def create_society_person_recursive(id, name, organization_names, max_depth, depth=0):
id_suffix = "".join(random.choice(string.ascii_lowercase) for _ in range(10))
if depth < max_depth:
memberships = [
create_organization_recursive(
f"{org_name}-{depth}-{id_suffix}",
org_name.lower(),
PERSON_NAMES,
max_depth,
depth + 1,
)
for org_name in organization_names
]
else:
memberships = None
return SocietyPerson(id=id, name=f"{name}{depth}", memberships=memberships)
def create_organization_recursive(id, name, member_names, max_depth, depth=0):
id_suffix = "".join(random.choice(string.ascii_lowercase) for _ in range(10))
if depth < max_depth:
members = [
create_society_person_recursive(
f"{member_name}-{depth}-{id_suffix}",
member_name.lower(),
ORGANIZATION_NAMES,
max_depth,
depth + 1,
)
for member_name in member_names
]
else:
members = None
return Organization(id=id, name=f"{name}{depth}", members=members)
def count_society(obj):
if isinstance(obj, SocietyPerson):
if obj.memberships is not None:
organization_counts, society_person_counts = zip(
*[count_society(organization) for organization in obj.memberships]
)
organization_count = sum(organization_counts)
society_person_count = sum(society_person_counts) + 1
return (organization_count, society_person_count)
else:
return (0, 1)
if isinstance(obj, Organization):
if obj.members is not None:
organization_counts, society_person_counts = zip(
*[count_society(organization) for organization in obj.members]
)
organization_count = sum(organization_counts) + 1
society_person_count = sum(society_person_counts)
return (organization_count, society_person_count)
else:
return (1, 0)
else:
raise Exception("Not allowed")
def show_first_difference(str1, str2, str1_name, str2_name, context=30):
for i, (c1, c2) in enumerate(zip(str1, str2)):
if c1 != c2:
start = max(0, i - context)
end1 = min(len(str1), i + context + 1)
end2 = min(len(str2), i + context + 1)
if i > 0:
return f"identical: '{str1[start:i-1]}' | {str1_name}: '{str1[i-1:end1]}'... != {str2_name}: '{str2[i-1:end2]}'..."
else:
return f"{str1_name} and {str2_name} have no overlap in characters"
if len(str1) > len(str2):
return f"{str2_name} is identical up to the {i}th character, missing afterwards '{str1[i:i+context]}'..."
if len(str2) > len(str1):
return f"{str1_name} is identical up to the {i}th character, missing afterwards '{str2[i:i+context]}'..."
else:
return f"{str1_name} and {str2_name} are identical."

View file

@ -7,20 +7,19 @@ from pathlib import Path
from swebench.harness.utils import load_swebench_dataset
from swebench.inference.make_datasets.create_instance import PATCH_EXAMPLE
import cognee
from cognee.api.v1.search import SearchType
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.infrastructure.llm.prompts import read_query_prompt
from cognee.modules.pipelines import Task, run_tasks
from cognee.modules.retrieval.brute_force_triplet_search import \
brute_force_triplet_search
from cognee.shared.data_models import SummarizedContent
# from cognee.shared.data_models import SummarizedContent
from cognee.shared.utils import render_graph
from cognee.tasks.repo_processor import (enrich_dependency_graph,
expand_dependency_graph,
get_repo_file_dependencies)
from cognee.tasks.storage import add_data_points
from cognee.tasks.summarization import summarize_code
# from cognee.tasks.summarization import summarize_code
from evals.eval_utils import download_github_repo, retrieved_edges_to_string
@ -43,8 +42,21 @@ def check_install_package(package_name):
async def generate_patch_with_cognee(instance, llm_client, search_type=SearchType.CHUNKS):
import os
import pathlib
import cognee
from cognee.infrastructure.databases.relational import create_db_and_tables
file_path = Path(__file__).parent
data_directory_path = str(pathlib.Path(os.path.join(file_path, ".data_storage/code_graph")).resolve())
cognee.config.data_root_directory(data_directory_path)
cognee_directory_path = str(pathlib.Path(os.path.join(file_path, ".cognee_system/code_graph")).resolve())
cognee.config.system_root_directory(cognee_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system()
await cognee.prune.prune_system(metadata = True)
await create_db_and_tables()
# repo_path = download_github_repo(instance, '../RAW_GIT_REPOS')

View file

@ -17,60 +17,63 @@
"metadata": {},
"outputs": [],
"source": [
"from cognee.modules.users.methods import get_default_user\n",
"import os\n",
"import pathlib\n",
"import cognee\n",
"from cognee.infrastructure.databases.relational import create_db_and_tables\n",
"\n",
"from cognee.modules.data.methods import get_datasets\n",
"from cognee.modules.data.methods.get_dataset_data import get_dataset_data\n",
"from cognee.modules.data.models import Data\n",
"notebook_path = os.path.abspath(\"\")\n",
"data_directory_path = str(pathlib.Path(os.path.join(notebook_path, \".data_storage/code_graph\")).resolve())\n",
"cognee.config.data_root_directory(data_directory_path)\n",
"cognee_directory_path = str(pathlib.Path(os.path.join(notebook_path, \".cognee_system/code_graph\")).resolve())\n",
"cognee.config.system_root_directory(cognee_directory_path)\n",
"\n",
"from cognee.modules.pipelines.tasks.Task import Task\n",
"from cognee.tasks.documents import classify_documents, check_permissions_on_documents, extract_chunks_from_documents\n",
"from cognee.tasks.graph import extract_graph_from_code\n",
"await cognee.prune.prune_data()\n",
"await cognee.prune.prune_system(metadata = True)\n",
"\n",
"await create_db_and_tables()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from os import path\n",
"from pathlib import Path\n",
"from cognee.infrastructure.files.storage import LocalStorage\n",
"import git\n",
"\n",
"notebook_path = path.abspath(\"\")\n",
"repo_clone_location = path.join(notebook_path, \"data/graphrag\")\n",
"\n",
"LocalStorage.remove_all(repo_clone_location)\n",
"\n",
"git.Repo.clone_from(\n",
" \"git@github.com:microsoft/graphrag.git\",\n",
" Path(repo_clone_location),\n",
" branch = \"main\",\n",
" single_branch = True,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from cognee.tasks.repo_processor import enrich_dependency_graph, expand_dependency_graph, get_repo_file_dependencies\n",
"from cognee.tasks.storage import add_data_points\n",
"from cognee.shared.SourceCodeGraph import SourceCodeGraph\n",
"from cognee.modules.pipelines.tasks.Task import Task\n",
"\n",
"from cognee.modules.pipelines import run_tasks\n",
"\n",
"from cognee.shared.utils import render_graph\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"user = await get_default_user()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"existing_datasets = await get_datasets(user.id)\n",
"\n",
"datasets = {}\n",
"for dataset in existing_datasets:\n",
" dataset_name = dataset.name.replace(\".\", \"_\").replace(\" \", \"_\")\n",
" data_documents: list[Data] = await get_dataset_data(dataset_id = dataset.id)\n",
" datasets[dataset_name] = data_documents\n",
"print(datasets.keys())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tasks = [\n",
" Task(classify_documents),\n",
" Task(check_permissions_on_documents, user = user, permissions = [\"write\"]),\n",
" Task(extract_chunks_from_documents), # Extract text chunks based on the document type.\n",
" Task(add_data_points, task_config = { \"batch_size\": 10 }),\n",
" Task(extract_graph_from_code, graph_model = SourceCodeGraph, task_config = { \"batch_size\": 10 }), # Generate knowledge graphs from the document chunks.\n",
" Task(get_repo_file_dependencies),\n",
" Task(add_data_points, task_config = { \"batch_size\": 50 }),\n",
" Task(enrich_dependency_graph, task_config = { \"batch_size\": 50 }),\n",
" Task(expand_dependency_graph, task_config = { \"batch_size\": 50 }),\n",
" Task(add_data_points, task_config = { \"batch_size\": 50 }),\n",
"]"
]
},
@ -80,21 +83,15 @@
"metadata": {},
"outputs": [],
"source": [
"async def run_codegraph_pipeline(tasks, data_documents):\n",
" pipeline = run_tasks(tasks, data_documents, \"code_graph_pipeline\")\n",
" results = []\n",
" async for result in pipeline:\n",
" results.append(result)\n",
" return(results)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"results = await run_codegraph_pipeline(tasks, datasets[\"main_dataset\"])"
"from cognee.modules.pipelines import run_tasks\n",
"\n",
"notebook_path = os.path.abspath(\"\")\n",
"repo_clone_location = os.path.join(notebook_path, \"data/graphrag\")\n",
"\n",
"pipeline = run_tasks(tasks, repo_clone_location, \"code_graph_pipeline\")\n",
"\n",
"async for result in pipeline:\n",
" print(result)"
]
},
{
@ -103,6 +100,7 @@
"metadata": {},
"outputs": [],
"source": [
"from cognee.shared.utils import render_graph\n",
"await render_graph(None, include_nodes = True, include_labels = True)"
]
},
@ -116,7 +114,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "cognee",
"display_name": ".venv",
"language": "python",
"name": "python3"
},
@ -130,7 +128,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.10"
"version": "3.11.8"
}
},
"nbformat": 4,

873
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -30,7 +30,7 @@ typing_extensions = "4.12.2"
nest_asyncio = "1.6.0"
numpy = "1.26.4"
datasets = "3.1.0"
falkordb = "1.0.9"
falkordb = {version = "1.0.9", optional = true}
boto3 = "^1.26.125"
botocore="^1.35.54"
gunicorn = "^20.1.0"
@ -43,59 +43,65 @@ filetype = "^1.2.0"
nltk = "^3.8.1"
dlt = {extras = ["sqlalchemy"], version = "^1.4.1"}
aiofiles = "^23.2.1"
qdrant-client = "^1.9.0"
qdrant-client = {version = "^1.9.0", optional = true}
graphistry = "^0.33.5"
tenacity = "^8.4.1"
weaviate-client = "4.6.7"
weaviate-client = {version = "4.6.7", optional = true}
scikit-learn = "^1.5.0"
pypdf = "^4.1.0"
neo4j = "^5.20.0"
neo4j = {version = "^5.20.0", optional = true}
jinja2 = "^3.1.3"
matplotlib = "^3.8.3"
tiktoken = "0.7.0"
langchain_text_splitters = "0.3.2"
langsmith = "0.1.139"
langchain_text_splitters = {version = "0.3.2", optional = true}
langsmith = {version = "0.1.139", optional = true}
langdetect = "1.0.9"
posthog = "^3.5.0"
posthog = {version = "^3.5.0", optional = true}
lancedb = "0.15.0"
litellm = "1.49.1"
groq = "0.8.0"
langfuse = "^2.32.0"
groq = {version = "0.8.0", optional = true}
langfuse = {version = "^2.32.0", optional = true}
pydantic-settings = "^2.2.1"
anthropic = "^0.26.1"
sentry-sdk = {extras = ["fastapi"], version = "^2.9.0"}
fastapi-users = {version = "*", extras = ["sqlalchemy"]}
alembic = "^1.13.3"
asyncpg = "0.30.0"
pgvector = "^0.3.5"
asyncpg = {version = "0.30.0", optional = true}
pgvector = {version = "^0.3.5", optional = true}
psycopg2 = {version = "^2.9.10", optional = true}
llama-index-core = {version = "^0.11.22", optional = true}
deepeval = {version = "^2.0.1", optional = true}
transformers = "^4.46.3"
pymilvus = {version = "^2.5.0", optional = true}
[tool.poetry.extras]
filesystem = ["s3fs", "botocore"]
cli = ["pipdeptree", "cron-descriptor"]
weaviate = ["weaviate-client"]
qdrant = ["qdrant-client"]
neo4j = ["neo4j"]
postgres = ["psycopg2", "pgvector", "asyncpg"]
notebook = ["ipykernel", "overrides", "ipywidgets", "jupyterlab", "jupyterlab_widgets", "jupyterlab-server", "jupyterlab-git"]
notebook = ["notebook", "ipykernel", "overrides", "ipywidgets", "jupyterlab", "jupyterlab_widgets", "jupyterlab-server", "jupyterlab-git"]
langchain = ["langsmith", "langchain_text_splitters"]
llama-index = ["llama-index-core"]
deepeval = ["deepeval"]
posthog = ["posthog"]
falkordb = ["falkordb"]
groq = ["groq"]
langfuse = ["langfuse"]
milvus = ["pymilvus"]
[tool.poetry.group.dev.dependencies]
pytest = "^7.4.0"
pytest-asyncio = "^0.21.1"
coverage = "^7.3.2"
mypy = "^1.7.1"
notebook = "^7.1.1"
notebook = {version = "^7.1.1", optional = true}
deptry = "^0.20.0"
debugpy = "1.8.2"
pylint = "^3.0.3"
ruff = "^0.2.2"
tweepy = "4.14.0"
gitpython = "^3.1.43"
[tool.poetry.group.docs.dependencies]
mkdocs-material = "^9.5.42"