Merge remote-tracking branch 'origin/main'

This commit is contained in:
Boris Arzentar 2024-12-03 21:58:19 +01:00
commit dd423ebc3d
82 changed files with 3248 additions and 1597 deletions

6
.gitignore vendored
View file

@ -4,6 +4,8 @@
.prod.env .prod.env
cognee/.data/ cognee/.data/
code_pipeline_output*/
*.lance/ *.lance/
.DS_Store .DS_Store
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
@ -12,7 +14,6 @@ __pycache__/
*$py.class *$py.class
full_run.ipynb full_run.ipynb
evals/
# C extensions # C extensions
*.so *.so
@ -181,3 +182,6 @@ cognee/cache/
.anon_id .anon_id
node_modules/ node_modules/
# Evals
SWE-bench_testsample/

View file

@ -1,9 +1,14 @@
from .api.v1.config.config import config
from .api.v1.add import add from .api.v1.add import add
from .api.v1.cognify import cognify from .api.v1.cognify import cognify
from .api.v1.config.config import config
from .api.v1.datasets.datasets import datasets from .api.v1.datasets.datasets import datasets
from .api.v1.search import search, SearchType, get_search_history
from .api.v1.prune import prune from .api.v1.prune import prune
from .api.v1.search import SearchType, get_search_history, search
# Pipelines # Pipelines
from .modules import pipelines from .modules import pipelines
try:
import dotenv
dotenv.load_dotenv()
except ImportError:
pass

View file

@ -2,7 +2,7 @@ from typing import Union, BinaryIO
from cognee.modules.users.models import User from cognee.modules.users.models import User
from cognee.modules.users.methods import get_default_user from cognee.modules.users.methods import get_default_user
from cognee.modules.pipelines import run_tasks, Task from cognee.modules.pipelines import run_tasks, Task
from cognee.tasks.ingestion import save_data_to_storage, ingest_data from cognee.tasks.ingestion import ingest_data_with_metadata
from cognee.infrastructure.databases.relational import create_db_and_tables as create_relational_db_and_tables from cognee.infrastructure.databases.relational import create_db_and_tables as create_relational_db_and_tables
from cognee.infrastructure.databases.vector.pgvector import create_db_and_tables as create_pgvector_db_and_tables from cognee.infrastructure.databases.vector.pgvector import create_db_and_tables as create_pgvector_db_and_tables
@ -14,8 +14,7 @@ async def add(data: Union[BinaryIO, list[BinaryIO], str, list[str]], dataset_nam
user = await get_default_user() user = await get_default_user()
tasks = [ tasks = [
Task(save_data_to_storage, dataset_name), Task(ingest_data_with_metadata, dataset_name, user)
Task(ingest_data, dataset_name, user)
] ]
pipeline = run_tasks(tasks, data, "add_pipeline") pipeline = run_tasks(tasks, data, "add_pipeline")

View file

@ -38,7 +38,7 @@ def create_graph_engine() -> GraphDBInterface:
) )
elif config.graph_database_provider == "falkordb": elif config.graph_database_provider == "falkordb":
if not (config.graph_database_url and config.graph_database_username and config.graph_database_password): if not (config.graph_database_url and config.graph_database_port):
raise EnvironmentError("Missing required FalkorDB credentials.") raise EnvironmentError("Missing required FalkorDB credentials.")
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine

View file

@ -1,7 +1,8 @@
import asyncio import asyncio
from textwrap import dedent # from datetime import datetime
from typing import Any import json
from uuid import UUID from uuid import UUID
from textwrap import dedent
from falkordb import FalkorDB from falkordb import FalkorDB
from cognee.exceptions import InvalidValueError from cognee.exceptions import InvalidValueError
@ -44,30 +45,39 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
async def embed_data(self, data: list[str]) -> list[list[float]]: async def embed_data(self, data: list[str]) -> list[list[float]]:
return await self.embedding_engine.embed_text(data) return await self.embedding_engine.embed_text(data)
async def stringify_properties(self, properties: dict, vectorize_fields = []) -> str: async def stringify_properties(self, properties: dict) -> str:
async def get_value(key, value): def parse_value(value):
return f"'{value}'" if key not in vectorize_fields else await self.get_vectorized_value(value) if type(value) is UUID:
return f"'{str(value)}'"
if type(value) is int or type(value) is float:
return value
if type(value) is list and type(value[0]) is float and len(value) == self.embedding_engine.get_vector_size():
return f"'vecf32({value})'"
# if type(value) is datetime:
# return datetime.strptime(value, "%Y-%m-%dT%H:%M:%S.%f%z")
if type(value) is dict:
return f"'{json.dumps(value)}'"
return f"'{value}'"
return ",".join([f"{key}:{await get_value(key, value)}" for key, value in properties.items()]) return ",".join([f"{key}:{parse_value(value)}" for key, value in properties.items()])
async def get_vectorized_value(self, value: Any) -> str: async def create_data_point_query(self, data_point: DataPoint, vectorized_values: dict):
vector = (await self.embed_data([value]))[0] node_label = type(data_point).__tablename__
return f"vecf32({vector})" property_names = DataPoint.get_embeddable_property_names(data_point)
async def create_data_point_query(self, data_point: DataPoint): node_properties = await self.stringify_properties({
node_label = type(data_point).__name__ **data_point.model_dump(),
node_properties = await self.stringify_properties( **({
data_point.model_dump(), property_names[index]: (vectorized_values[index] \
data_point._metadata["index_fields"], if index < len(vectorized_values) else getattr(data_point, property_name, None)) \
# data_point._metadata["index_fields"] if hasattr(data_point, "_metadata") else [], for index, property_name in enumerate(property_names)
) }),
})
return dedent(f""" return dedent(f"""
MERGE (node:{node_label} {{id: '{str(data_point.id)}'}}) MERGE (node:{node_label} {{id: '{str(data_point.id)}'}})
ON CREATE SET node += ({{{node_properties}}}) ON CREATE SET node += ({{{node_properties}}}), node.updated_at = timestamp()
ON CREATE SET node.updated_at = timestamp() ON MATCH SET node += ({{{node_properties}}}), node.updated_at = timestamp()
ON MATCH SET node += ({{{node_properties}}})
ON MATCH SET node.updated_at = timestamp()
""").strip() """).strip()
async def create_edge_query(self, edge: tuple[str, str, str, dict]) -> str: async def create_edge_query(self, edge: tuple[str, str, str, dict]) -> str:
@ -91,7 +101,37 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
return collection_name in collections return collection_name in collections
async def create_data_points(self, data_points: list[DataPoint]): async def create_data_points(self, data_points: list[DataPoint]):
queries = [await self.create_data_point_query(data_point) for data_point in data_points] embeddable_values = []
vector_map = {}
for data_point in data_points:
property_names = DataPoint.get_embeddable_property_names(data_point)
key = str(data_point.id)
vector_map[key] = {}
for property_name in property_names:
property_value = getattr(data_point, property_name, None)
if property_value is not None:
vector_map[key][property_name] = len(embeddable_values)
embeddable_values.append(property_value)
else:
vector_map[key][property_name] = None
vectorized_values = await self.embed_data(embeddable_values)
queries = [
await self.create_data_point_query(
data_point,
[
vectorized_values[vector_map[str(data_point.id)][property_name]] \
if vector_map[str(data_point.id)][property_name] is not None \
else None \
for property_name in DataPoint.get_embeddable_property_names(data_point)
],
) for data_point in data_points
]
for query in queries: for query in queries:
self.query(query) self.query(query)
@ -149,18 +189,21 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
return [result["edge_exists"] for result in results] return [result["edge_exists"] for result in results]
async def retrieve(self, data_point_ids: list[str]): async def retrieve(self, data_point_ids: list[UUID]):
return self.query( result = self.query(
f"MATCH (node) WHERE node.id IN $node_ids RETURN node", f"MATCH (node) WHERE node.id IN $node_ids RETURN node",
{ {
"node_ids": data_point_ids, "node_ids": [str(data_point) for data_point in data_point_ids],
}, },
) )
return result.result_set
async def extract_node(self, data_point_id: str): async def extract_node(self, data_point_id: UUID):
return await self.retrieve([data_point_id]) result = await self.retrieve([data_point_id])
result = result[0][0] if len(result[0]) > 0 else None
return result.properties if result else None
async def extract_nodes(self, data_point_ids: list[str]): async def extract_nodes(self, data_point_ids: list[UUID]):
return await self.retrieve(data_point_ids) return await self.retrieve(data_point_ids)
async def get_connections(self, node_id: UUID) -> list: async def get_connections(self, node_id: UUID) -> list:
@ -206,10 +249,12 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
if query_text and not query_vector: if query_text and not query_vector:
query_vector = (await self.embed_data([query_text]))[0] query_vector = (await self.embed_data([query_text]))[0]
[label, attribute_name] = collection_name.split(".")
query = dedent(f""" query = dedent(f"""
CALL db.idx.vector.queryNodes( CALL db.idx.vector.queryNodes(
{collection_name}, '{label}',
'text', '{attribute_name}',
{limit}, {limit},
vecf32({query_vector}) vecf32({query_vector})
) YIELD node, score ) YIELD node, score
@ -217,7 +262,7 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
result = self.query(query) result = self.query(query)
return result return result.result_set
async def batch_search( async def batch_search(
self, self,
@ -237,11 +282,35 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
) for query_vector in query_vectors] ) for query_vector in query_vectors]
) )
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]): async def get_graph_data(self):
query = "MATCH (n) RETURN ID(n) AS id, labels(n) AS labels, properties(n) AS properties"
result = self.query(query)
nodes = [(
record[2]["id"],
record[2],
) for record in result.result_set]
query = """
MATCH (n)-[r]->(m)
RETURN ID(n) AS source, ID(m) AS target, TYPE(r) AS type, properties(r) AS properties
"""
result = self.query(query)
edges = [(
record[3]["source_node_id"],
record[3]["target_node_id"],
record[2],
record[3],
) for record in result.result_set]
return (nodes, edges)
async def delete_data_points(self, collection_name: str, data_point_ids: list[UUID]):
return self.query( return self.query(
f"MATCH (node) WHERE node.id IN $node_ids DETACH DELETE node", f"MATCH (node) WHERE node.id IN $node_ids DETACH DELETE node",
{ {
"node_ids": data_point_ids, "node_ids": [str(data_point) for data_point in data_point_ids],
}, },
) )
@ -265,4 +334,4 @@ class FalkorDBAdapter(VectorDBInterface, GraphDBInterface):
print(f"Error deleting graph: {e}") print(f"Error deleting graph: {e}")
async def prune(self): async def prune(self):
self.delete_graph() await self.delete_graph()

View file

@ -58,7 +58,7 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
) )
elif config["vector_db_provider"] == "falkordb": elif config["vector_db_provider"] == "falkordb":
if not (config["vector_db_url"] and config["vector_db_key"]): if not (config["vector_db_url"] and config["vector_db_port"]):
raise EnvironmentError("Missing requred FalkorDB credentials!") raise EnvironmentError("Missing requred FalkorDB credentials!")
from ..hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter from ..hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter

View file

@ -1,9 +1,10 @@
import asyncio import logging
from typing import List, Optional from typing import List, Optional
import litellm import litellm
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
litellm.set_verbose = False litellm.set_verbose = False
logger = logging.getLogger("LiteLLMEmbeddingEngine")
class LiteLLMEmbeddingEngine(EmbeddingEngine): class LiteLLMEmbeddingEngine(EmbeddingEngine):
api_key: str api_key: str
@ -27,20 +28,19 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
self.dimensions = dimensions self.dimensions = dimensions
async def embed_text(self, text: List[str]) -> List[List[float]]: async def embed_text(self, text: List[str]) -> List[List[float]]:
async def get_embedding(text_): try:
response = await litellm.aembedding( response = await litellm.aembedding(
self.model, self.model,
input = text_, input = text,
api_key = self.api_key, api_key = self.api_key,
api_base = self.endpoint, api_base = self.endpoint,
api_version = self.api_version api_version = self.api_version
) )
except litellm.exceptions.BadRequestError as error:
logger.error("Error embedding text: %s", str(error))
raise error
return response.data[0]["embedding"] return [data["embedding"] for data in response.data]
tasks = [get_embedding(text_) for text_ in text]
result = await asyncio.gather(*tasks)
return result
def get_vector_size(self) -> int: def get_vector_size(self) -> int:
return self.dimensions return self.dimensions

View file

@ -1,4 +1,3 @@
import inspect
from typing import List, Optional, get_type_hints, Generic, TypeVar from typing import List, Optional, get_type_hints, Generic, TypeVar
import asyncio import asyncio
from uuid import UUID from uuid import UUID
@ -88,7 +87,7 @@ class LanceDBAdapter(VectorDBInterface):
collection = await connection.open_table(collection_name) collection = await connection.open_table(collection_name)
data_vectors = await self.embed_data( data_vectors = await self.embed_data(
[data_point.get_embeddable_data() for data_point in data_points] [DataPoint.get_embeddable_data(data_point) for data_point in data_points]
) )
IdType = TypeVar("IdType") IdType = TypeVar("IdType")
@ -115,19 +114,11 @@ class LanceDBAdapter(VectorDBInterface):
for (data_point_index, data_point) in enumerate(data_points) for (data_point_index, data_point) in enumerate(data_points)
] ]
# TODO: This enables us to work with pydantic version but shouldn't await collection.merge_insert("id") \
# stay like this, existing rows should be updated .when_matched_update_all() \
.when_not_matched_insert_all() \
.execute(lance_data_points)
await collection.delete("id IS NOT NULL")
original_size = await collection.count_rows()
await collection.add(lance_data_points)
new_size = await collection.count_rows()
if new_size <= original_size:
raise InvalidValueError(message=
"LanceDB create_datapoints error: data points did not get added.")
async def retrieve(self, collection_name: str, data_point_ids: list[str]): async def retrieve(self, collection_name: str, data_point_ids: list[str]):
connection = await self.get_connection() connection = await self.get_connection()
@ -145,10 +136,10 @@ class LanceDBAdapter(VectorDBInterface):
) for result in results.to_dict("index").values()] ) for result in results.to_dict("index").values()]
async def get_distance_from_collection_elements( async def get_distance_from_collection_elements(
self, self,
collection_name: str, collection_name: str,
query_text: str = None, query_text: str = None,
query_vector: List[float] = None query_vector: List[float] = None
): ):
if query_text is None and query_vector is None: if query_text is None and query_vector is None:
raise InvalidValueError(message="One of query_text or query_vector must be provided!") raise InvalidValueError(message="One of query_text or query_vector must be provided!")

View file

@ -102,7 +102,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
) )
data_vectors = await self.embed_data( data_vectors = await self.embed_data(
[data_point.get_embeddable_data() for data_point in data_points] [DataPoint.get_embeddable_data(data_point) for data_point in data_points]
) )
vector_size = self.embedding_engine.get_vector_size() vector_size = self.embedding_engine.get_vector_size()
@ -143,7 +143,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
await self.create_data_points(f"{index_name}_{index_property_name}", [ await self.create_data_points(f"{index_name}_{index_property_name}", [
IndexSchema( IndexSchema(
id = data_point.id, id = data_point.id,
text = data_point.get_embeddable_data(), text = DataPoint.get_embeddable_data(data_point),
) for data_point in data_points ) for data_point in data_points
]) ])

View file

@ -102,7 +102,9 @@ class QDrantAdapter(VectorDBInterface):
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]): async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
client = self.get_qdrant_client() client = self.get_qdrant_client()
data_vectors = await self.embed_data([data_point.get_embeddable_data() for data_point in data_points]) data_vectors = await self.embed_data([
DataPoint.get_embeddable_data(data_point) for data_point in data_points
])
def convert_to_qdrant_point(data_point: DataPoint): def convert_to_qdrant_point(data_point: DataPoint):
return models.PointStruct( return models.PointStruct(

View file

@ -1,8 +1,6 @@
from typing import List from typing import List
def normalize_distances(result_values: List[dict]) -> List[float]: def normalize_distances(result_values: List[dict]) -> List[float]:
min_value = min(result["_distance"] for result in result_values) min_value = min(result["_distance"] for result in result_values)
max_value = max(result["_distance"] for result in result_values) max_value = max(result["_distance"] for result in result_values)
@ -13,4 +11,4 @@ def normalize_distances(result_values: List[dict]) -> List[float]:
normalized_values = [(result["_distance"] - min_value) / (max_value - min_value) for result in normalized_values = [(result["_distance"] - min_value) / (max_value - min_value) for result in
result_values] result_values]
return normalized_values return normalized_values

View file

@ -83,7 +83,7 @@ class WeaviateAdapter(VectorDBInterface):
from weaviate.classes.data import DataObject from weaviate.classes.data import DataObject
data_vectors = await self.embed_data( data_vectors = await self.embed_data(
[data_point.get_embeddable_data() for data_point in data_points] [DataPoint.get_embeddable_data(data_point) for data_point in data_points]
) )
def convert_to_weaviate_data_points(data_point: DataPoint): def convert_to_weaviate_data_points(data_point: DataPoint):
@ -116,12 +116,20 @@ class WeaviateAdapter(VectorDBInterface):
) )
else: else:
data_point: DataObject = data_points[0] data_point: DataObject = data_points[0]
return collection.data.update( if collection.data.exists(data_point.uuid):
uuid = data_point.uuid, return collection.data.update(
vector = data_point.vector, uuid = data_point.uuid,
properties = data_point.properties, vector = data_point.vector,
references = data_point.references, properties = data_point.properties,
) references = data_point.references,
)
else:
return collection.data.insert(
uuid = data_point.uuid,
vector = data_point.vector,
properties = data_point.properties,
references = data_point.references,
)
except Exception as error: except Exception as error:
logger.error("Error creating data points: %s", str(error)) logger.error("Error creating data points: %s", str(error))
raise error raise error
@ -133,7 +141,7 @@ class WeaviateAdapter(VectorDBInterface):
await self.create_data_points(f"{index_name}_{index_property_name}", [ await self.create_data_points(f"{index_name}_{index_property_name}", [
IndexSchema( IndexSchema(
id = data_point.id, id = data_point.id,
text = data_point.get_embeddable_data(), text = DataPoint.get_embeddable_data(data_point),
) for data_point in data_points ) for data_point in data_points
]) ])

View file

@ -11,6 +11,7 @@ class DataPoint(BaseModel):
__tablename__ = "data_point" __tablename__ = "data_point"
id: UUID = Field(default_factory = uuid4) id: UUID = Field(default_factory = uuid4)
updated_at: Optional[datetime] = datetime.now(timezone.utc) updated_at: Optional[datetime] = datetime.now(timezone.utc)
topological_rank: Optional[int] = 0
_metadata: Optional[MetaData] = { _metadata: Optional[MetaData] = {
"index_fields": [] "index_fields": []
} }
@ -18,11 +19,24 @@ class DataPoint(BaseModel):
# class Config: # class Config:
# underscore_attrs_are_private = True # underscore_attrs_are_private = True
def get_embeddable_data(self): @classmethod
if self._metadata and len(self._metadata["index_fields"]) > 0 \ def get_embeddable_data(self, data_point):
and hasattr(self, self._metadata["index_fields"][0]): if data_point._metadata and len(data_point._metadata["index_fields"]) > 0 \
attribute = getattr(self, self._metadata["index_fields"][0]) and hasattr(data_point, data_point._metadata["index_fields"][0]):
attribute = getattr(data_point, data_point._metadata["index_fields"][0])
if isinstance(attribute, str): if isinstance(attribute, str):
return(attribute.strip()) return attribute.strip()
else: else:
return (attribute) return attribute
@classmethod
def get_embeddable_properties(self, data_point):
if data_point._metadata and len(data_point._metadata["index_fields"]) > 0:
return [getattr(data_point, field, None) for field in data_point._metadata["index_fields"]]
return []
@classmethod
def get_embeddable_property_names(self, data_point):
return data_point._metadata["index_fields"] or []

View file

@ -4,6 +4,7 @@ from .guess_file_type import guess_file_type
class FileMetadata(TypedDict): class FileMetadata(TypedDict):
name: str name: str
file_path: str
mime_type: str mime_type: str
extension: str extension: str

View file

@ -1,7 +1,6 @@
from typing import Type from typing import Type
from pydantic import BaseModel from pydantic import BaseModel
import instructor import instructor
from tenacity import retry, stop_after_attempt
import anthropic import anthropic
from cognee.exceptions import InvalidValueError from cognee.exceptions import InvalidValueError

View file

@ -3,7 +3,6 @@ import asyncio
from typing import List, Type from typing import List, Type
from pydantic import BaseModel from pydantic import BaseModel
import instructor import instructor
from tenacity import retry, stop_after_attempt
import openai import openai
from cognee.exceptions import InvalidValueError from cognee.exceptions import InvalidValueError
@ -55,60 +54,6 @@ class GenericAPIAdapter(LLMInterface):
mode = instructor.Mode.JSON, mode = instructor.Mode.JSON,
) )
@retry(stop = stop_after_attempt(5))
def completions_with_backoff(self, **kwargs):
"""Wrapper around ChatCompletion.create w/ backoff"""
# Local model
return openai.chat.completions.create(**kwargs)
@retry(stop = stop_after_attempt(5))
async def acompletions_with_backoff(self, **kwargs):
"""Wrapper around ChatCompletion.acreate w/ backoff"""
return await openai.chat.completions.acreate(**kwargs)
@retry(stop = stop_after_attempt(5))
async def acreate_embedding_with_backoff(self, input: List[str], model: str = "text-embedding-3-large"):
"""Wrapper around Embedding.acreate w/ backoff"""
return await self.aclient.embeddings.create(input = input, model = model)
async def async_get_embedding_with_backoff(self, text, model="text-embedding-3-large"):
"""To get text embeddings, import/call this function
It specifies defaults + handles rate-limiting + is async"""
text = text.replace("\n", " ")
response = await self.aclient.embeddings.create(input = text, model = model)
embedding = response.data[0].embedding
return embedding
@retry(stop = stop_after_attempt(5))
def create_embedding_with_backoff(self, **kwargs):
"""Wrapper around Embedding.create w/ backoff"""
return openai.embeddings.create(**kwargs)
def get_embedding_with_backoff(self, text: str, model: str = "text-embedding-3-large"):
"""To get text embeddings, import/call this function
It specifies defaults + handles rate-limiting
:param text: str
:param model: str
"""
text = text.replace("\n", " ")
response = self.create_embedding_with_backoff(input=[text], model=model)
embedding = response.data[0].embedding
return embedding
async def async_get_batch_embeddings_with_backoff(self, texts: List[str], models: List[str]):
"""To get multiple text embeddings in parallel, import/call this function
It specifies defaults + handles rate-limiting + is async"""
# Collect all coroutines
coroutines = (self.async_get_embedding_with_backoff(text, model)
for text, model in zip(texts, models))
# Run the coroutines in parallel and gather the results
embeddings = await asyncio.gather(*coroutines)
return embeddings
@retry(stop = stop_after_attempt(5))
async def acreate_structured_output(self, text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel: async def acreate_structured_output(self, text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel:
"""Generate a response from a user query.""" """Generate a response from a user query."""

View file

@ -3,6 +3,8 @@
from typing import Type, Protocol from typing import Type, Protocol
from abc import abstractmethod from abc import abstractmethod
from pydantic import BaseModel from pydantic import BaseModel
from cognee.infrastructure.llm.prompts import read_query_prompt
class LLMInterface(Protocol): class LLMInterface(Protocol):
""" LLM Interface """ """ LLM Interface """
@ -14,7 +16,14 @@ class LLMInterface(Protocol):
"""To get structured output, import/call this function""" """To get structured output, import/call this function"""
raise NotImplementedError raise NotImplementedError
@abstractmethod
def show_prompt(self, text_input: str, system_prompt: str) -> str: def show_prompt(self, text_input: str, system_prompt: str) -> str:
"""To get structured output, import/call this function""" """Format and display the prompt for a user query."""
raise NotImplementedError if not text_input:
text_input = "No user input provided."
if not system_prompt:
raise ValueError("No system prompt path provided.")
system_prompt = read_query_prompt(system_prompt)
formatted_prompt = f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n"""
return formatted_prompt

View file

@ -0,0 +1,2 @@
Answer the question using the provided context. Be as brief as possible.
Each entry in the context is a paragraph, which is represented as a list with two elements [title, sentences] and sentences is a list of strings.

View file

@ -0,0 +1,2 @@
Answer the question using the provided context. Be as brief as possible.
Each entry in the context is tuple of length 3, representing an edge of a knowledge graph with its two nodes.

View file

@ -0,0 +1,2 @@
The question is: `{{ question }}`
And here is the context: `{{ context }}`

View file

@ -1,14 +0,0 @@
You are tasked with analyzing `{{ data_type }}` files, especially in a multilayer network context for tasks such as analysis, categorization, and feature extraction. Various layers can be incorporated to capture the depth and breadth of information contained within the {{ data_type }}.
These layers can help in understanding the content, context, and characteristics of the `{{ data_type }}`.
Your objective is to extract meaningful layers of information that will contribute to constructing a detailed multilayer network or knowledge graph.
Approach this task by considering the unique characteristics and inherent properties of the data at hand.
VERY IMPORTANT: The context you are working in is `{{ category_name }}` and the specific domain you are extracting data on is `{{ category_name }}`.
Guidelines for Layer Extraction:
Take into account: The content type, in this case, is: `{{ category_name }}`, should play a major role in how you decompose into layers.
Based on your analysis, define and describe the layers you've identified, explaining their relevance and contribution to understanding the dataset. Your independent identification of layers will enable a nuanced and multifaceted representation of the data, enhancing applications in knowledge discovery, content analysis, and information retrieval.

View file

@ -1,3 +1,2 @@
I need you to solve this issue by looking at the provided knowledge graph and I need you to solve this issue by generating a single patch file that I can apply directly to this repository using git apply.
generating a single patch file that I can apply directly to this repository using git apply.
Please respond with a single patch file in the following format. Please respond with a single patch file in the following format.

View file

@ -0,0 +1,3 @@
I need you to solve this issue by looking at the provided edges retrieved from a knowledge graph and
generate a single patch file that I can apply directly to this repository using git apply.
Please respond with a single patch file in the following format.

View file

@ -35,6 +35,10 @@ class TextChunker():
is_part_of = self.document, is_part_of = self.document,
chunk_index = self.chunk_index, chunk_index = self.chunk_index,
cut_type = chunk_data["cut_type"], cut_type = chunk_data["cut_type"],
_metadata = {
"index_fields": ["text"],
"metadata_id": self.document.metadata_id
}
) )
paragraph_chunks = [] paragraph_chunks = []
self.chunk_size = 0 self.chunk_size = 0
@ -48,6 +52,10 @@ class TextChunker():
is_part_of = self.document, is_part_of = self.document,
chunk_index = self.chunk_index, chunk_index = self.chunk_index,
cut_type = paragraph_chunks[len(paragraph_chunks) - 1]["cut_type"], cut_type = paragraph_chunks[len(paragraph_chunks) - 1]["cut_type"],
_metadata = {
"index_fields": ["text"],
"metadata_id": self.document.metadata_id
}
) )
except Exception as e: except Exception as e:
print(e) print(e)
@ -65,6 +73,10 @@ class TextChunker():
is_part_of = self.document, is_part_of = self.document,
chunk_index = self.chunk_index, chunk_index = self.chunk_index,
cut_type = paragraph_chunks[len(paragraph_chunks) - 1]["cut_type"], cut_type = paragraph_chunks[len(paragraph_chunks) - 1]["cut_type"],
_metadata = {
"index_fields": ["text"],
"metadata_id": self.document.metadata_id
}
) )
except Exception as e: except Exception as e:
print(e) print(e)

View file

@ -1,11 +0,0 @@
from typing import Type, Dict
from pydantic import BaseModel
from cognee.infrastructure.llm.prompts import render_prompt
from cognee.infrastructure.llm.get_llm_client import get_llm_client
async def extract_cognitive_layers(content: str, category: Dict, response_model: Type[BaseModel]):
llm_client = get_llm_client()
system_prompt = render_prompt("generate_cog_layers.txt", category)
return await llm_client.acreate_structured_output(content, system_prompt, response_model)

View file

@ -1,31 +1,46 @@
from uuid import uuid4
from typing import List
from datetime import datetime, timezone from datetime import datetime, timezone
from sqlalchemy.orm import relationship, Mapped from typing import List
from sqlalchemy import Column, String, DateTime, UUID from uuid import uuid4
from sqlalchemy import UUID, Column, DateTime, String
from sqlalchemy.orm import Mapped, relationship
from cognee.infrastructure.databases.relational import Base from cognee.infrastructure.databases.relational import Base
from .DatasetData import DatasetData from .DatasetData import DatasetData
from .Metadata import Metadata
class Data(Base): class Data(Base):
__tablename__ = "data" __tablename__ = "data"
id = Column(UUID, primary_key = True, default = uuid4) id = Column(UUID, primary_key=True, default=uuid4)
name = Column(String) name = Column(String)
extension = Column(String) extension = Column(String)
mime_type = Column(String) mime_type = Column(String)
raw_data_location = Column(String) raw_data_location = Column(String)
created_at = Column(
created_at = Column(DateTime(timezone = True), default = lambda: datetime.now(timezone.utc)) DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
updated_at = Column(DateTime(timezone = True), onupdate = lambda: datetime.now(timezone.utc))
datasets: Mapped[List["Dataset"]] = relationship(
"Dataset",
secondary = DatasetData.__tablename__,
back_populates = "data",
lazy = "noload",
cascade="all, delete"
) )
updated_at = Column(
DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc)
)
datasets = relationship(
"Dataset",
secondary=DatasetData.__tablename__,
back_populates="data",
lazy="noload",
cascade="all, delete",
)
metadata_relationship = relationship(
"Metadata",
back_populates="data",
lazy="noload",
cascade="all, delete",
)
def to_json(self) -> dict: def to_json(self) -> dict:
return { return {

View file

@ -0,0 +1,26 @@
from datetime import datetime, timezone
from uuid import uuid4
from sqlalchemy import UUID, Column, DateTime, String, ForeignKey
from sqlalchemy.orm import relationship
from cognee.infrastructure.databases.relational import Base
class Metadata(Base):
__tablename__ = "metadata_table"
id = Column(UUID, primary_key=True, default=uuid4)
metadata_repr = Column(String)
metadata_source = Column(String)
created_at = Column(
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
)
updated_at = Column(
DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc)
)
data_id = Column(UUID, ForeignKey("data.id", ondelete="CASCADE"), primary_key = False)
data = relationship("Data", back_populates="metadata_relationship")

View file

@ -0,0 +1,19 @@
import warnings
from uuid import UUID
from sqlalchemy import select
from cognee.infrastructure.databases.relational import get_relational_engine
from ..models.Metadata import Metadata
async def delete_metadata(metadata_id: UUID):
db_engine = get_relational_engine()
async with db_engine.get_async_session() as session:
metadata = await session.get(Metadata, metadata_id)
if metadata is None:
warnings.warn(f"metadata for metadata_id: {metadata_id} not found")
session.delete(metadata)
session.commit()

View file

@ -0,0 +1,19 @@
import json
from uuid import UUID
from sqlalchemy import select
from cognee.infrastructure.databases.relational import get_relational_engine
from ..models.Metadata import Metadata
async def get_metadata(metadata_id: UUID) -> Metadata:
db_engine = get_relational_engine()
async with db_engine.get_async_session() as session:
metadata = await session.get(Metadata, metadata_id)
return metadata

View file

@ -0,0 +1,52 @@
import inspect
import json
import re
import warnings
from typing import Any
from uuid import UUID
from typing import Any, BinaryIO, Union
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.infrastructure.files.utils.get_file_metadata import FileMetadata
from ..models.Metadata import Metadata
async def write_metadata(data_item: Union[BinaryIO, str, Any], data_id: UUID, file_metadata: FileMetadata) -> UUID:
metadata_dict = get_metadata_dict(data_item, file_metadata)
db_engine = get_relational_engine()
async with db_engine.get_async_session() as session:
metadata = Metadata(
id=data_id,
metadata_repr=json.dumps(metadata_dict),
metadata_source=parse_type(type(data_item)),
data_id=data_id
)
session.add(metadata)
await session.commit()
def parse_type(type_: Any) -> str:
pattern = r".+'([\w_\.]+)'"
match = re.search(pattern, str(type_))
if match:
return match.group(1)
else:
raise Exception(f"type: {type_} could not be parsed")
def get_metadata_dict(data_item: Union[BinaryIO, str, Any], file_metadata: FileMetadata) -> dict[str, Any]:
if isinstance(data_item, str):
return(file_metadata)
elif isinstance(data_item, BinaryIO):
return(file_metadata)
elif hasattr(data_item, "dict") and inspect.ismethod(getattr(data_item, "dict")):
return {**file_metadata, **data_item.dict()}
else:
warnings.warn(
f"metadata of type {type(data_item)}: {str(data_item)[:20]}... does not have dict method. Defaulting to string method"
)
try:
return {**dict(file_metadata), "content": str(data_item)}
except Exception as e:
raise Exception(f"Could not cast metadata to string: {e}")

View file

@ -1,9 +1,11 @@
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from uuid import UUID
class Document(DataPoint): class Document(DataPoint):
type: str type: str
name: str name: str
raw_data_location: str raw_data_location: str
metadata_id: UUID
def read(self, chunk_size: int) -> str: def read(self, chunk_size: int) -> str:
pass pass

View file

@ -1,6 +1,7 @@
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
from .EntityType import EntityType from cognee.modules.engine.models.EntityType import EntityType
class Entity(DataPoint): class Entity(DataPoint):
__tablename__ = "entity" __tablename__ = "entity"
@ -8,6 +9,7 @@ class Entity(DataPoint):
is_a: EntityType is_a: EntityType
description: str description: str
mentioned_in: DocumentChunk mentioned_in: DocumentChunk
_metadata: dict = { _metadata: dict = {
"index_fields": ["name"], "index_fields": ["name"],
} }

View file

@ -1,12 +1,14 @@
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
class EntityType(DataPoint): class EntityType(DataPoint):
__tablename__ = "entity_type" __tablename__ = "entity_type"
name: str name: str
type: str type: str
description: str description: str
exists_in: DocumentChunk exists_in: DocumentChunk
_metadata: dict = { _metadata: dict = {
"index_fields": ["name"], "index_fields": ["name"],
} }

View file

@ -55,14 +55,16 @@ class CogneeGraph(CogneeAbstractGraph):
def get_edges(self)-> List[Edge]: def get_edges(self)-> List[Edge]:
return self.edges return self.edges
async def project_graph_from_db(self, async def project_graph_from_db(
adapter: Union[GraphDBInterface], self,
node_properties_to_project: List[str], adapter: Union[GraphDBInterface],
edge_properties_to_project: List[str], node_properties_to_project: List[str],
directed = True, edge_properties_to_project: List[str],
node_dimension = 1, directed = True,
edge_dimension = 1, node_dimension = 1,
memory_fragment_filter = []) -> None: edge_dimension = 1,
memory_fragment_filter = [],
) -> None:
if node_dimension < 1 or edge_dimension < 1: if node_dimension < 1 or edge_dimension < 1:
raise InvalidValueError(message="Dimensions must be positive integers") raise InvalidValueError(message="Dimensions must be positive integers")

View file

@ -1,2 +1,6 @@
from .expand_with_nodes_and_edges import expand_with_nodes_and_edges
from .get_graph_from_model import get_graph_from_model from .get_graph_from_model import get_graph_from_model
from .get_model_instance_from_graph import get_model_instance_from_graph from .get_model_instance_from_graph import get_model_instance_from_graph
from .retrieve_existing_edges import retrieve_existing_edges
from .convert_node_to_data_point import convert_node_to_data_point
from .deduplicate_nodes_and_edges import deduplicate_nodes_and_edges

View file

@ -0,0 +1,23 @@
from cognee.infrastructure.engine import DataPoint
def convert_node_to_data_point(node_data: dict) -> DataPoint:
subclass = find_subclass_by_name(DataPoint, node_data["type"])
return subclass(**node_data)
def get_all_subclasses(cls):
subclasses = []
for subclass in cls.__subclasses__():
subclasses.append(subclass)
subclasses.extend(get_all_subclasses(subclass)) # Recursively get subclasses
return subclasses
def find_subclass_by_name(cls, name):
for subclass in get_all_subclasses(cls):
if subclass.__name__ == name:
return subclass
return None

View file

@ -0,0 +1,19 @@
from cognee.infrastructure.engine import DataPoint
def deduplicate_nodes_and_edges(nodes: list[DataPoint], edges: list[dict]):
added_entities = {}
final_nodes = []
final_edges = []
for node in nodes:
if str(node.id) not in added_entities:
final_nodes.append(node)
added_entities[str(node.id)] = True
for edge in edges:
edge_key = str(edge[0]) + str(edge[2]) + str(edge[1])
if edge_key not in added_entities:
final_edges.append(edge)
added_entities[edge_key] = True
return final_nodes, final_edges

View file

@ -0,0 +1,83 @@
from typing import Optional
from cognee.infrastructure.engine import DataPoint
from cognee.modules.engine.models import Entity, EntityType
from cognee.modules.engine.utils import (
generate_edge_name,
generate_node_id,
generate_node_name,
)
from cognee.shared.data_models import KnowledgeGraph
def expand_with_nodes_and_edges(
graph_node_index: list[tuple[DataPoint, KnowledgeGraph]],
existing_edges_map: Optional[dict[str, bool]] = None,
):
if existing_edges_map is None:
existing_edges_map = {}
added_nodes_map = {}
relationships = []
data_points = []
for graph_source, graph in graph_node_index:
if graph is None:
continue
for node in graph.nodes:
node_id = generate_node_id(node.id)
node_name = generate_node_name(node.name)
type_node_id = generate_node_id(node.type)
type_node_name = generate_node_name(node.type)
if f"{str(type_node_id)}_type" not in added_nodes_map:
type_node = EntityType(
id = type_node_id,
name = type_node_name,
type = type_node_name,
description = type_node_name,
exists_in = graph_source,
)
added_nodes_map[f"{str(type_node_id)}_type"] = type_node
else:
type_node = added_nodes_map[f"{str(type_node_id)}_type"]
if f"{str(node_id)}_entity" not in added_nodes_map:
entity_node = Entity(
id = node_id,
name = node_name,
is_a = type_node,
description = node.description,
mentioned_in = graph_source,
)
data_points.append(entity_node)
added_nodes_map[f"{str(node_id)}_entity"] = entity_node
# Add relationship that came from graphs.
for edge in graph.edges:
source_node_id = generate_node_id(edge.source_node_id)
target_node_id = generate_node_id(edge.target_node_id)
relationship_name = generate_edge_name(edge.relationship_name)
edge_key = str(source_node_id) + str(target_node_id) + relationship_name
if edge_key not in existing_edges_map:
relationships.append(
(
source_node_id,
target_node_id,
edge.relationship_name,
dict(
relationship_name = generate_edge_name(
edge.relationship_name
),
source_node_id = source_node_id,
target_node_id = target_node_id,
),
)
)
existing_edges_map[edge_key] = True
return (data_points, relationships)

View file

@ -3,26 +3,44 @@ from datetime import datetime, timezone
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.modules.storage.utils import copy_model from cognee.modules.storage.utils import copy_model
async def get_graph_from_model(
def get_graph_from_model(data_point: DataPoint, added_nodes=None, added_edges=None): data_point: DataPoint,
include_root = True,
if not added_nodes: added_nodes = None,
added_nodes = {} added_edges = None,
if not added_edges: visited_properties = None,
added_edges = {} ):
nodes = [] nodes = []
edges = [] edges = []
added_nodes = added_nodes or {}
added_edges = added_edges or {}
visited_properties = visited_properties or {}
data_point_properties = {} data_point_properties = {}
excluded_properties = set() excluded_properties = set()
if str(data_point.id) in added_nodes:
return nodes, edges
for field_name, field_value in data_point: for field_name, field_value in data_point:
if field_name == "_metadata": if field_name == "_metadata":
continue continue
elif isinstance(field_value, DataPoint):
if field_value is None:
excluded_properties.add(field_name) excluded_properties.add(field_name)
nodes, edges, added_nodes, added_edges = add_nodes_and_edges( continue
if isinstance(field_value, DataPoint):
excluded_properties.add(field_name)
property_key = f"{str(data_point.id)}{field_name}{str(field_value.id)}"
if property_key in visited_properties:
continue
visited_properties[property_key] = True
nodes, edges = await add_nodes_and_edges(
data_point, data_point,
field_name, field_name,
field_value, field_value,
@ -30,46 +48,68 @@ def get_graph_from_model(data_point: DataPoint, added_nodes=None, added_edges=No
edges, edges,
added_nodes, added_nodes,
added_edges, added_edges,
visited_properties,
) )
elif ( continue
isinstance(field_value, list)
and len(field_value) > 0 if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint):
and isinstance(field_value[0], DataPoint)
):
excluded_properties.add(field_name) excluded_properties.add(field_name)
for item in field_value: for field_value_item in field_value:
n_edges_before = len(edges) property_key = f"{str(data_point.id)}{field_name}{str(field_value_item.id)}"
nodes, edges, added_nodes, added_edges = add_nodes_and_edges(
data_point, field_name, item, nodes, edges, added_nodes, added_edges if property_key in visited_properties:
continue
visited_properties[property_key] = True
nodes, edges = await add_nodes_and_edges(
data_point,
field_name,
field_value_item,
nodes,
edges,
added_nodes,
added_edges,
visited_properties,
) )
edges = edges[:n_edges_before] + [
(*edge[:3], {**edge[3], "metadata": {"type": "list"}})
for edge in edges[n_edges_before:]
]
else:
data_point_properties[field_name] = field_value
SimpleDataPointModel = copy_model( continue
type(data_point),
include_fields={
"_metadata": (dict, data_point._metadata),
},
exclude_fields=excluded_properties,
)
nodes.append(SimpleDataPointModel(**data_point_properties)) data_point_properties[field_name] = field_value
if include_root:
SimpleDataPointModel = copy_model(
type(data_point),
include_fields = {
"_metadata": (dict, data_point._metadata),
"__tablename__": data_point.__tablename__,
},
exclude_fields = excluded_properties,
)
nodes.append(SimpleDataPointModel(**data_point_properties))
added_nodes[str(data_point.id)] = True
return nodes, edges return nodes, edges
def add_nodes_and_edges( async def add_nodes_and_edges(
data_point, field_name, field_value, nodes, edges, added_nodes, added_edges data_point,
field_name,
field_value,
nodes,
edges,
added_nodes,
added_edges,
visited_properties,
): ):
property_nodes, property_edges = await get_graph_from_model(
property_nodes, property_edges = get_graph_from_model( field_value,
field_value, dict(added_nodes), dict(added_edges) True,
added_nodes,
added_edges,
visited_properties,
) )
for node in property_nodes: for node in property_nodes:
@ -105,7 +145,7 @@ def add_nodes_and_edges(
) )
added_edges[str(edge_key)] = True added_edges[str(edge_key)] = True
return (nodes, edges, added_nodes, added_edges) return (nodes, edges)
def get_own_properties(property_nodes, property_edges): def get_own_properties(property_nodes, property_edges):

View file

@ -0,0 +1,55 @@
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
from cognee.infrastructure.engine import DataPoint
from cognee.modules.engine.utils import generate_node_id
from cognee.shared.data_models import KnowledgeGraph
async def retrieve_existing_edges(
graph_node_index: list[tuple[DataPoint, KnowledgeGraph]],
graph_engine: GraphDBInterface,
) -> dict[str, bool]:
processed_nodes = {}
type_node_edges = []
entity_node_edges = []
type_entity_edges = []
for graph_source, graph in graph_node_index:
for node in graph.nodes:
type_node_id = generate_node_id(node.type)
entity_node_id = generate_node_id(node.id)
if str(type_node_id) not in processed_nodes:
type_node_edges.append(
(str(graph_source), str(type_node_id), "exists_in")
)
processed_nodes[str(type_node_id)] = True
if str(entity_node_id) not in processed_nodes:
entity_node_edges.append(
(str(graph_source), entity_node_id, "mentioned_in")
)
type_entity_edges.append(
(str(entity_node_id), str(type_node_id), "is_a")
)
processed_nodes[str(entity_node_id)] = True
graph_node_edges = [
(edge.target_node_id, edge.source_node_id, edge.relationship_name)
for edge in graph.edges
]
existing_edges = await graph_engine.has_edges(
[
*type_node_edges,
*entity_node_edges,
*type_entity_edges,
*graph_node_edges,
]
)
existing_edges_map = {}
for edge in existing_edges:
existing_edges_map[edge[0] + edge[1] + edge[2]] = True
return existing_edges_map

View file

@ -1,13 +1,15 @@
import asyncio import asyncio
import logging import logging
from typing import List from typing import List
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_default_user
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
from cognee.modules.users.methods import get_default_user
from cognee.modules.users.models import User
from cognee.shared.utils import send_telemetry from cognee.shared.utils import send_telemetry
def format_triplets(edges): def format_triplets(edges):
print("\n\n\n") print("\n\n\n")
def filter_attributes(obj, attributes): def filter_attributes(obj, attributes):
@ -48,16 +50,14 @@ def format_triplets(edges):
return "".join(triplets) return "".join(triplets)
async def brute_force_triplet_search(query: str, user: User = None, top_k = 5) -> list: async def brute_force_triplet_search(query: str, user: User = None, top_k = 5, collections = None) -> list:
if user is None: if user is None:
user = await get_default_user() user = await get_default_user()
if user is None: if user is None:
raise PermissionError("No user found in the system. Please create a user.") raise PermissionError("No user found in the system. Please create a user.")
retrieved_results = await brute_force_search(query, user, top_k) retrieved_results = await brute_force_search(query, user, top_k, collections=collections)
return retrieved_results return retrieved_results

View file

@ -0,0 +1,40 @@
from typing import List, Optional
from cognee.infrastructure.engine import DataPoint
class Repository(DataPoint):
__tablename__ = "Repository"
path: str
type: Optional[str] = "Repository"
class CodeFile(DataPoint):
__tablename__ = "CodeFile"
extracted_id: str # actually file path
type: Optional[str] = "CodeFile"
source_code: Optional[str] = None
part_of: Optional[Repository] = None
depends_on: Optional[List["CodeFile"]] = None
depends_directly_on: Optional[List["CodeFile"]] = None
contains: Optional[List["CodePart"]] = None
_metadata: dict = {
"index_fields": ["source_code"]
}
class CodePart(DataPoint):
__tablename__ = "CodePart"
# part_of: Optional[CodeFile]
source_code: str
type: Optional[str] = "CodePart"
_metadata: dict = {
"index_fields": ["source_code"]
}
class CodeRelationship(DataPoint):
source_id: str
target_id: str
type: str # between files
relation: str # depends on or depends directly
CodeFile.model_rebuild()
CodePart.model_rebuild()

View file

@ -0,0 +1,27 @@
import os
import asyncio
import argparse
from cognee.tasks.repo_processor.get_repo_file_dependencies import get_repo_file_dependencies
from cognee.tasks.repo_processor.enrich_dependency_graph import enrich_dependency_graph
def main():
parser = argparse.ArgumentParser()
parser.add_argument("repo_path", help="Path to the repository")
args = parser.parse_args()
repo_path = args.repo_path
if not os.path.exists(repo_path):
print(f"Error: The provided repository path does not exist: {repo_path}")
return
graph = asyncio.run(get_repo_file_dependencies(repo_path))
graph = asyncio.run(enrich_dependency_graph(graph))
for node in graph.nodes:
print(f"Node: {node}")
for _, target, data in graph.out_edges(node, data=True):
print(f" Edge to {target}, data: {data}")
if __name__ == "__main__":
main()

View file

@ -0,0 +1,29 @@
import os
import asyncio
import argparse
from cognee.tasks.repo_processor.get_repo_file_dependencies import get_repo_file_dependencies
from cognee.tasks.repo_processor.enrich_dependency_graph import enrich_dependency_graph
from cognee.tasks.repo_processor.expand_dependency_graph import expand_dependency_graph
def main():
parser = argparse.ArgumentParser()
parser.add_argument("repo_path", help="Path to the repository")
args = parser.parse_args()
repo_path = args.repo_path
if not os.path.exists(repo_path):
print(f"Error: The provided repository path does not exist: {repo_path}")
return
graph = asyncio.run(get_repo_file_dependencies(repo_path))
graph = asyncio.run(enrich_dependency_graph(graph))
graph = expand_dependency_graph(graph)
for node in graph.nodes:
print(f"Node: {node}")
for _, target, data in graph.out_edges(node, data=True):
print(f" Edge to {target}, data: {data}")
if __name__ == "__main__":
main()

View file

@ -0,0 +1,27 @@
import os
import asyncio
import argparse
from cognee.tasks.repo_processor.get_repo_file_dependencies import get_repo_file_dependencies
def main():
parser = argparse.ArgumentParser()
parser.add_argument("repo_path", help="Path to the repository")
args = parser.parse_args()
repo_path = args.repo_path
if not os.path.exists(repo_path):
print(f"Error: The provided repository path does not exist: {repo_path}")
return
graph = asyncio.run(get_repo_file_dependencies(repo_path))
for node in graph.nodes:
print(f"Node: {node}")
edges = graph.edges(node, data=True)
for _, target, data in edges:
print(f" Edge to {target}, Relation: {data.get('relation')}")
if __name__ == "__main__":
main()

View file

@ -6,6 +6,7 @@ from cognee.modules.data.processing.document_types import (
ImageDocument, ImageDocument,
TextDocument, TextDocument,
) )
from cognee.modules.data.operations.get_metadata import get_metadata
EXTENSION_TO_DOCUMENT_CLASS = { EXTENSION_TO_DOCUMENT_CLASS = {
"pdf": PdfDocument, # Text documents "pdf": PdfDocument, # Text documents
@ -38,14 +39,17 @@ EXTENSION_TO_DOCUMENT_CLASS = {
} }
def classify_documents(data_documents: list[Data]) -> list[Document]: async def classify_documents(data_documents: list[Data]) -> list[Document]:
documents = [ documents = []
EXTENSION_TO_DOCUMENT_CLASS[data_item.extension]( for data_item in data_documents:
metadata = await get_metadata(data_item.id)
document = EXTENSION_TO_DOCUMENT_CLASS[data_item.extension](
id=data_item.id, id=data_item.id,
title=f"{data_item.name}.{data_item.extension}", title=f"{data_item.name}.{data_item.extension}",
raw_data_location=data_item.raw_data_location, raw_data_location=data_item.raw_data_location,
name=data_item.name, name=data_item.name,
metadata_id=metadata.id
) )
for data_item in data_documents documents.append(document)
]
return documents return documents

View file

@ -1,119 +1,40 @@
import asyncio import asyncio
from typing import Type from typing import Type
from pydantic import BaseModel from pydantic import BaseModel
from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.data.extraction.knowledge_graph import extract_content_graph
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
from cognee.modules.engine.models import EntityType, Entity from cognee.modules.data.extraction.knowledge_graph import extract_content_graph
from cognee.modules.engine.utils import generate_edge_name, generate_node_id, generate_node_name from cognee.modules.graph.utils import (
expand_with_nodes_and_edges,
retrieve_existing_edges,
)
from cognee.tasks.storage import add_data_points from cognee.tasks.storage import add_data_points
async def extract_graph_from_data(data_chunks: list[DocumentChunk], graph_model: Type[BaseModel]):
async def extract_graph_from_data(
data_chunks: list[DocumentChunk], graph_model: Type[BaseModel]
):
chunk_graphs = await asyncio.gather( chunk_graphs = await asyncio.gather(
*[extract_content_graph(chunk.text, graph_model) for chunk in data_chunks] *[extract_content_graph(chunk.text, graph_model) for chunk in data_chunks]
) )
processed_nodes = {}
type_node_edges = []
entity_node_edges = []
type_entity_edges = []
for (chunk_index, chunk) in enumerate(data_chunks):
chunk_graph = chunk_graphs[chunk_index]
for node in chunk_graph.nodes:
type_node_id = generate_node_id(node.type)
entity_node_id = generate_node_id(node.id)
if str(type_node_id) not in processed_nodes:
type_node_edges.append((str(chunk.id), str(type_node_id), "exists_in"))
processed_nodes[str(type_node_id)] = True
if str(entity_node_id) not in processed_nodes:
entity_node_edges.append((str(chunk.id), entity_node_id, "mentioned_in"))
type_entity_edges.append((str(entity_node_id), str(type_node_id), "is_a"))
processed_nodes[str(entity_node_id)] = True
graph_node_edges = [
(edge.target_node_id, edge.source_node_id, edge.relationship_name) \
for edge in chunk_graph.edges
]
graph_engine = await get_graph_engine() graph_engine = await get_graph_engine()
chunk_and_chunk_graphs = [
(chunk, chunk_graph) for chunk, chunk_graph in zip(data_chunks, chunk_graphs)
]
existing_edges_map = await retrieve_existing_edges(
chunk_and_chunk_graphs,
graph_engine,
)
existing_edges = await graph_engine.has_edges([ graph_nodes, graph_edges = expand_with_nodes_and_edges(
*type_node_edges, chunk_and_chunk_graphs,
*entity_node_edges, existing_edges_map,
*type_entity_edges, )
*graph_node_edges,
])
existing_edges_map = {} if len(graph_nodes) > 0:
await add_data_points(graph_nodes)
for edge in existing_edges:
existing_edges_map[edge[0] + edge[1] + edge[2]] = True
added_nodes_map = {}
graph_edges = []
data_points = []
for (chunk_index, chunk) in enumerate(data_chunks):
graph = chunk_graphs[chunk_index]
if graph is None:
continue
for node in graph.nodes:
node_id = generate_node_id(node.id)
node_name = generate_node_name(node.name)
type_node_id = generate_node_id(node.type)
type_node_name = generate_node_name(node.type)
if f"{str(type_node_id)}_type" not in added_nodes_map:
type_node = EntityType(
id = type_node_id,
name = type_node_name,
type = type_node_name,
description = type_node_name,
exists_in = chunk,
)
added_nodes_map[f"{str(type_node_id)}_type"] = type_node
else:
type_node = added_nodes_map[f"{str(type_node_id)}_type"]
if f"{str(node_id)}_entity" not in added_nodes_map:
entity_node = Entity(
id = node_id,
name = node_name,
is_a = type_node,
description = node.description,
mentioned_in = chunk,
)
data_points.append(entity_node)
added_nodes_map[f"{str(node_id)}_entity"] = entity_node
# Add relationship that came from graphs.
for edge in graph.edges:
source_node_id = generate_node_id(edge.source_node_id)
target_node_id = generate_node_id(edge.target_node_id)
relationship_name = generate_edge_name(edge.relationship_name)
edge_key = str(source_node_id) + str(target_node_id) + relationship_name
if edge_key not in existing_edges_map:
graph_edges.append((
source_node_id,
target_node_id,
edge.relationship_name,
dict(
relationship_name = generate_edge_name(edge.relationship_name),
source_node_id = source_node_id,
target_node_id = target_node_id,
),
))
existing_edges_map[edge_key] = True
if len(data_points) > 0:
await add_data_points(data_points)
if len(graph_edges) > 0: if len(graph_edges) > 0:
await graph_engine.add_edges(graph_edges) await graph_engine.add_edges(graph_edges)

View file

@ -2,3 +2,4 @@ from .ingest_data import ingest_data
from .save_data_to_storage import save_data_to_storage from .save_data_to_storage import save_data_to_storage
from .save_data_item_to_storage import save_data_item_to_storage from .save_data_item_to_storage import save_data_item_to_storage
from .save_data_item_with_metadata_to_storage import save_data_item_with_metadata_to_storage from .save_data_item_with_metadata_to_storage import save_data_item_with_metadata_to_storage
from .ingest_data_with_metadata import ingest_data_with_metadata

View file

@ -1,13 +1,20 @@
from typing import Any
import dlt import dlt
import cognee.modules.ingestion as ingestion import cognee.modules.ingestion as ingestion
from typing import Any
from cognee.shared.utils import send_telemetry
from cognee.modules.users.models import User
from cognee.infrastructure.databases.relational import get_relational_engine from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.modules.data.methods import create_dataset from cognee.modules.data.methods import create_dataset
from cognee.modules.data.operations.delete_metadata import delete_metadata
from cognee.modules.users.models import User
from cognee.modules.users.permissions.methods import give_permission_on_document from cognee.modules.users.permissions.methods import give_permission_on_document
from cognee.shared.utils import send_telemetry
from cognee.modules.data.operations.write_metadata import write_metadata
from .get_dlt_destination import get_dlt_destination from .get_dlt_destination import get_dlt_destination
from .save_data_item_with_metadata_to_storage import save_data_item_with_metadata_to_storage from .save_data_item_with_metadata_to_storage import (
save_data_item_with_metadata_to_storage,
)
async def ingest_data_with_metadata(data: Any, dataset_name: str, user: User): async def ingest_data_with_metadata(data: Any, dataset_name: str, user: User):
destination = get_dlt_destination() destination = get_dlt_destination()
@ -25,8 +32,9 @@ async def ingest_data_with_metadata(data: Any, dataset_name: str, user: User):
# Process data # Process data
for data_item in data: for data_item in data:
file_path = await save_data_item_with_metadata_to_storage(
file_path = save_data_item_with_metadata_to_storage(data_item, dataset_name) data_item, dataset_name
)
# Ingest data and add metadata # Ingest data and add metadata
with open(file_path.replace("file://", ""), mode = "rb") as file: with open(file_path.replace("file://", ""), mode = "rb") as file:
@ -37,6 +45,7 @@ async def ingest_data_with_metadata(data: Any, dataset_name: str, user: User):
file_metadata = classified_data.get_metadata() file_metadata = classified_data.get_metadata()
from sqlalchemy import select from sqlalchemy import select
from cognee.modules.data.models import Data from cognee.modules.data.models import Data
db_engine = get_relational_engine() db_engine = get_relational_engine()
@ -44,29 +53,30 @@ async def ingest_data_with_metadata(data: Any, dataset_name: str, user: User):
async with db_engine.get_async_session() as session: async with db_engine.get_async_session() as session:
dataset = await create_dataset(dataset_name, user.id, session) dataset = await create_dataset(dataset_name, user.id, session)
data_point = (await session.execute( data_point = (
select(Data).filter(Data.id == data_id) await session.execute(select(Data).filter(Data.id == data_id))
)).scalar_one_or_none() ).scalar_one_or_none()
if data_point is not None: if data_point is not None:
await delete_metadata(data_point.metadata_id)
data_point.name = file_metadata["name"] data_point.name = file_metadata["name"]
data_point.raw_data_location = file_metadata["file_path"] data_point.raw_data_location = file_metadata["file_path"]
data_point.extension = file_metadata["extension"] data_point.extension = file_metadata["extension"]
data_point.mime_type = file_metadata["mime_type"] data_point.mime_type = file_metadata["mime_type"]
await session.merge(data_point) await session.merge(data_point)
await session.commit()
else: else:
data_point = Data( data_point = Data(
id = data_id, id = data_id,
name = file_metadata["name"], name = file_metadata["name"],
raw_data_location = file_metadata["file_path"], raw_data_location = file_metadata["file_path"],
extension = file_metadata["extension"], extension = file_metadata["extension"],
mime_type = file_metadata["mime_type"], mime_type = file_metadata["mime_type"]
) )
dataset.data.append(data_point) dataset.data.append(data_point)
await session.commit() await session.commit()
await write_metadata(data_item, data_point.id, file_metadata)
yield { yield {
"id": data_id, "id": data_id,
@ -79,14 +89,13 @@ async def ingest_data_with_metadata(data: Any, dataset_name: str, user: User):
await give_permission_on_document(user, data_id, "read") await give_permission_on_document(user, data_id, "read")
await give_permission_on_document(user, data_id, "write") await give_permission_on_document(user, data_id, "write")
send_telemetry("cognee.add EXECUTION STARTED", user_id=user.id)
send_telemetry("cognee.add EXECUTION STARTED", user_id = user.id)
run_info = pipeline.run( run_info = pipeline.run(
data_resources(data, user), data_resources(data, user),
table_name = "file_metadata", table_name = "file_metadata",
dataset_name = dataset_name, dataset_name = dataset_name,
write_disposition = "merge", write_disposition = "merge",
) )
send_telemetry("cognee.add EXECUTION COMPLETED", user_id = user.id) send_telemetry("cognee.add EXECUTION COMPLETED", user_id=user.id)
return run_info return run_info

View file

@ -3,19 +3,23 @@ from typing import Union, BinaryIO, Any
from cognee.modules.ingestion.exceptions import IngestionError from cognee.modules.ingestion.exceptions import IngestionError
from cognee.modules.ingestion import save_data_to_file from cognee.modules.ingestion import save_data_to_file
def save_data_item_with_metadata_to_storage(data_item: Union[BinaryIO, str, Any], dataset_name: str) -> str:
# Dynamic import is used because the llama_index module is optional.
# For the same reason Any is accepted as a data item
from llama_index.core import Document
from .transform_data import get_data_from_llama_index
async def save_data_item_with_metadata_to_storage(
data_item: Union[BinaryIO, str, Any], dataset_name: str
) -> str:
# Dynamic import is used because the llama_index module is optional.
# For the same reason Any is accepted as a data item
# Check if data is of type Document or any of it's subclasses # Check if data is of type Document or any of it's subclasses
if isinstance(data_item, Document): if str(type(data_item)).startswith("llama_index"):
from .transform_data import get_data_from_llama_index
file_path = get_data_from_llama_index(data_item, dataset_name) file_path = get_data_from_llama_index(data_item, dataset_name)
# data is a file object coming from upload. # data is a file object coming from upload.
elif hasattr(data_item, "file"): elif hasattr(data_item, "file"):
file_path = save_data_to_file(data_item.file, dataset_name, filename=data_item.filename) file_path = save_data_to_file(
data_item.file, dataset_name, filename=data_item.filename
)
elif isinstance(data_item, str): elif isinstance(data_item, str):
# data is a file path # data is a file path
@ -27,4 +31,4 @@ def save_data_item_with_metadata_to_storage(data_item: Union[BinaryIO, str, Any]
else: else:
raise IngestionError(message=f"Data type not supported: {type(data_item)}") raise IngestionError(message=f"Data type not supported: {type(data_item)}")
return file_path return file_path

View file

@ -1,3 +1,7 @@
import logging import logging
logger = logging.getLogger("task:repo_processor") logger = logging.getLogger("task:repo_processor")
from .enrich_dependency_graph import enrich_dependency_graph
from .expand_dependency_graph import expand_dependency_graph
from .get_repo_file_dependencies import get_repo_file_dependencies

View file

@ -0,0 +1,129 @@
import networkx as nx
from typing import AsyncGenerator, Dict, List
from tqdm.asyncio import tqdm
from cognee.infrastructure.engine import DataPoint
from cognee.shared.CodeGraphEntities import CodeFile
from cognee.modules.graph.utils import get_graph_from_model, convert_node_to_data_point
from cognee.infrastructure.databases.graph import get_graph_engine
def topologically_sort_subgraph(subgraph_node_to_indegree: Dict[str, int], graph: nx.DiGraph) -> List[str]:
"""Performs a topological sort on a subgraph based on node indegrees."""
results = []
remaining_nodes = subgraph_node_to_indegree.copy()
while remaining_nodes:
next_node = min(remaining_nodes, key=remaining_nodes.get)
results.append(next_node)
for successor in graph.successors(next_node):
if successor in remaining_nodes:
remaining_nodes[successor] -= 1
remaining_nodes.pop(next_node)
return results
def topologically_sort(graph: nx.DiGraph) -> List[str]:
"""Performs a topological sort on the entire graph."""
subgraphs = (graph.subgraph(c).copy() for c in nx.weakly_connected_components(graph))
topological_order = []
for subgraph in subgraphs:
node_to_indegree = {
node: len(list(subgraph.successors(node)))
for node in subgraph.nodes
}
topological_order.extend(
topologically_sort_subgraph(node_to_indegree, subgraph)
)
return topological_order
async def node_enrich_and_connect(
graph: nx.MultiDiGraph,
topological_order: List[str],
node: CodeFile,
data_points_map: Dict[str, DataPoint],
) -> None:
"""Adds 'depends_on' edges to the graph based on topological order."""
topological_rank = topological_order.index(node.id)
node.topological_rank = topological_rank
node_descendants = nx.descendants(graph, node.id)
if graph.has_edge(node.id, node.id):
node_descendants.add(node.id)
new_connections = []
graph_engine = await get_graph_engine()
for desc_id in node_descendants:
if desc_id not in topological_order[:topological_rank + 1]:
continue
desc = None
if desc_id in data_points_map:
desc = data_points_map[desc_id]
else:
node_data = await graph_engine.extract_node(desc_id)
try:
desc = convert_node_to_data_point(node_data)
except Exception:
pass
if desc is not None:
new_connections.append(desc)
node.depends_directly_on = node.depends_directly_on or []
node.depends_directly_on.extend(new_connections)
async def enrich_dependency_graph(data_points: list[DataPoint]) -> AsyncGenerator[list[DataPoint], None]:
"""Enriches the graph with topological ranks and 'depends_on' edges."""
nodes = []
edges = []
for data_point in data_points:
graph_nodes, graph_edges = await get_graph_from_model(data_point)
nodes.extend(graph_nodes)
edges.extend(graph_edges)
graph = nx.MultiDiGraph()
simple_nodes = [(node.id, node.model_dump()) for node in nodes]
graph.add_nodes_from(simple_nodes)
graph.add_edges_from(edges)
topological_order = topologically_sort(graph)
node_rank_map = {node: idx for idx, node in enumerate(topological_order)}
# for node_id, node in tqdm(graph.nodes(data = True), desc = "Enriching dependency graph", unit = "node"):
# if node_id not in node_rank_map:
# continue
# data_points.append(node_enrich_and_connect(graph, topological_order, node))
data_points_map = {data_point.id: data_point for data_point in data_points}
# data_points_futures = []
for data_point in tqdm(data_points, desc = "Enriching dependency graph", unit = "data_point"):
if data_point.id not in node_rank_map:
continue
if isinstance(data_point, CodeFile):
# data_points_futures.append(node_enrich_and_connect(graph, topological_order, data_point, data_points_map))
await node_enrich_and_connect(graph, topological_order, data_point, data_points_map)
yield data_point
# await asyncio.gather(*data_points_futures)
# return data_points

View file

@ -0,0 +1,65 @@
from typing import AsyncGenerator
from uuid import NAMESPACE_OID, uuid5
# from tqdm import tqdm
from cognee.infrastructure.engine import DataPoint
from cognee.shared.CodeGraphEntities import CodeFile, CodePart
from cognee.tasks.repo_processor.extract_code_parts import extract_code_parts
from cognee.tasks.repo_processor import logger
def _add_code_parts_nodes_and_edges(code_file: CodeFile, part_type, code_parts) -> None:
"""Add code part nodes and edges for a specific part type."""
if not code_parts:
logger.debug(f"No code parts to add for node {code_file.id} and part_type {part_type}.")
return
part_nodes = []
for idx, code_part in enumerate(code_parts):
if not code_part.strip():
logger.warning(f"Empty code part in node {code_file.id} and part_type {part_type}.")
continue
part_node_id = uuid5(NAMESPACE_OID, f"{code_file.id}_{part_type}_{idx}")
part_nodes.append(CodePart(
id = part_node_id,
type = part_type,
# part_of = code_file,
source_code = code_part,
))
# graph.add_node(part_node_id, source_code=code_part, node_type=part_type)
# graph.add_edge(parent_node_id, part_node_id, relation="contains")
code_file.contains = code_file.contains or []
code_file.contains.extend(part_nodes)
def _process_single_node(code_file: CodeFile) -> None:
"""Process a single Python file node."""
node_id = code_file.id
source_code = code_file.source_code
if not source_code.strip():
logger.warning(f"Node {node_id} has no or empty 'source_code'. Skipping.")
return
try:
code_parts_dict = extract_code_parts(source_code)
except Exception as e:
logger.error(f"Error processing node {node_id}: {e}")
return
for part_type, code_parts in code_parts_dict.items():
_add_code_parts_nodes_and_edges(code_file, part_type, code_parts)
async def expand_dependency_graph(data_points: list[DataPoint]) -> AsyncGenerator[list[DataPoint], None]:
"""Process Python file nodes, adding code part nodes and edges."""
# for data_point in tqdm(data_points, desc = "Expand dependency graph", unit = "data_point"):
for data_point in data_points:
if isinstance(data_point, CodeFile):
_process_single_node(data_point)
yield data_point
# return data_points

View file

@ -0,0 +1,59 @@
from typing import Dict, List
import parso
from cognee.tasks.repo_processor import logger
def _extract_parts_from_module(module, parts_dict: Dict[str, List[str]]) -> Dict[str, List[str]]:
"""Extract code parts from a parsed module."""
current_top_level_code = []
child_to_code_type = {
'classdef': "classes",
'funcdef': "functions",
'import_name': "imports",
'import_from': "imports",
}
for child in module.children:
if child.type == 'simple_stmt':
current_top_level_code.append(child.get_code())
continue
if current_top_level_code:
parts_dict["top_level_code"].append('\n'.join(current_top_level_code))
current_top_level_code = []
if child.type in child_to_code_type:
code_type = child_to_code_type[child.type]
parts_dict[code_type].append(child.get_code())
if current_top_level_code:
parts_dict["top_level_code"].append('\n'.join(current_top_level_code))
if parts_dict["imports"]:
parts_dict["imports"] = ['\n'.join(parts_dict["imports"])]
return parts_dict
def extract_code_parts(source_code: str) -> Dict[str, List[str]]:
"""Extract high-level parts of the source code."""
parts_dict = {"classes": [], "functions": [], "imports": [], "top_level_code": []}
if not source_code.strip():
logger.warning("Empty source_code provided.")
return parts_dict
try:
module = parso.parse(source_code)
except Exception as e:
logger.error(f"Error parsing source code: {e}")
return parts_dict
if not module.children:
logger.warning("Parsed module has no children (empty or invalid source code).")
return parts_dict
return _extract_parts_from_module(module, parts_dict)

View file

@ -0,0 +1,99 @@
import os
from typing import AsyncGenerator
from uuid import NAMESPACE_OID, uuid5
import aiofiles
from concurrent.futures import ProcessPoolExecutor
import asyncio
from cognee.shared.CodeGraphEntities import CodeFile, Repository
from cognee.tasks.repo_processor.get_local_dependencies import get_local_script_dependencies
async def get_py_path_and_source(file_path):
try:
async with aiofiles.open(file_path, "r", encoding="utf-8") as f:
source_code = await f.read()
return file_path, source_code
except Exception as e:
print(f"Error reading file {file_path}: {e}")
return file_path, None
async def get_py_files_dict(repo_path):
"""Get .py files and their source code"""
if not os.path.exists(repo_path):
return {}
py_files_paths = (
os.path.join(root, file)
for root, _, files in os.walk(repo_path) for file in files if file.endswith(".py")
)
py_files_dict = {}
for file_path in py_files_paths:
absolute_path = os.path.abspath(file_path)
relative_path, source_code = await get_py_path_and_source(absolute_path)
py_files_dict[relative_path] = {"source_code": source_code}
return py_files_dict
def get_edge(file_path: str, dependency: str, repo_path: str, relative_paths: bool = False) -> tuple:
if relative_paths:
file_path = os.path.relpath(file_path, repo_path)
dependency = os.path.relpath(dependency, repo_path)
return (file_path, dependency, {"relation": "depends_directly_on"})
def run_coroutine(coroutine_func, *args, **kwargs):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(coroutine_func(*args, **kwargs))
loop.close()
return result
async def get_repo_file_dependencies(repo_path: str) -> AsyncGenerator[list, None]:
"""Generate a dependency graph for Python files in the given repository path."""
py_files_dict = await get_py_files_dict(repo_path)
repo = Repository(
id = uuid5(NAMESPACE_OID, repo_path),
path = repo_path,
)
yield repo
with ProcessPoolExecutor(max_workers = 12) as executor:
loop = asyncio.get_event_loop()
tasks = [
loop.run_in_executor(
executor,
run_coroutine,
get_local_script_dependencies,
os.path.join(repo_path, file_path),
repo_path
)
for file_path, metadata in py_files_dict.items()
if metadata.get("source_code") is not None
]
results = await asyncio.gather(*tasks)
for (file_path, metadata), dependencies in zip(py_files_dict.items(), results):
source_code = metadata.get("source_code")
yield CodeFile(
id = uuid5(NAMESPACE_OID, file_path),
source_code = source_code,
extracted_id = file_path,
part_of = repo,
depends_on = [
CodeFile(
id = uuid5(NAMESPACE_OID, dependency),
extracted_id = dependency,
part_of = repo,
source_code = py_files_dict.get(dependency, {}).get("source_code"),
) for dependency in dependencies
] if dependencies else None,
)

View file

@ -1,6 +1,7 @@
import asyncio
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.graph.utils import get_graph_from_model from cognee.modules.graph.utils import deduplicate_nodes_and_edges, get_graph_from_model
from .index_data_points import index_data_points from .index_data_points import index_data_points
@ -8,15 +9,26 @@ async def add_data_points(data_points: list[DataPoint]):
nodes = [] nodes = []
edges = [] edges = []
for data_point in data_points: added_nodes = {}
property_nodes, property_edges = get_graph_from_model(data_point) added_edges = {}
nodes.extend(property_nodes) results = await asyncio.gather(*[
edges.extend(property_edges) get_graph_from_model(
data_point,
added_nodes = added_nodes,
added_edges = added_edges,
) for data_point in data_points
])
for result_nodes, result_edges in results:
nodes.extend(result_nodes)
edges.extend(result_edges)
nodes, edges = deduplicate_nodes_and_edges(nodes, edges)
graph_engine = await get_graph_engine() graph_engine = await get_graph_engine()
await index_data_points(data_points) await index_data_points(nodes)
await graph_engine.add_nodes(nodes) await graph_engine.add_nodes(nodes)
await graph_engine.add_edges(edges) await graph_engine.add_edges(edges)

View file

@ -7,15 +7,13 @@ async def index_data_points(data_points: list[DataPoint]):
vector_engine = get_vector_engine() vector_engine = get_vector_engine()
flat_data_points: list[DataPoint] = []
for data_point in data_points: for data_point in data_points:
flat_data_points.extend(get_data_points_from_model(data_point))
for data_point in flat_data_points:
data_point_type = type(data_point) data_point_type = type(data_point)
for field_name in data_point._metadata["index_fields"]: for field_name in data_point._metadata["index_fields"]:
if getattr(data_point, field_name, None) is None:
continue
index_name = f"{data_point_type.__tablename__}.{field_name}" index_name = f"{data_point_type.__tablename__}.{field_name}"
if index_name not in created_indexes: if index_name not in created_indexes:
@ -35,12 +33,21 @@ async def index_data_points(data_points: list[DataPoint]):
return data_points return data_points
def get_data_points_from_model(data_point: DataPoint, added_data_points = {}) -> list[DataPoint]: async def get_data_points_from_model(data_point: DataPoint, added_data_points = None, visited_properties = None) -> list[DataPoint]:
data_points = [] data_points = []
added_data_points = added_data_points or {}
visited_properties = visited_properties or {}
for field_name, field_value in data_point: for field_name, field_value in data_point:
if isinstance(field_value, DataPoint): if isinstance(field_value, DataPoint):
new_data_points = get_data_points_from_model(field_value, added_data_points) property_key = f"{str(data_point.id)}{field_name}{str(field_value.id)}"
if property_key in visited_properties:
return []
visited_properties[property_key] = True
new_data_points = await get_data_points_from_model(field_value, added_data_points, visited_properties)
for new_point in new_data_points: for new_point in new_data_points:
if str(new_point.id) not in added_data_points: if str(new_point.id) not in added_data_points:
@ -49,7 +56,14 @@ def get_data_points_from_model(data_point: DataPoint, added_data_points = {}) ->
if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint): if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint):
for field_value_item in field_value: for field_value_item in field_value:
new_data_points = get_data_points_from_model(field_value_item, added_data_points) property_key = f"{str(data_point.id)}{field_name}{str(field_value_item.id)}"
if property_key in visited_properties:
return []
visited_properties[property_key] = True
new_data_points = await get_data_points_from_model(field_value_item, added_data_points, visited_properties)
for new_point in new_data_points: for new_point in new_data_points:
if str(new_point.id) not in added_data_points: if str(new_point.id) not in added_data_points:
@ -79,4 +93,3 @@ if __name__ == "__main__":
data_points = get_data_points_from_model(person) data_points = get_data_points_from_model(person)
print(data_points) print(data_points)

View file

@ -1,2 +1,3 @@
from .summarize_text import summarize_text
from .query_summaries import query_summaries from .query_summaries import query_summaries
from .summarize_code import summarize_code
from .summarize_text import summarize_text

View file

@ -1,6 +1,8 @@
from cognee.infrastructure.engine import DataPoint from cognee.infrastructure.engine import DataPoint
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
from cognee.modules.data.processing.document_types import Document from cognee.modules.data.processing.document_types import Document
from cognee.shared.CodeGraphEntities import CodeFile
class TextSummary(DataPoint): class TextSummary(DataPoint):
__tablename__ = "text_summary" __tablename__ = "text_summary"
@ -10,3 +12,12 @@ class TextSummary(DataPoint):
_metadata: dict = { _metadata: dict = {
"index_fields": ["text"], "index_fields": ["text"],
} }
class CodeSummary(DataPoint):
text: str
made_from: CodeFile
_metadata: dict = {
"index_fields": ["text"],
}

View file

@ -0,0 +1,39 @@
import asyncio
from typing import Type
from uuid import uuid5
from pydantic import BaseModel
from cognee.infrastructure.engine import DataPoint
from cognee.modules.data.extraction.extract_summary import extract_summary
from cognee.shared.CodeGraphEntities import CodeFile
from cognee.tasks.storage import add_data_points
from .models import CodeSummary
async def summarize_code(
code_files: list[DataPoint],
summarization_model: Type[BaseModel],
) -> list[DataPoint]:
if len(code_files) == 0:
return code_files
code_files_data_points = [file for file in code_files if isinstance(file, CodeFile)]
file_summaries = await asyncio.gather(
*[extract_summary(file.source_code, summarization_model) for file in code_files_data_points]
)
summaries = [
CodeSummary(
id = uuid5(file.id, "CodeSummary"),
made_from = file,
text = file_summaries[file_index].summary,
)
for (file_index, file) in enumerate(code_files_data_points)
]
await add_data_points(summaries)
return code_files

View file

@ -27,7 +27,7 @@ TEST_TEXT = """
def test_AudioDocument(): def test_AudioDocument():
document = AudioDocument( document = AudioDocument(
id=uuid.uuid4(), name="audio-dummy-test", raw_data_location="" id=uuid.uuid4(), name="audio-dummy-test", raw_data_location="", metadata_id=uuid.uuid4()
) )
with patch.object(AudioDocument, "create_transcript", return_value=TEST_TEXT): with patch.object(AudioDocument, "create_transcript", return_value=TEST_TEXT):
for ground_truth, paragraph_data in zip( for ground_truth, paragraph_data in zip(

View file

@ -16,7 +16,7 @@ The commotion has attracted an audience: a murder of crows has gathered in the l
def test_ImageDocument(): def test_ImageDocument():
document = ImageDocument( document = ImageDocument(
id=uuid.uuid4(), name="image-dummy-test", raw_data_location="" id=uuid.uuid4(), name="image-dummy-test", raw_data_location="", metadata_id=uuid.uuid4()
) )
with patch.object(ImageDocument, "transcribe_image", return_value=TEST_TEXT): with patch.object(ImageDocument, "transcribe_image", return_value=TEST_TEXT):

View file

@ -17,7 +17,7 @@ def test_PdfDocument():
"artificial-intelligence.pdf", "artificial-intelligence.pdf",
) )
document = PdfDocument( document = PdfDocument(
id=uuid.uuid4(), name="Test document.pdf", raw_data_location=test_file_path id=uuid.uuid4(), name="Test document.pdf", raw_data_location=test_file_path, metadata_id=uuid.uuid4()
) )
for ground_truth, paragraph_data in zip( for ground_truth, paragraph_data in zip(

View file

@ -29,7 +29,7 @@ def test_TextDocument(input_file, chunk_size):
input_file, input_file,
) )
document = TextDocument( document = TextDocument(
id=uuid.uuid4(), name=input_file, raw_data_location=test_file_path id=uuid.uuid4(), name=input_file, raw_data_location=test_file_path, metadata_id=uuid.uuid4()
) )
for ground_truth, paragraph_data in zip( for ground_truth, paragraph_data in zip(

View file

@ -0,0 +1,15 @@
import asyncio
from cognee.shared.data_models import SummarizedContent
from cognee.tasks.summarization import summarize_code
from cognee.tests.tasks.graph.code_graph_test_data_generation import (
code_graph_test_data_generation,
)
def test_summarize_code():
nodes, _ = code_graph_test_data_generation()
nodes_out = asyncio.run(summarize_code(nodes, SummarizedContent))
for node_in, node_out in zip(nodes, nodes_out):
assert node_in == node_out, f"{node_in = } != {node_out = }"

83
cognee/tests/test_falkordb.py Executable file
View file

@ -0,0 +1,83 @@
import os
import logging
import pathlib
import cognee
from cognee.api.v1.search import SearchType
from cognee.shared.utils import render_graph
logging.basicConfig(level = logging.DEBUG)
async def main():
data_directory_path = str(pathlib.Path(os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_falkordb")).resolve())
cognee.config.data_root_directory(data_directory_path)
cognee_directory_path = str(pathlib.Path(os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_falkordb")).resolve())
cognee.config.system_root_directory(cognee_directory_path)
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata = True)
dataset_name = "artificial_intelligence"
ai_text_file_path = os.path.join(pathlib.Path(__file__).parent, "test_data/artificial-intelligence.pdf")
await cognee.add([ai_text_file_path], dataset_name)
text = """A large language model (LLM) is a language model notable for its ability to achieve general-purpose language generation and other natural language processing tasks such as classification. LLMs acquire these abilities by learning statistical relationships from text documents during a computationally intensive self-supervised and semi-supervised training process. LLMs can be used for text generation, a form of generative AI, by taking an input text and repeatedly predicting the next token or word.
LLMs are artificial neural networks. The largest and most capable, as of March 2024, are built with a decoder-only transformer-based architecture while some recent implementations are based on other architectures, such as recurrent neural network variants and Mamba (a state space model).
Up to 2020, fine tuning was the only way a model could be adapted to be able to accomplish specific tasks. Larger sized models, such as GPT-3, however, can be prompt-engineered to achieve similar results.[6] They are thought to acquire knowledge about syntax, semantics and "ontology" inherent in human language corpora, but also inaccuracies and biases present in the corpora.
Some notable LLMs are OpenAI's GPT series of models (e.g., GPT-3.5 and GPT-4, used in ChatGPT and Microsoft Copilot), Google's PaLM and Gemini (the latter of which is currently used in the chatbot of the same name), xAI's Grok, Meta's LLaMA family of open-source models, Anthropic's Claude models, Mistral AI's open source models, and Databricks' open source DBRX.
"""
await cognee.add([text], dataset_name)
await cognee.cognify([dataset_name])
# await render_graph(None, include_labels = True, include_nodes = True)
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("entity.name", "AI"))[0]
random_node_name = random_node.payload["text"]
search_results = await cognee.search(SearchType.INSIGHTS, query_text = random_node_name)
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted sentences are:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search(SearchType.CHUNKS, query_text = random_node_name)
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted chunks are:\n")
for result in search_results:
print(f"{result}\n")
search_results = await cognee.search(SearchType.SUMMARIES, query_text = random_node_name)
assert len(search_results) != 0, "Query related summaries don't exist."
print("\nExtracted summaries are:\n")
for result in search_results:
print(f"{result}\n")
history = await cognee.get_search_history()
assert len(history) == 6, "Search history is not correct."
# Assert local data files are cleaned properly
await cognee.prune.prune_data()
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
# Assert relational, vector and graph databases have been cleaned properly
await cognee.prune.prune_system(metadata=True)
connection = await vector_engine.get_connection()
collection_names = await connection.table_names()
assert len(collection_names) == 0, "LanceDB vector database is not empty"
from cognee.infrastructure.databases.relational import get_relational_engine
assert not os.path.exists(get_relational_engine().db_path), "SQLite relational database is not empty"
from cognee.infrastructure.databases.graph import get_graph_config
graph_config = get_graph_config()
assert not os.path.exists(graph_config.graph_file_path), "Networkx graph database is not empty"
if __name__ == "__main__":
import asyncio
asyncio.run(main(), debug=True)

View file

@ -0,0 +1,100 @@
import asyncio
import random
import time
from typing import List
from uuid import uuid5, NAMESPACE_OID
from cognee.infrastructure.engine import DataPoint
from cognee.modules.graph.utils import get_graph_from_model
random.seed(1500)
class Repository(DataPoint):
path: str
class CodeFile(DataPoint):
part_of: Repository
contains: List["CodePart"] = []
depends_on: List["CodeFile"] = []
source_code: str
class CodePart(DataPoint):
part_of: CodeFile
source_code: str
CodeFile.model_rebuild()
CodePart.model_rebuild()
def nanoseconds_to_largest_unit(nanoseconds):
# Define conversion factors
conversion_factors = {
'weeks': 7 * 24 * 60 * 60 * 1e9,
'days': 24 * 60 * 60 * 1e9,
'hours': 60 * 60 * 1e9,
'minutes': 60 * 1e9,
'seconds': 1e9,
'miliseconds': 1e6,
'microseconds': 1e3,
}
# Iterate through conversion factors to find the largest unit
for unit, factor in conversion_factors.items():
converted_value = nanoseconds / factor
if converted_value >= 1:
return converted_value, unit
# If nanoseconds is smaller than a second
return nanoseconds, 'nanoseconds'
async def test_circular_reference_extraction():
repo = Repository(path = "repo1")
code_files = [CodeFile(
id = uuid5(NAMESPACE_OID, f"file{file_index}"),
source_code = "source code",
part_of = repo,
contains = [],
depends_on = [CodeFile(
id = uuid5(NAMESPACE_OID, f"file{random_id}"),
source_code = "source code",
part_of = repo,
depends_on = [],
) for random_id in [random.randint(0, 1499) for _ in range(random.randint(0, 5))]],
) for file_index in range(1500)]
for code_file in code_files:
code_file.contains.extend([CodePart(
part_of = code_file,
source_code = f"Part {part_index}",
) for part_index in range(random.randint(1, 20))])
nodes = []
edges = []
start = time.perf_counter_ns()
results = await asyncio.gather(*[
get_graph_from_model(code_file) for code_file in code_files
])
time_to_run = time.perf_counter_ns() - start
print(nanoseconds_to_largest_unit(time_to_run))
for result_nodes, result_edges in results:
nodes.extend(result_nodes)
edges.extend(result_edges)
# for code_file in code_files:
# model_nodes, model_edges = get_graph_from_model(code_file)
# nodes.extend(model_nodes)
# edges.extend(model_edges)
assert len(nodes) == 1501
assert len(edges) == 1501 * 20 + 1500 * 5
if __name__ == "__main__":
asyncio.run(test_circular_reference_extraction())

View file

@ -11,7 +11,7 @@ from cognee.tests.unit.interfaces.graph.util import (
@pytest.mark.parametrize("recursive_depth", [1, 2, 3]) @pytest.mark.parametrize("recursive_depth", [1, 2, 3])
def test_society_nodes_and_edges(recursive_depth): async def test_society_nodes_and_edges(recursive_depth):
import sys import sys
if sys.version_info[0] == 3 and sys.version_info[1] >= 11: if sys.version_info[0] == 3 and sys.version_info[1] >= 11:
@ -22,7 +22,7 @@ def test_society_nodes_and_edges(recursive_depth):
n_organizations, n_persons = count_society(society) n_organizations, n_persons = count_society(society)
society_counts_total = n_organizations + n_persons society_counts_total = n_organizations + n_persons
nodes, edges = get_graph_from_model(society) nodes, edges = await get_graph_from_model(society)
assert ( assert (
len(nodes) == society_counts_total len(nodes) == society_counts_total

View file

@ -48,29 +48,29 @@ PERSON_GROUND_TRUTH = {
} }
def test_extracted_car_type(boris): async def test_extracted_car_type(boris):
nodes, _ = get_graph_from_model(boris) nodes, _ = await get_graph_from_model(boris)
assert len(nodes) == 3 assert len(nodes) == 3
car_type = nodes[0] car_type = nodes[0]
run_test_against_ground_truth("car_type", car_type, CAR_TYPE_GROUND_TRUTH) run_test_against_ground_truth("car_type", car_type, CAR_TYPE_GROUND_TRUTH)
def test_extracted_car(boris): async def test_extracted_car(boris):
nodes, _ = get_graph_from_model(boris) nodes, _ = await get_graph_from_model(boris)
assert len(nodes) == 3 assert len(nodes) == 3
car = nodes[1] car = nodes[1]
run_test_against_ground_truth("car", car, CAR_GROUND_TRUTH) run_test_against_ground_truth("car", car, CAR_GROUND_TRUTH)
def test_extracted_person(boris): async def test_extracted_person(boris):
nodes, _ = get_graph_from_model(boris) nodes, _ = await get_graph_from_model(boris)
assert len(nodes) == 3 assert len(nodes) == 3
person = nodes[2] person = nodes[2]
run_test_against_ground_truth("person", person, PERSON_GROUND_TRUTH) run_test_against_ground_truth("person", person, PERSON_GROUND_TRUTH)
def test_extracted_car_sedan_edge(boris): async def test_extracted_car_sedan_edge(boris):
_, edges = get_graph_from_model(boris) _, edges = await get_graph_from_model(boris)
edge = edges[0] edge = edges[0]
assert CAR_SEDAN_EDGE[:3] == edge[:3], f"{CAR_SEDAN_EDGE[:3] = } != {edge[:3] = }" assert CAR_SEDAN_EDGE[:3] == edge[:3], f"{CAR_SEDAN_EDGE[:3] = } != {edge[:3] = }"
@ -78,8 +78,8 @@ def test_extracted_car_sedan_edge(boris):
assert ground_truth == edge[3][key], f"{ground_truth = } != {edge[3][key] = }" assert ground_truth == edge[3][key], f"{ground_truth = } != {edge[3][key] = }"
def test_extracted_boris_car_edge(boris): async def test_extracted_boris_car_edge(boris):
_, edges = get_graph_from_model(boris) _, edges = await get_graph_from_model(boris)
edge = edges[1] edge = edges[1]
assert ( assert (

View file

@ -14,14 +14,14 @@ from cognee.tests.unit.interfaces.graph.util import (
@pytest.mark.parametrize("recursive_depth", [1, 2, 3]) @pytest.mark.parametrize("recursive_depth", [1, 2, 3])
def test_society_nodes_and_edges(recursive_depth): async def test_society_nodes_and_edges(recursive_depth):
import sys import sys
if sys.version_info[0] == 3 and sys.version_info[1] >= 11: if sys.version_info[0] == 3 and sys.version_info[1] >= 11:
society = create_organization_recursive( society = create_organization_recursive(
"society", "Society", PERSON_NAMES, recursive_depth "society", "Society", PERSON_NAMES, recursive_depth
) )
nodes, edges = get_graph_from_model(society) nodes, edges = await get_graph_from_model(society)
parsed_society = get_model_instance_from_graph(nodes, edges, "society") parsed_society = get_model_instance_from_graph(nodes, edges, "society")
assert str(society) == (str(parsed_society)), show_first_difference( assert str(society) == (str(parsed_society)), show_first_difference(

View file

@ -25,8 +25,8 @@ CAR_GROUND_TRUTH = {
} }
def test_parsed_person(boris): async def test_parsed_person(boris):
nodes, edges = get_graph_from_model(boris) nodes, edges = await get_graph_from_model(boris)
parsed_person = get_model_instance_from_graph(nodes, edges, "boris") parsed_person = get_model_instance_from_graph(nodes, edges, "boris")
run_test_against_ground_truth( run_test_against_ground_truth(

64
evals/EC2_README.md Normal file
View file

@ -0,0 +1,64 @@
## Creating the EC2 Instance
Create an EC2 Instance with the
`Ubuntu Image`
Many instance types will work, we used:
`m7a.2xlarge` # more than 8 parallel processes doesn't seem to speed up overall process. Maybe to do with docker parallelism?
DON'T FORGET TO ADD
`500 GB storage`
Or the evaluation run will run out of space
Add a key pair login where you have access to the corresponding key file (*.pem)
## Accessing your instance and setup
To ssh into the instance, you have to save your key pair file (*.pem) to an appropriate location, such as ~/.aws. After launching the instance, you can access the Instance Summary, and retrieve "Public IPv4 DNS" address. Then run
`ssh -i PATH_TO_KEY ubuntu@IPv4ADDRESS`
to gain command line access to the instance.
To copy your current state of cognee, go to the folder that contains "cognee" on your local machine, zip it to cognee.zip and run:
`zip -r cognee.zip cognee`
`scp -i PATH_TO_KEY cognee.zip ubuntu@IPv4ADDRESS:cognee.zip`
And unzip cognee.zip in your SSH session:
`sudo apt install unzip`
`unzip cognee.zip`
Then run:
`cd cognee`
`source evals/cloud/setup_ubuntu_instance.sh`
`sudo usermod -aG docker $USER`
disconnect, and reconnect.
Confirm that `ubuntu` has been added to the docker user group with
`groups | grep docker`
## Running SWE-bench
Then enter a `screen` and activate the virtual env
`screen`
`source venv/bin/activate`
then, from cognee, you can run swe_bench:
`cd cognee`
`python evals/eval_swe_bench.py --cognee_off --max_workers=N_CPUS`
Building the environment images should take roughly 17 minutes
If the virtual env wasn't set up correctly for some reason, just run the last few lines of `setup_ubuntu_instance.sh` manually

View file

@ -0,0 +1,33 @@
sudo apt-get update -y
sudo apt-get install -y ca-certificates curl
sudo install -m 0755 -d /etc/apt/keyrings
sudo curl -fsSL https://download.docker.com/linux/ubuntu/gpg -o /etc/apt/keyrings/docker.asc
sudo chmod a+r /etc/apt/keyrings/docker.asc
# Add the repository to Apt sources:
echo \
"deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/ubuntu \
$(. /etc/os-release && echo "$VERSION_CODENAME") stable" | \
sudo tee /etc/apt/sources.list.d/docker.list > /dev/null
sudo apt-get update -y
sudo apt-get install -y docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin
sudo docker run hello-world
sudo apt install -y unzip
sudo apt-get install -y python3-virtualenv
sudo add-apt-repository -y ppa:deadsnakes/ppa
sudo apt update -y
sudo apt install -y python3.11
virtualenv venv --python=python3.11
source venv/bin/activate
pip install poetry
poetry install
pip install swebench transformers sentencepiece datasets tiktoken protobuf

14
evals/deepeval_metrics.py Normal file
View file

@ -0,0 +1,14 @@
from deepeval.metrics import GEval
from deepeval.test_case import LLMTestCaseParams
correctness_metric = GEval(
name="Correctness",
model="gpt-4o-mini",
evaluation_params=[
LLMTestCaseParams.ACTUAL_OUTPUT,
LLMTestCaseParams.EXPECTED_OUTPUT
],
evaluation_steps=[
"Determine whether the actual output is factually correct based on the expected output."
]
)

View file

@ -1,69 +1,113 @@
import argparse import argparse
import json import json
import subprocess import subprocess
import sys
from pathlib import Path from pathlib import Path
from datasets import Dataset
from swebench.harness.utils import load_swebench_dataset from swebench.harness.utils import load_swebench_dataset
from swebench.inference.make_datasets.create_instance import PATCH_EXAMPLE from swebench.inference.make_datasets.create_instance import PATCH_EXAMPLE
import cognee import cognee
from cognee.api.v1.cognify.code_graph_pipeline import code_graph_pipeline
from cognee.api.v1.search import SearchType from cognee.api.v1.search import SearchType
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.llm.get_llm_client import get_llm_client from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.infrastructure.llm.prompts import read_query_prompt from cognee.infrastructure.llm.prompts import read_query_prompt
from evals.eval_utils import download_instances from cognee.modules.pipelines import Task, run_tasks
from cognee.modules.retrieval.brute_force_triplet_search import \
brute_force_triplet_search
from cognee.shared.data_models import SummarizedContent
from cognee.shared.utils import render_graph
from cognee.tasks.repo_processor import (enrich_dependency_graph,
expand_dependency_graph,
get_repo_file_dependencies)
from cognee.tasks.storage import add_data_points
from cognee.tasks.summarization import summarize_code
from evals.eval_utils import download_github_repo, retrieved_edges_to_string
async def generate_patch_with_cognee(instance, search_type=SearchType.CHUNKS): def check_install_package(package_name):
"""
Check if a pip package is installed and install it if not.
Returns True if package is/was installed successfully, False otherwise.
"""
try:
__import__(package_name)
return True
except ImportError:
try:
subprocess.check_call(
[sys.executable, "-m", "pip", "install", package_name]
)
return True
except subprocess.CalledProcessError:
return False
async def generate_patch_with_cognee(instance, llm_client, search_type=SearchType.CHUNKS):
await cognee.prune.prune_data() await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True) await cognee.prune.prune_system()
dataset_name = "SWE_test_data" # repo_path = download_github_repo(instance, '../RAW_GIT_REPOS')
code_text = instance["text"]
await cognee.add([code_text], dataset_name) repo_path = '/Users/borisarzentar/Projects/graphrag'
await code_graph_pipeline([dataset_name])
graph_engine = await get_graph_engine() tasks = [
with open(graph_engine.filename, "r") as f: Task(get_repo_file_dependencies),
graph_str = f.read() Task(add_data_points, task_config = { "batch_size": 50 }),
Task(enrich_dependency_graph, task_config = { "batch_size": 50 }),
Task(expand_dependency_graph, task_config = { "batch_size": 50 }),
Task(add_data_points, task_config = { "batch_size": 50 }),
# Task(summarize_code, summarization_model = SummarizedContent),
]
pipeline = run_tasks(tasks, repo_path, "cognify_code_pipeline")
async for result in pipeline:
print(result)
print('Here we have the repo under the repo_path')
await render_graph(None, include_labels = True, include_nodes = True)
problem_statement = instance['problem_statement'] problem_statement = instance['problem_statement']
instructions = read_query_prompt("patch_gen_instructions.txt") instructions = read_query_prompt("patch_gen_kg_instructions.txt")
retrieved_edges = await brute_force_triplet_search(problem_statement, top_k = 3, collections = ["data_point_source_code", "data_point_text"])
retrieved_edges_str = retrieved_edges_to_string(retrieved_edges)
prompt = "\n".join([ prompt = "\n".join([
instructions, problem_statement,
"<patch>", "<patch>",
PATCH_EXAMPLE, PATCH_EXAMPLE,
"</patch>", "</patch>",
"This is the knowledge graph:", "These are the retrieved edges:",
graph_str retrieved_edges_str
]) ])
llm_client = get_llm_client() llm_client = get_llm_client()
answer_prediction = await llm_client.acreate_structured_output( answer_prediction = await llm_client.acreate_structured_output(
text_input=problem_statement, text_input=prompt,
system_prompt=prompt, system_prompt=instructions,
response_model=str, response_model=str,
) )
return answer_prediction return answer_prediction
async def generate_patch_without_cognee(instance): async def generate_patch_without_cognee(instance, llm_client):
problem_statement = instance['problem_statement'] instructions = read_query_prompt("patch_gen_instructions.txt")
prompt = instance["text"]
llm_client = get_llm_client()
answer_prediction = await llm_client.acreate_structured_output( answer_prediction = await llm_client.acreate_structured_output(
text_input=problem_statement, text_input=instance["text"],
system_prompt=prompt, system_prompt=instructions,
response_model=str, response_model=str,
) )
return answer_prediction return answer_prediction
async def get_preds(dataset, with_cognee=True): async def get_preds(dataset, with_cognee=True):
llm_client = get_llm_client()
if with_cognee: if with_cognee:
model_name = "with_cognee" model_name = "with_cognee"
pred_func = generate_patch_with_cognee pred_func = generate_patch_with_cognee
@ -71,9 +115,20 @@ async def get_preds(dataset, with_cognee=True):
model_name = "without_cognee" model_name = "without_cognee"
pred_func = generate_patch_without_cognee pred_func = generate_patch_without_cognee
preds = [{"instance_id": instance["instance_id"], futures = [
"model_patch": await pred_func(instance), (instance["instance_id"], pred_func(instance, llm_client))
"model_name_or_path": model_name} for instance in dataset] for instance in dataset
]
model_patches = await asyncio.gather(*[x[1] for x in futures])
preds = [
{
"instance_id": instance_id,
"model_patch": model_patch,
"model_name_or_path": model_name,
}
for (instance_id, _), model_patch in zip(futures, model_patches)
]
return preds return preds
@ -82,8 +137,12 @@ async def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Run LLM predictions on SWE-bench dataset") description="Run LLM predictions on SWE-bench dataset")
parser.add_argument('--cognee_off', action='store_true') parser.add_argument('--cognee_off', action='store_true')
parser.add_argument("--max_workers", type=int, required=True)
args = parser.parse_args() args = parser.parse_args()
for dependency in ["transformers", "sentencepiece", "swebench"]:
check_install_package(dependency)
if args.cognee_off: if args.cognee_off:
dataset_name = 'princeton-nlp/SWE-bench_Lite_bm25_13K' dataset_name = 'princeton-nlp/SWE-bench_Lite_bm25_13K'
dataset = load_swebench_dataset(dataset_name, split='test') dataset = load_swebench_dataset(dataset_name, split='test')
@ -96,23 +155,32 @@ async def main():
dataset_name = 'princeton-nlp/SWE-bench_Lite' dataset_name = 'princeton-nlp/SWE-bench_Lite'
swe_dataset = load_swebench_dataset( swe_dataset = load_swebench_dataset(
dataset_name, split='test')[:1] dataset_name, split='test')[:1]
filepath = Path("SWE-bench_testsample")
if filepath.exists():
dataset = Dataset.load_from_disk(filepath)
else:
dataset = download_instances(swe_dataset, filepath)
predictions_path = "preds.json" predictions_path = "preds.json"
preds = await get_preds(dataset, with_cognee=not args.cognee_off) preds = await get_preds(swe_dataset, with_cognee=not args.cognee_off)
with open(predictions_path, "w") as file: with open(predictions_path, "w") as file:
json.dump(preds, file) json.dump(preds, file)
subprocess.run(["python", "-m", "swebench.harness.run_evaluation",
"--dataset_name", dataset_name, subprocess.run(
"--split", "test", [
"--predictions_path", predictions_path, "python",
"--max_workers", "1", "-m",
"--run_id", "test_run"]) "swebench.harness.run_evaluation",
"--dataset_name",
dataset_name,
"--split",
"test",
"--predictions_path",
predictions_path,
"--max_workers",
str(args.max_workers),
"--run_id",
"test_run",
]
)
if __name__ == "__main__": if __name__ == "__main__":
import asyncio import asyncio
asyncio.run(main(), debug=True) asyncio.run(main(), debug=True)

View file

@ -1,103 +1,72 @@
import os import os
from copy import deepcopy import shutil
from pathlib import Path
from tempfile import TemporaryDirectory
from datasets import Dataset from git import Repo
from swebench.inference.make_datasets.create_instance import make_code_text
from swebench.inference.make_datasets.utils import (AutoContextManager,
ingest_directory_contents)
from tqdm.auto import tqdm
def ingest_files(filenames): def download_github_repo(instance, output_dir):
files_dict = dict() """
for filename in filenames: Downloads a GitHub repository and checks out the specified commit.
with open(filename) as f:
content = f.read()
files_dict[filename] = content
return files_dict
def ingest_repos(input_instances):
orig_dir = os.getcwd()
with TemporaryDirectory(
dir="/scratch" if os.path.exists("/scratch") else "/tmp"
) as root_dir:
for instance in tqdm(
input_instances.values(),
total=len(input_instances),
desc="Downloading repos on specific commits",
):
try:
with AutoContextManager(
instance, root_dir
) as cm:
readmes = cm.get_readme_files()
instance["readmes"] = ingest_files(readmes)
instance["file_contents"] = ingest_directory_contents(
cm.repo_path
)
finally:
# if AutoContextManager fails to exit properly future exits will return the wrong directory
os.chdir(orig_dir)
return input_instances
def extract_fields(instance):
readmes_text = make_code_text(instance["readmes"])
code_text = make_code_text(
instance["file_contents"], add_line_numbers=False)
text_inputs = "\n".join([readmes_text, code_text])
text_inputs = text_inputs.strip() + "\n\n"
# text_inputs = code_text
patch = "\n".join(["<patch>", instance["patch"], "</patch>"])
return {**instance, "text": text_inputs, "patch": patch}
def create_dataset(input_instances):
columns = [
"instance_id",
"text",
"repo",
"base_commit",
"problem_statement",
"hints_text",
"created_at",
"patch",
"test_patch",
"version",
"FAIL_TO_PASS",
"PASS_TO_PASS",
"environment_setup_commit",
]
data_table = {key: list() for key in columns}
for instance in input_instances.values():
datum = extract_fields(instance)
for key in columns:
data_table[key].append(datum[key] if key in datum else "")
dataset = Dataset.from_dict(data_table)
return dataset
def download_instances(
input_data,
path=Path("SWE-bench_testsample"),
verbose=False,
):
"""Downloads code from github.
Args: Args:
- input_data: dictionary with unprocessed input instances. instance (dict): Dictionary containing 'repo', 'base_commit', and 'instance_id'.
- verbose: set ContextManager verbose to True output_dir (str): Directory to store the downloaded repositories.
Returns:
str: Path to the downloaded repository.
""" """
input_instances = {x["instance_id"]: x for x in input_data} repo_owner_repo = instance['repo']
input_instances_copy = deepcopy(input_instances) base_commit = instance['base_commit']
input_instances_with_text = ingest_repos(input_instances_copy) instance_id = instance['instance_id']
dataset = create_dataset(input_instances_with_text)
dataset.save_to_disk(path) repo_url = f"https://github.com/{repo_owner_repo}.git"
return dataset
repo_path = os.path.abspath(os.path.join(output_dir, instance_id))
# Clone repository if it doesn't already exist
if not os.path.exists(repo_path):
print(f"Cloning {repo_url} to {repo_path}...")
Repo.clone_from(repo_url, repo_path)
else:
print(f"Repository already exists at {repo_path}.")
repo = Repo(repo_path)
repo.git.checkout(base_commit)
return repo_path
def delete_repo(repo_path):
"""
Deletes the specified repository directory.
Args:
repo_path (str): Path to the repository to delete.
Returns:
None
"""
try:
if os.path.exists(repo_path):
shutil.rmtree(repo_path)
print(f"Deleted repository at {repo_path}.")
else:
print(f"Repository path {repo_path} does not exist. Nothing to delete.")
except Exception as e:
print(f"Error deleting repository at {repo_path}: {e}")
def node_to_string(node):
text = node.attributes["text"]
type = node.attributes["type"]
return f"Node(id: {node.id}, type: {type}, description: {text})"
def retrieved_edges_to_string(retrieved_edges):
edge_strings = []
for edge in retrieved_edges:
relationship_type = edge.attributes["relationship_type"]
edge_str = f"{node_to_string(edge.node1)} {relationship_type} {node_to_string(edge.node2)}"
edge_strings.append(edge_str)
return "\n".join(edge_strings)

130
evals/llm_as_a_judge.py Normal file
View file

@ -0,0 +1,130 @@
import argparse
import asyncio
import json
import statistics
from pathlib import Path
import deepeval.metrics
import wget
from deepeval.dataset import EvaluationDataset
from deepeval.test_case import LLMTestCase
from tqdm import tqdm
import cognee
import evals.deepeval_metrics
from cognee.api.v1.search import SearchType
from cognee.base_config import get_base_config
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt
async def answer_without_cognee(instance):
args = {
"question": instance["question"],
"context": instance["context"],
}
user_prompt = render_prompt("context_for_question.txt", args)
system_prompt = read_query_prompt("answer_hotpot_question.txt")
llm_client = get_llm_client()
answer_prediction = await llm_client.acreate_structured_output(
text_input=user_prompt,
system_prompt=system_prompt,
response_model=str,
)
return answer_prediction
async def answer_with_cognee(instance):
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
for (title, sentences) in instance["context"]:
await cognee.add("\n".join(sentences), dataset_name = "HotPotQA")
await cognee.cognify("HotPotQA")
search_results = await cognee.search(
SearchType.INSIGHTS, query_text=instance["question"]
)
args = {
"question": instance["question"],
"context": search_results,
}
user_prompt = render_prompt("context_for_question.txt", args)
system_prompt = read_query_prompt("answer_hotpot_using_cognee_search.txt")
llm_client = get_llm_client()
answer_prediction = await llm_client.acreate_structured_output(
text_input=user_prompt,
system_prompt=system_prompt,
response_model=str,
)
return answer_prediction
async def eval_answers(instances, answers, eval_metric):
test_cases = []
for instance, answer in zip(instances, answers):
test_case = LLMTestCase(
input=instance["question"],
actual_output=answer,
expected_output=instance["answer"]
)
test_cases.append(test_case)
eval_set = EvaluationDataset(test_cases)
eval_results = eval_set.evaluate([eval_metric])
return eval_results
async def eval_on_hotpotQA(answer_provider, num_samples, eval_metric):
base_config = get_base_config()
data_root_dir = base_config.data_root_directory
if not Path(data_root_dir).exists():
Path(data_root_dir).mkdir()
filepath = data_root_dir / Path("hotpot_dev_fullwiki_v1.json")
if not filepath.exists():
url = 'http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_fullwiki_v1.json'
wget.download(url, out=data_root_dir)
with open(filepath, "r") as file:
dataset = json.load(file)
instances = dataset if not num_samples else dataset[:num_samples]
answers = []
for instance in tqdm(instances, desc="Getting answers"):
answer = await answer_provider(instance)
answers.append(answer)
eval_results = await eval_answers(instances, answers, eval_metric)
avg_score = statistics.mean([result.metrics_data[0].score for result in eval_results.test_results])
return avg_score
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--with_cognee", action="store_true")
parser.add_argument("--num_samples", type=int, default=500)
parser.add_argument("--metric", type=str, default="correctness_metric")
args = parser.parse_args()
try:
metric_cls = getattr(deepeval.metrics, args.metric)
metric = metric_cls()
except AttributeError:
metric = getattr(evals.deepeval_metrics, args.metric)
if args.with_cognee:
answer_provider = answer_with_cognee
else:
answer_provider = answer_without_cognee
avg_score = asyncio.run(eval_on_hotpotQA(answer_provider, args.num_samples, metric))
print(f"Average {args.metric}: {avg_score}")

2285
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,7 +1,7 @@
import argparse import argparse
import time import asyncio
from benchmark_function import benchmark_function from .benchmark_function import benchmark_function
from cognee.modules.graph.utils import get_graph_from_model from cognee.modules.graph.utils import get_graph_from_model
from cognee.tests.unit.interfaces.graph.util import ( from cognee.tests.unit.interfaces.graph.util import (
@ -28,9 +28,12 @@ if __name__ == "__main__":
society = create_organization_recursive( society = create_organization_recursive(
"society", "Society", PERSON_NAMES, args.recursive_depth "society", "Society", PERSON_NAMES, args.recursive_depth
) )
nodes, edges = get_graph_from_model(society) nodes, edges = asyncio.run(get_graph_from_model(society))
results = benchmark_function(get_graph_from_model, society, num_runs=args.runs) def get_graph_from_model_sync(model):
return asyncio.run(get_graph_from_model(model))
results = benchmark_function(get_graph_from_model_sync, society, num_runs=args.runs)
print("\nBenchmark Results:") print("\nBenchmark Results:")
print( print(
f"N nodes: {len(nodes)}, N edges: {len(edges)}, Recursion depth: {args.recursive_depth}" f"N nodes: {len(nodes)}, N edges: {len(edges)}, Recursion depth: {args.recursive_depth}"

View file

@ -0,0 +1,9 @@
import numpy as np
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
class DummyEmbeddingEngine(EmbeddingEngine):
async def embed_text(self, text: list[str]) -> list[list[float]]:
return(list(list(np.random.randn(3072))))
def get_vector_size(self) -> int:
return(3072)

View file

@ -0,0 +1,65 @@
from typing import Type
from uuid import uuid4
import spacy
import textacy
from pydantic import BaseModel
from cognee.infrastructure.llm.llm_interface import LLMInterface
from cognee.shared.data_models import Edge, KnowledgeGraph, Node, SummarizedContent
class DummyLLMAdapter(LLMInterface):
nlp = spacy.load("en_core_web_sm")
async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel:
if (
str(response_model)
== "<class 'cognee.shared.data_models.SummarizedContent'>"
):
return dummy_summarize_content(text_input)
elif (
str(response_model) == "<class 'cognee.shared.data_models.KnowledgeGraph'>"
):
return dummy_extract_knowledge_graph(text_input, self.nlp)
else:
raise Exception(
"Currently dummy acreate_structured_input is only implemented for SummarizedContent and KnowledgeGraph"
)
def dummy_extract_knowledge_graph(text, nlp):
doc = nlp(text)
triples = list(textacy.extract.subject_verb_object_triples(doc))
nodes = {}
edges = []
for triple in triples:
source = "_".join([str(e) for e in triple.subject])
target = "_".join([str(e) for e in triple.object])
nodes[source] = nodes.get(
source, Node(id=str(uuid4()), name=source, type="object", description="")
)
nodes[target] = nodes.get(
target, Node(id=str(uuid4()), name=target, type="object", description="")
)
edge_type = "_".join([str(e) for e in triple.verb])
edges.append(
Edge(
source_node_id=nodes[source].id,
target_node_id=nodes[target].id,
relationship_name=edge_type,
)
)
return KnowledgeGraph(nodes=list(nodes.values()), edges=edges)
def dummy_summarize_content(text):
words = [(word, len(word)) for word in set(text.split(" "))]
words = sorted(words, key=lambda x: x[1], reverse=True)
summary = " ".join([word for word, _ in words[:50]])
description = " ".join([word for word, _ in words[:10]])
return SummarizedContent(summary=summary, description=description)

View file

@ -41,7 +41,7 @@ aiosqlite = "^0.20.0"
pandas = "2.0.3" pandas = "2.0.3"
filetype = "^1.2.0" filetype = "^1.2.0"
nltk = "^3.8.1" nltk = "^3.8.1"
dlt = {extras = ["sqlalchemy"], version = "^1.3.0"} dlt = {extras = ["sqlalchemy"], version = "^1.4.1"}
aiofiles = "^23.2.1" aiofiles = "^23.2.1"
qdrant-client = "^1.9.0" qdrant-client = "^1.9.0"
graphistry = "^0.33.5" graphistry = "^0.33.5"
@ -70,6 +70,7 @@ asyncpg = "0.30.0"
pgvector = "^0.3.5" pgvector = "^0.3.5"
psycopg2 = {version = "^2.9.10", optional = true} psycopg2 = {version = "^2.9.10", optional = true}
llama-index-core = {version = "^0.11.22", optional = true} llama-index-core = {version = "^0.11.22", optional = true}
deepeval = {version = "^2.0.1", optional = true}
[tool.poetry.extras] [tool.poetry.extras]
filesystem = ["s3fs", "botocore"] filesystem = ["s3fs", "botocore"]
@ -80,6 +81,8 @@ neo4j = ["neo4j"]
postgres = ["psycopg2", "pgvector", "asyncpg"] postgres = ["psycopg2", "pgvector", "asyncpg"]
notebook = ["ipykernel", "overrides", "ipywidgets", "jupyterlab", "jupyterlab_widgets", "jupyterlab-server", "jupyterlab-git"] notebook = ["ipykernel", "overrides", "ipywidgets", "jupyterlab", "jupyterlab_widgets", "jupyterlab-server", "jupyterlab-git"]
llama-index = ["llama-index-core"] llama-index = ["llama-index-core"]
deepeval = ["deepeval"]
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]