fix: refactor get_graph_from_model to return nodes and edges correctly (#257)

* fix: handle rate limit error coming from llm model

* fix: fixes lost edges and nodes in get_graph_from_model

* fix: fixes database pruning issue in pgvector (#261)

* fix: cognee_demo notebook pipeline is not saving summaries

---------

Co-authored-by: hajdul88 <52442977+hajdul88@users.noreply.github.com>
This commit is contained in:
Boris 2024-12-06 12:52:01 +01:00 committed by GitHub
parent 351ce92001
commit 348610e73c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 242 additions and 160 deletions

View file

@ -81,13 +81,13 @@ async def run_cognify_pipeline(dataset: Dataset, user: User):
Task(classify_documents),
Task(check_permissions_on_documents, user = user, permissions = ["write"]),
Task(extract_chunks_from_documents), # Extract text chunks based on the document type.
Task(add_data_points, task_config = { "batch_size": 10 }),
Task(extract_graph_from_data, graph_model = KnowledgeGraph, task_config = { "batch_size": 10 }), # Generate knowledge graphs from the document chunks.
Task(
summarize_text,
summarization_model = cognee_config.summarization_model,
task_config = { "batch_size": 10 }
),
Task(add_data_points, task_config = { "batch_size": 10 }),
]
pipeline = run_tasks(tasks, data_documents, "cognify_pipeline")

View file

@ -29,7 +29,14 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
self.model = model
self.dimensions = dimensions
MAX_RETRIES = 5
retry_count = 0
async def embed_text(self, text: List[str]) -> List[List[float]]:
async def exponential_backoff(attempt):
wait_time = min(10 * (2 ** attempt), 60) # Max 60 seconds
await asyncio.sleep(wait_time)
try:
response = await litellm.aembedding(
self.model,
@ -38,11 +45,18 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
api_base = self.endpoint,
api_version = self.api_version
)
self.retry_count = 0
return [data["embedding"] for data in response.data]
except litellm.exceptions.ContextWindowExceededError as error:
if isinstance(text, list):
if len(text) == 1:
parts = [text]
else:
parts = [text[0:math.ceil(len(text)/2)], text[math.ceil(len(text)/2):]]
parts_futures = [self.embed_text(part) for part in parts]
embeddings = await asyncio.gather(*parts_futures)
@ -50,11 +64,21 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
for embeddings_part in embeddings:
all_embeddings.extend(embeddings_part)
return [data["embedding"] for data in all_embeddings]
return all_embeddings
logger.error("Context window exceeded for embedding text: %s", str(error))
raise error
except litellm.exceptions.RateLimitError:
if self.retry_count >= self.MAX_RETRIES:
raise Exception(f"Rate limit exceeded and no more retries left.")
await exponential_backoff(self.retry_count)
self.retry_count += 1
return await self.embed_text(text)
except Exception as error:
logger.error("Error embedding text: %s", str(error))
raise error

View file

@ -35,6 +35,7 @@ class TextChunker():
is_part_of = self.document,
chunk_index = self.chunk_index,
cut_type = chunk_data["cut_type"],
contains = [],
_metadata = {
"index_fields": ["text"],
"metadata_id": self.document.metadata_id
@ -52,6 +53,7 @@ class TextChunker():
is_part_of = self.document,
chunk_index = self.chunk_index,
cut_type = paragraph_chunks[len(paragraph_chunks) - 1]["cut_type"],
contains = [],
_metadata = {
"index_fields": ["text"],
"metadata_id": self.document.metadata_id
@ -73,6 +75,7 @@ class TextChunker():
is_part_of = self.document,
chunk_index = self.chunk_index,
cut_type = paragraph_chunks[len(paragraph_chunks) - 1]["cut_type"],
contains = [],
_metadata = {
"index_fields": ["text"],
"metadata_id": self.document.metadata_id

View file

@ -1,6 +1,7 @@
from typing import Optional
from typing import List, Optional
from cognee.infrastructure.engine import DataPoint
from cognee.modules.data.processing.document_types import Document
from cognee.modules.engine.models import Entity
class DocumentChunk(DataPoint):
__tablename__ = "document_chunk"
@ -9,6 +10,7 @@ class DocumentChunk(DataPoint):
chunk_index: int
cut_type: str
is_part_of: Document
contains: List[Entity] = None
_metadata: Optional[dict] = {
"index_fields": ["text"],

View file

@ -0,0 +1 @@
from .DocumentChunk import DocumentChunk

View file

@ -1,5 +1,4 @@
from cognee.infrastructure.engine import DataPoint
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
from cognee.modules.engine.models.EntityType import EntityType
@ -8,7 +7,6 @@ class Entity(DataPoint):
name: str
is_a: EntityType
description: str
mentioned_in: DocumentChunk
_metadata: dict = {
"index_fields": ["name"],

View file

@ -1,13 +1,10 @@
from cognee.infrastructure.engine import DataPoint
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
class EntityType(DataPoint):
__tablename__ = "entity_type"
name: str
type: str
description: str
exists_in: DocumentChunk
_metadata: dict = {
"index_fields": ["name"],

View file

@ -1,6 +1,6 @@
from typing import Optional
from cognee.infrastructure.engine import DataPoint
from cognee.modules.chunking.models import DocumentChunk
from cognee.modules.engine.models import Entity, EntityType
from cognee.modules.engine.utils import (
generate_edge_name,
@ -11,7 +11,8 @@ from cognee.shared.data_models import KnowledgeGraph
def expand_with_nodes_and_edges(
graph_node_index: list[tuple[DataPoint, KnowledgeGraph]],
data_chunks: list[DocumentChunk],
chunk_graphs: list[KnowledgeGraph],
existing_edges_map: Optional[dict[str, bool]] = None,
):
if existing_edges_map is None:
@ -19,9 +20,10 @@ def expand_with_nodes_and_edges(
added_nodes_map = {}
relationships = []
data_points = []
for graph_source, graph in graph_node_index:
for index, data_chunk in enumerate(data_chunks):
graph = chunk_graphs[index]
if graph is None:
continue
@ -38,7 +40,6 @@ def expand_with_nodes_and_edges(
name = type_node_name,
type = type_node_name,
description = type_node_name,
exists_in = graph_source,
)
added_nodes_map[f"{str(type_node_id)}_type"] = type_node
else:
@ -50,9 +51,13 @@ def expand_with_nodes_and_edges(
name = node_name,
is_a = type_node,
description = node.description,
mentioned_in = graph_source,
)
data_points.append(entity_node)
if data_chunk.contains is None:
data_chunk.contains = []
data_chunk.contains.append(entity_node)
added_nodes_map[f"{str(node_id)}_entity"] = entity_node
# Add relationship that came from graphs.
@ -80,4 +85,4 @@ def expand_with_nodes_and_edges(
)
existing_edges_map[edge_key] = True
return (data_points, relationships)
return (data_chunks, relationships)

View file

@ -1,154 +1,115 @@
from datetime import datetime, timezone
from cognee.infrastructure.engine import DataPoint
from cognee.modules.storage.utils import copy_model
async def get_graph_from_model(
data_point: DataPoint,
added_nodes: dict,
added_edges: dict,
visited_properties: dict = None,
include_root = True,
added_nodes = None,
added_edges = None,
visited_properties = None,
):
if str(data_point.id) in added_nodes:
return [], []
nodes = []
edges = []
added_nodes = added_nodes or {}
added_edges = added_edges or {}
visited_properties = visited_properties or {}
data_point_properties = {}
excluded_properties = set()
if str(data_point.id) in added_nodes:
return nodes, edges
properties_to_visit = set()
for field_name, field_value in data_point:
if field_name == "_metadata":
continue
if field_value is None:
excluded_properties.add(field_name)
continue
if isinstance(field_value, DataPoint):
excluded_properties.add(field_name)
property_key = f"{str(data_point.id)}{field_name}{str(field_value.id)}"
property_key = str(data_point.id) + field_name + str(field_value.id)
if property_key in visited_properties:
continue
visited_properties[property_key] = True
nodes, edges = await add_nodes_and_edges(
data_point,
field_name,
field_value,
nodes,
edges,
added_nodes,
added_edges,
visited_properties,
)
properties_to_visit.add(field_name)
continue
if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint):
excluded_properties.add(field_name)
for field_value_item in field_value:
property_key = f"{str(data_point.id)}{field_name}{str(field_value_item.id)}"
for index, item in enumerate(field_value):
property_key = str(data_point.id) + field_name + str(item.id)
if property_key in visited_properties:
continue
visited_properties[property_key] = True
nodes, edges = await add_nodes_and_edges(
data_point,
field_name,
field_value_item,
nodes,
edges,
added_nodes,
added_edges,
visited_properties,
)
properties_to_visit.add(f"{field_name}.{index}")
continue
data_point_properties[field_name] = field_value
if include_root:
if include_root and str(data_point.id) not in added_nodes:
SimpleDataPointModel = copy_model(
type(data_point),
include_fields = {
"_metadata": (dict, data_point._metadata),
"__tablename__": data_point.__tablename__,
"__tablename__": (str, data_point.__tablename__),
},
exclude_fields = excluded_properties,
exclude_fields = list(excluded_properties),
)
nodes.append(SimpleDataPointModel(**data_point_properties))
added_nodes[str(data_point.id)] = True
return nodes, edges
for field_name in properties_to_visit:
index = None
if "." in field_name:
field_name, index = field_name.split(".")
field_value = getattr(data_point, field_name)
if index is not None:
field_value = field_value[int(index)]
edge_key = str(data_point.id) + str(field_value.id) + field_name
if str(edge_key) not in added_edges:
edges.append((data_point.id, field_value.id, field_name, {
"source_node_id": data_point.id,
"target_node_id": field_value.id,
"relationship_name": field_name,
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
}))
added_edges[str(edge_key)] = True
if str(field_value.id) in added_nodes:
continue
async def add_nodes_and_edges(
data_point,
field_name,
field_value,
nodes,
edges,
added_nodes,
added_edges,
visited_properties,
):
property_nodes, property_edges = await get_graph_from_model(
field_value,
True,
added_nodes,
added_edges,
visited_properties,
include_root = True,
added_nodes = added_nodes,
added_edges = added_edges,
visited_properties = visited_properties,
)
for node in property_nodes:
if str(node.id) not in added_nodes:
nodes.append(node)
added_nodes[str(node.id)] = True
for edge in property_edges:
edge_key = str(edge[0]) + str(edge[1]) + edge[2]
if str(edge_key) not in added_edges:
edges.append(edge)
added_edges[str(edge_key)] = True
for property_node in get_own_properties(property_nodes, property_edges):
edge_key = str(data_point.id) + str(property_node.id) + field_name
property_key = str(data_point.id) + field_name + str(field_value.id)
visited_properties[property_key] = True
if str(edge_key) not in added_edges:
edges.append(
(
data_point.id,
property_node.id,
field_name,
{
"source_node_id": data_point.id,
"target_node_id": property_node.id,
"relationship_name": field_name,
"updated_at": datetime.now(timezone.utc).strftime(
"%Y-%m-%d %H:%M:%S"
),
},
)
)
added_edges[str(edge_key)] = True
return (nodes, edges)
return nodes, edges
def get_own_properties(property_nodes, property_edges):
def get_own_property_nodes(property_nodes, property_edges):
own_properties = []
destination_nodes = [str(property_edge[1]) for property_edge in property_edges]

View file

@ -5,7 +5,8 @@ from cognee.shared.data_models import KnowledgeGraph
async def retrieve_existing_edges(
graph_node_index: list[tuple[DataPoint, KnowledgeGraph]],
data_chunks: list[DataPoint],
chunk_graphs: list[KnowledgeGraph],
graph_engine: GraphDBInterface,
) -> dict[str, bool]:
processed_nodes = {}
@ -13,23 +14,25 @@ async def retrieve_existing_edges(
entity_node_edges = []
type_entity_edges = []
for graph_source, graph in graph_node_index:
for index, data_chunk in enumerate(data_chunks):
graph = chunk_graphs[index]
for node in graph.nodes:
type_node_id = generate_node_id(node.type)
entity_node_id = generate_node_id(node.id)
if str(type_node_id) not in processed_nodes:
type_node_edges.append(
(str(graph_source), str(type_node_id), "exists_in")
(data_chunk.id, type_node_id, "exists_in")
)
processed_nodes[str(type_node_id)] = True
if str(entity_node_id) not in processed_nodes:
entity_node_edges.append(
(str(graph_source), entity_node_id, "mentioned_in")
(data_chunk.id, entity_node_id, "mentioned_in")
)
type_entity_edges.append(
(str(entity_node_id), str(type_node_id), "is_a")
(entity_node_id, type_node_id, "is_a")
)
processed_nodes[str(entity_node_id)] = True

View file

@ -7,7 +7,7 @@ class Repository(DataPoint):
type: Optional[str] = "Repository"
class CodeFile(DataPoint):
__tablename__ = "CodeFile"
__tablename__ = "codefile"
extracted_id: str # actually file path
type: Optional[str] = "CodeFile"
source_code: Optional[str] = None
@ -21,7 +21,7 @@ class CodeFile(DataPoint):
}
class CodePart(DataPoint):
__tablename__ = "CodePart"
__tablename__ = "codepart"
# part_of: Optional[CodeFile]
source_code: str
type: Optional[str] = "CodePart"

View file

@ -20,16 +20,16 @@ async def extract_graph_from_data(
*[extract_content_graph(chunk.text, graph_model) for chunk in data_chunks]
)
graph_engine = await get_graph_engine()
chunk_and_chunk_graphs = [
(chunk, chunk_graph) for chunk, chunk_graph in zip(data_chunks, chunk_graphs)
]
existing_edges_map = await retrieve_existing_edges(
chunk_and_chunk_graphs,
data_chunks,
chunk_graphs,
graph_engine,
)
graph_nodes, graph_edges = expand_with_nodes_and_edges(
chunk_and_chunk_graphs,
data_chunks,
chunk_graphs,
existing_edges_map,
)

View file

@ -70,7 +70,7 @@ async def node_enrich_and_connect(
if desc_id in data_points_map:
desc = data_points_map[desc_id]
else:
node_data = await graph_engine.extract_node(desc_id)
node_data = await graph_engine.extract_node(str(desc_id))
try:
desc = convert_node_to_data_point(node_data)
except Exception:
@ -87,9 +87,17 @@ async def enrich_dependency_graph(data_points: list[DataPoint]) -> AsyncGenerato
"""Enriches the graph with topological ranks and 'depends_on' edges."""
nodes = []
edges = []
added_nodes = {}
added_edges = {}
visited_properties = {}
for data_point in data_points:
graph_nodes, graph_edges = await get_graph_from_model(data_point)
graph_nodes, graph_edges = await get_graph_from_model(
data_point,
added_nodes = added_nodes,
added_edges = added_edges,
visited_properties = visited_properties,
)
nodes.extend(graph_nodes)
edges.extend(graph_edges)

View file

@ -11,12 +11,14 @@ async def add_data_points(data_points: list[DataPoint]):
added_nodes = {}
added_edges = {}
visited_properties = {}
results = await asyncio.gather(*[
get_graph_from_model(
data_point,
added_nodes = added_nodes,
added_edges = added_edges,
visited_properties = visited_properties,
) for data_point in data_points
])

View file

@ -1,6 +1,5 @@
from cognee.infrastructure.engine import DataPoint
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
from cognee.modules.data.processing.document_types import Document
from cognee.modules.chunking.models import DocumentChunk
from cognee.shared.CodeGraphEntities import CodeFile

View file

@ -4,7 +4,6 @@ from uuid import uuid5
from pydantic import BaseModel
from cognee.modules.data.extraction.extract_summary import extract_summary
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
from cognee.tasks.storage import add_data_points
from .models import TextSummary
async def summarize_text(data_chunks: list[DocumentChunk], summarization_model: Type[BaseModel]):
@ -23,6 +22,4 @@ async def summarize_text(data_chunks: list[DocumentChunk], summarization_model:
) for (chunk_index, chunk) in enumerate(data_chunks)
]
await add_data_points(summaries)
return data_chunks
return summaries

View file

@ -73,10 +73,13 @@ async def test_circular_reference_extraction():
nodes = []
edges = []
added_nodes = {}
added_edges = {}
start = time.perf_counter_ns()
results = await asyncio.gather(*[
get_graph_from_model(code_file) for code_file in code_files
get_graph_from_model(code_file, added_nodes = added_nodes, added_edges = added_edges) for code_file in code_files
])
time_to_run = time.perf_counter_ns() - start
@ -87,12 +90,6 @@ async def test_circular_reference_extraction():
nodes.extend(result_nodes)
edges.extend(result_edges)
# for code_file in code_files:
# model_nodes, model_edges = get_graph_from_model(code_file)
# nodes.extend(model_nodes)
# edges.extend(model_edges)
assert len(nodes) == 1501
assert len(edges) == 1501 * 20 + 1500 * 5

View file

@ -0,0 +1,69 @@
import asyncio
import random
from typing import List
from uuid import uuid5, NAMESPACE_OID
from cognee.infrastructure.engine import DataPoint
from cognee.modules.graph.utils import get_graph_from_model
class Document(DataPoint):
path: str
class DocumentChunk(DataPoint):
part_of: Document
text: str
contains: List["Entity"] = None
class EntityType(DataPoint):
name: str
class Entity(DataPoint):
name: str
is_type: EntityType
DocumentChunk.model_rebuild()
async def get_graph_from_model_test():
document = Document(path = "file_path")
document_chunks = [DocumentChunk(
id = uuid5(NAMESPACE_OID, f"file{file_index}"),
text = "some text",
part_of = document,
contains = [],
) for file_index in range(1)]
for document_chunk in document_chunks:
document_chunk.contains.append(Entity(
name = f"Entity",
is_type = EntityType(
name = "Type 1",
),
))
nodes = []
edges = []
added_nodes = {}
added_edges = {}
visited_properties = {}
results = await asyncio.gather(*[
get_graph_from_model(
document_chunk,
added_nodes = added_nodes,
added_edges = added_edges,
visited_properties = visited_properties,
) for document_chunk in document_chunks
])
for result_nodes, result_edges in results:
nodes.extend(result_nodes)
edges.extend(result_edges)
assert len(nodes) == 4
assert len(edges) == 3
if __name__ == "__main__":
asyncio.run(get_graph_from_model_test())

View file

@ -64,7 +64,6 @@ async def generate_patch_with_cognee(instance, llm_client, search_type=SearchTyp
tasks = [
Task(get_repo_file_dependencies),
Task(add_data_points, task_config = { "batch_size": 50 }),
Task(enrich_dependency_graph, task_config = { "batch_size": 50 }),
Task(expand_dependency_graph, task_config = { "batch_size": 50 }),
Task(add_data_points, task_config = { "batch_size": 50 }),

View file

@ -265,7 +265,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"id": "df16431d0f48b006",
"metadata": {
"ExecuteTime": {
@ -304,7 +304,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"id": "9086abf3af077ab4",
"metadata": {
"ExecuteTime": {
@ -349,7 +349,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"id": "a9de0cc07f798b7f",
"metadata": {
"ExecuteTime": {
@ -393,7 +393,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"id": "185ff1c102d06111",
"metadata": {
"ExecuteTime": {
@ -437,7 +437,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"id": "d55ce4c58f8efb67",
"metadata": {
"ExecuteTime": {
@ -479,7 +479,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"id": "ca4ecc32721ad332",
"metadata": {
"ExecuteTime": {
@ -529,7 +529,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"id": "bce39dc6",
"metadata": {},
"outputs": [],
@ -622,7 +622,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"id": "7c431fdef4921ae0",
"metadata": {
"ExecuteTime": {
@ -654,13 +654,13 @@
" Task(classify_documents),\n",
" Task(check_permissions_on_documents, user = user, permissions = [\"write\"]),\n",
" Task(extract_chunks_from_documents), # Extract text chunks based on the document type.\n",
" Task(add_data_points, task_config = { \"batch_size\": 10 }),\n",
" Task(extract_graph_from_data, graph_model = KnowledgeGraph, task_config = { \"batch_size\": 10 }), # Generate knowledge graphs from the document chunks.\n",
" Task(\n",
" summarize_text,\n",
" summarization_model = cognee_config.summarization_model,\n",
" task_config = { \"batch_size\": 10 }\n",
" ),\n",
" Task(add_data_points, task_config = { \"batch_size\": 10 }),\n",
" ]\n",
"\n",
" pipeline = run_tasks(tasks, data_documents)\n",
@ -883,7 +883,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.6"
"version": "3.11.8"
}
},
"nbformat": 4,

View file

@ -28,10 +28,27 @@ if __name__ == "__main__":
society = create_organization_recursive(
"society", "Society", PERSON_NAMES, args.recursive_depth
)
nodes, edges = asyncio.run(get_graph_from_model(society))
added_nodes = {}
added_edges = {}
visited_properties = {}
nodes, edges = asyncio.run(get_graph_from_model(
society,
added_nodes = added_nodes,
added_edges = added_edges,
visited_properties = visited_properties,
))
def get_graph_from_model_sync(model):
return asyncio.run(get_graph_from_model(model))
added_nodes = {}
added_edges = {}
visited_properties = {}
return asyncio.run(get_graph_from_model(
model,
added_nodes = added_nodes,
added_edges = added_edges,
visited_properties = visited_properties,
))
results = benchmark_function(get_graph_from_model_sync, society, num_runs=args.runs)
print("\nBenchmark Results:")