fix: refactor get_graph_from_model to return nodes and edges correctly (#257)
* fix: handle rate limit error coming from llm model * fix: fixes lost edges and nodes in get_graph_from_model * fix: fixes database pruning issue in pgvector (#261) * fix: cognee_demo notebook pipeline is not saving summaries --------- Co-authored-by: hajdul88 <52442977+hajdul88@users.noreply.github.com>
This commit is contained in:
parent
351ce92001
commit
348610e73c
22 changed files with 242 additions and 160 deletions
|
|
@ -81,13 +81,13 @@ async def run_cognify_pipeline(dataset: Dataset, user: User):
|
|||
Task(classify_documents),
|
||||
Task(check_permissions_on_documents, user = user, permissions = ["write"]),
|
||||
Task(extract_chunks_from_documents), # Extract text chunks based on the document type.
|
||||
Task(add_data_points, task_config = { "batch_size": 10 }),
|
||||
Task(extract_graph_from_data, graph_model = KnowledgeGraph, task_config = { "batch_size": 10 }), # Generate knowledge graphs from the document chunks.
|
||||
Task(
|
||||
summarize_text,
|
||||
summarization_model = cognee_config.summarization_model,
|
||||
task_config = { "batch_size": 10 }
|
||||
),
|
||||
Task(add_data_points, task_config = { "batch_size": 10 }),
|
||||
]
|
||||
|
||||
pipeline = run_tasks(tasks, data_documents, "cognify_pipeline")
|
||||
|
|
|
|||
|
|
@ -29,7 +29,14 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
self.model = model
|
||||
self.dimensions = dimensions
|
||||
|
||||
MAX_RETRIES = 5
|
||||
retry_count = 0
|
||||
|
||||
async def embed_text(self, text: List[str]) -> List[List[float]]:
|
||||
async def exponential_backoff(attempt):
|
||||
wait_time = min(10 * (2 ** attempt), 60) # Max 60 seconds
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
try:
|
||||
response = await litellm.aembedding(
|
||||
self.model,
|
||||
|
|
@ -38,11 +45,18 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
api_base = self.endpoint,
|
||||
api_version = self.api_version
|
||||
)
|
||||
|
||||
self.retry_count = 0
|
||||
|
||||
return [data["embedding"] for data in response.data]
|
||||
|
||||
except litellm.exceptions.ContextWindowExceededError as error:
|
||||
if isinstance(text, list):
|
||||
if len(text) == 1:
|
||||
parts = [text]
|
||||
else:
|
||||
parts = [text[0:math.ceil(len(text)/2)], text[math.ceil(len(text)/2):]]
|
||||
|
||||
parts_futures = [self.embed_text(part) for part in parts]
|
||||
embeddings = await asyncio.gather(*parts_futures)
|
||||
|
||||
|
|
@ -50,11 +64,21 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
for embeddings_part in embeddings:
|
||||
all_embeddings.extend(embeddings_part)
|
||||
|
||||
return [data["embedding"] for data in all_embeddings]
|
||||
return all_embeddings
|
||||
|
||||
logger.error("Context window exceeded for embedding text: %s", str(error))
|
||||
raise error
|
||||
|
||||
except litellm.exceptions.RateLimitError:
|
||||
if self.retry_count >= self.MAX_RETRIES:
|
||||
raise Exception(f"Rate limit exceeded and no more retries left.")
|
||||
|
||||
await exponential_backoff(self.retry_count)
|
||||
|
||||
self.retry_count += 1
|
||||
|
||||
return await self.embed_text(text)
|
||||
|
||||
except Exception as error:
|
||||
logger.error("Error embedding text: %s", str(error))
|
||||
raise error
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ class TextChunker():
|
|||
is_part_of = self.document,
|
||||
chunk_index = self.chunk_index,
|
||||
cut_type = chunk_data["cut_type"],
|
||||
contains = [],
|
||||
_metadata = {
|
||||
"index_fields": ["text"],
|
||||
"metadata_id": self.document.metadata_id
|
||||
|
|
@ -52,6 +53,7 @@ class TextChunker():
|
|||
is_part_of = self.document,
|
||||
chunk_index = self.chunk_index,
|
||||
cut_type = paragraph_chunks[len(paragraph_chunks) - 1]["cut_type"],
|
||||
contains = [],
|
||||
_metadata = {
|
||||
"index_fields": ["text"],
|
||||
"metadata_id": self.document.metadata_id
|
||||
|
|
@ -73,6 +75,7 @@ class TextChunker():
|
|||
is_part_of = self.document,
|
||||
chunk_index = self.chunk_index,
|
||||
cut_type = paragraph_chunks[len(paragraph_chunks) - 1]["cut_type"],
|
||||
contains = [],
|
||||
_metadata = {
|
||||
"index_fields": ["text"],
|
||||
"metadata_id": self.document.metadata_id
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.data.processing.document_types import Document
|
||||
from cognee.modules.engine.models import Entity
|
||||
|
||||
class DocumentChunk(DataPoint):
|
||||
__tablename__ = "document_chunk"
|
||||
|
|
@ -9,6 +10,7 @@ class DocumentChunk(DataPoint):
|
|||
chunk_index: int
|
||||
cut_type: str
|
||||
is_part_of: Document
|
||||
contains: List[Entity] = None
|
||||
|
||||
_metadata: Optional[dict] = {
|
||||
"index_fields": ["text"],
|
||||
|
|
|
|||
1
cognee/modules/chunking/models/__init__.py
Normal file
1
cognee/modules/chunking/models/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
from .DocumentChunk import DocumentChunk
|
||||
|
|
@ -1,5 +1,4 @@
|
|||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
|
||||
from cognee.modules.engine.models.EntityType import EntityType
|
||||
|
||||
|
||||
|
|
@ -8,7 +7,6 @@ class Entity(DataPoint):
|
|||
name: str
|
||||
is_a: EntityType
|
||||
description: str
|
||||
mentioned_in: DocumentChunk
|
||||
|
||||
_metadata: dict = {
|
||||
"index_fields": ["name"],
|
||||
|
|
|
|||
|
|
@ -1,13 +1,10 @@
|
|||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
|
||||
|
||||
|
||||
class EntityType(DataPoint):
|
||||
__tablename__ = "entity_type"
|
||||
name: str
|
||||
type: str
|
||||
description: str
|
||||
exists_in: DocumentChunk
|
||||
|
||||
_metadata: dict = {
|
||||
"index_fields": ["name"],
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Optional
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.chunking.models import DocumentChunk
|
||||
from cognee.modules.engine.models import Entity, EntityType
|
||||
from cognee.modules.engine.utils import (
|
||||
generate_edge_name,
|
||||
|
|
@ -11,7 +11,8 @@ from cognee.shared.data_models import KnowledgeGraph
|
|||
|
||||
|
||||
def expand_with_nodes_and_edges(
|
||||
graph_node_index: list[tuple[DataPoint, KnowledgeGraph]],
|
||||
data_chunks: list[DocumentChunk],
|
||||
chunk_graphs: list[KnowledgeGraph],
|
||||
existing_edges_map: Optional[dict[str, bool]] = None,
|
||||
):
|
||||
if existing_edges_map is None:
|
||||
|
|
@ -19,9 +20,10 @@ def expand_with_nodes_and_edges(
|
|||
|
||||
added_nodes_map = {}
|
||||
relationships = []
|
||||
data_points = []
|
||||
|
||||
for graph_source, graph in graph_node_index:
|
||||
for index, data_chunk in enumerate(data_chunks):
|
||||
graph = chunk_graphs[index]
|
||||
|
||||
if graph is None:
|
||||
continue
|
||||
|
||||
|
|
@ -38,7 +40,6 @@ def expand_with_nodes_and_edges(
|
|||
name = type_node_name,
|
||||
type = type_node_name,
|
||||
description = type_node_name,
|
||||
exists_in = graph_source,
|
||||
)
|
||||
added_nodes_map[f"{str(type_node_id)}_type"] = type_node
|
||||
else:
|
||||
|
|
@ -50,9 +51,13 @@ def expand_with_nodes_and_edges(
|
|||
name = node_name,
|
||||
is_a = type_node,
|
||||
description = node.description,
|
||||
mentioned_in = graph_source,
|
||||
)
|
||||
data_points.append(entity_node)
|
||||
|
||||
if data_chunk.contains is None:
|
||||
data_chunk.contains = []
|
||||
|
||||
data_chunk.contains.append(entity_node)
|
||||
|
||||
added_nodes_map[f"{str(node_id)}_entity"] = entity_node
|
||||
|
||||
# Add relationship that came from graphs.
|
||||
|
|
@ -80,4 +85,4 @@ def expand_with_nodes_and_edges(
|
|||
)
|
||||
existing_edges_map[edge_key] = True
|
||||
|
||||
return (data_points, relationships)
|
||||
return (data_chunks, relationships)
|
||||
|
|
|
|||
|
|
@ -1,154 +1,115 @@
|
|||
from datetime import datetime, timezone
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.storage.utils import copy_model
|
||||
|
||||
async def get_graph_from_model(
|
||||
data_point: DataPoint,
|
||||
added_nodes: dict,
|
||||
added_edges: dict,
|
||||
visited_properties: dict = None,
|
||||
include_root = True,
|
||||
added_nodes = None,
|
||||
added_edges = None,
|
||||
visited_properties = None,
|
||||
):
|
||||
if str(data_point.id) in added_nodes:
|
||||
return [], []
|
||||
|
||||
nodes = []
|
||||
edges = []
|
||||
added_nodes = added_nodes or {}
|
||||
added_edges = added_edges or {}
|
||||
visited_properties = visited_properties or {}
|
||||
|
||||
data_point_properties = {}
|
||||
excluded_properties = set()
|
||||
|
||||
if str(data_point.id) in added_nodes:
|
||||
return nodes, edges
|
||||
properties_to_visit = set()
|
||||
|
||||
for field_name, field_value in data_point:
|
||||
if field_name == "_metadata":
|
||||
continue
|
||||
|
||||
if field_value is None:
|
||||
excluded_properties.add(field_name)
|
||||
continue
|
||||
|
||||
if isinstance(field_value, DataPoint):
|
||||
excluded_properties.add(field_name)
|
||||
|
||||
property_key = f"{str(data_point.id)}{field_name}{str(field_value.id)}"
|
||||
property_key = str(data_point.id) + field_name + str(field_value.id)
|
||||
|
||||
if property_key in visited_properties:
|
||||
continue
|
||||
|
||||
visited_properties[property_key] = True
|
||||
|
||||
nodes, edges = await add_nodes_and_edges(
|
||||
data_point,
|
||||
field_name,
|
||||
field_value,
|
||||
nodes,
|
||||
edges,
|
||||
added_nodes,
|
||||
added_edges,
|
||||
visited_properties,
|
||||
)
|
||||
properties_to_visit.add(field_name)
|
||||
|
||||
continue
|
||||
|
||||
if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint):
|
||||
excluded_properties.add(field_name)
|
||||
|
||||
for field_value_item in field_value:
|
||||
property_key = f"{str(data_point.id)}{field_name}{str(field_value_item.id)}"
|
||||
for index, item in enumerate(field_value):
|
||||
property_key = str(data_point.id) + field_name + str(item.id)
|
||||
|
||||
if property_key in visited_properties:
|
||||
continue
|
||||
|
||||
visited_properties[property_key] = True
|
||||
|
||||
nodes, edges = await add_nodes_and_edges(
|
||||
data_point,
|
||||
field_name,
|
||||
field_value_item,
|
||||
nodes,
|
||||
edges,
|
||||
added_nodes,
|
||||
added_edges,
|
||||
visited_properties,
|
||||
)
|
||||
properties_to_visit.add(f"{field_name}.{index}")
|
||||
|
||||
continue
|
||||
|
||||
data_point_properties[field_name] = field_value
|
||||
|
||||
if include_root:
|
||||
|
||||
if include_root and str(data_point.id) not in added_nodes:
|
||||
SimpleDataPointModel = copy_model(
|
||||
type(data_point),
|
||||
include_fields = {
|
||||
"_metadata": (dict, data_point._metadata),
|
||||
"__tablename__": data_point.__tablename__,
|
||||
"__tablename__": (str, data_point.__tablename__),
|
||||
},
|
||||
exclude_fields = excluded_properties,
|
||||
exclude_fields = list(excluded_properties),
|
||||
)
|
||||
nodes.append(SimpleDataPointModel(**data_point_properties))
|
||||
added_nodes[str(data_point.id)] = True
|
||||
|
||||
return nodes, edges
|
||||
for field_name in properties_to_visit:
|
||||
index = None
|
||||
|
||||
if "." in field_name:
|
||||
field_name, index = field_name.split(".")
|
||||
|
||||
field_value = getattr(data_point, field_name)
|
||||
|
||||
if index is not None:
|
||||
field_value = field_value[int(index)]
|
||||
|
||||
edge_key = str(data_point.id) + str(field_value.id) + field_name
|
||||
|
||||
if str(edge_key) not in added_edges:
|
||||
edges.append((data_point.id, field_value.id, field_name, {
|
||||
"source_node_id": data_point.id,
|
||||
"target_node_id": field_value.id,
|
||||
"relationship_name": field_name,
|
||||
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
|
||||
}))
|
||||
added_edges[str(edge_key)] = True
|
||||
|
||||
if str(field_value.id) in added_nodes:
|
||||
continue
|
||||
|
||||
async def add_nodes_and_edges(
|
||||
data_point,
|
||||
field_name,
|
||||
field_value,
|
||||
nodes,
|
||||
edges,
|
||||
added_nodes,
|
||||
added_edges,
|
||||
visited_properties,
|
||||
):
|
||||
property_nodes, property_edges = await get_graph_from_model(
|
||||
field_value,
|
||||
True,
|
||||
added_nodes,
|
||||
added_edges,
|
||||
visited_properties,
|
||||
include_root = True,
|
||||
added_nodes = added_nodes,
|
||||
added_edges = added_edges,
|
||||
visited_properties = visited_properties,
|
||||
)
|
||||
|
||||
for node in property_nodes:
|
||||
if str(node.id) not in added_nodes:
|
||||
nodes.append(node)
|
||||
added_nodes[str(node.id)] = True
|
||||
|
||||
for edge in property_edges:
|
||||
edge_key = str(edge[0]) + str(edge[1]) + edge[2]
|
||||
|
||||
if str(edge_key) not in added_edges:
|
||||
edges.append(edge)
|
||||
added_edges[str(edge_key)] = True
|
||||
|
||||
for property_node in get_own_properties(property_nodes, property_edges):
|
||||
edge_key = str(data_point.id) + str(property_node.id) + field_name
|
||||
property_key = str(data_point.id) + field_name + str(field_value.id)
|
||||
visited_properties[property_key] = True
|
||||
|
||||
if str(edge_key) not in added_edges:
|
||||
edges.append(
|
||||
(
|
||||
data_point.id,
|
||||
property_node.id,
|
||||
field_name,
|
||||
{
|
||||
"source_node_id": data_point.id,
|
||||
"target_node_id": property_node.id,
|
||||
"relationship_name": field_name,
|
||||
"updated_at": datetime.now(timezone.utc).strftime(
|
||||
"%Y-%m-%d %H:%M:%S"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
added_edges[str(edge_key)] = True
|
||||
|
||||
return (nodes, edges)
|
||||
return nodes, edges
|
||||
|
||||
|
||||
def get_own_properties(property_nodes, property_edges):
|
||||
def get_own_property_nodes(property_nodes, property_edges):
|
||||
own_properties = []
|
||||
|
||||
destination_nodes = [str(property_edge[1]) for property_edge in property_edges]
|
||||
|
|
|
|||
|
|
@ -5,7 +5,8 @@ from cognee.shared.data_models import KnowledgeGraph
|
|||
|
||||
|
||||
async def retrieve_existing_edges(
|
||||
graph_node_index: list[tuple[DataPoint, KnowledgeGraph]],
|
||||
data_chunks: list[DataPoint],
|
||||
chunk_graphs: list[KnowledgeGraph],
|
||||
graph_engine: GraphDBInterface,
|
||||
) -> dict[str, bool]:
|
||||
processed_nodes = {}
|
||||
|
|
@ -13,23 +14,25 @@ async def retrieve_existing_edges(
|
|||
entity_node_edges = []
|
||||
type_entity_edges = []
|
||||
|
||||
for graph_source, graph in graph_node_index:
|
||||
for index, data_chunk in enumerate(data_chunks):
|
||||
graph = chunk_graphs[index]
|
||||
|
||||
for node in graph.nodes:
|
||||
type_node_id = generate_node_id(node.type)
|
||||
entity_node_id = generate_node_id(node.id)
|
||||
|
||||
if str(type_node_id) not in processed_nodes:
|
||||
type_node_edges.append(
|
||||
(str(graph_source), str(type_node_id), "exists_in")
|
||||
(data_chunk.id, type_node_id, "exists_in")
|
||||
)
|
||||
processed_nodes[str(type_node_id)] = True
|
||||
|
||||
if str(entity_node_id) not in processed_nodes:
|
||||
entity_node_edges.append(
|
||||
(str(graph_source), entity_node_id, "mentioned_in")
|
||||
(data_chunk.id, entity_node_id, "mentioned_in")
|
||||
)
|
||||
type_entity_edges.append(
|
||||
(str(entity_node_id), str(type_node_id), "is_a")
|
||||
(entity_node_id, type_node_id, "is_a")
|
||||
)
|
||||
processed_nodes[str(entity_node_id)] = True
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ class Repository(DataPoint):
|
|||
type: Optional[str] = "Repository"
|
||||
|
||||
class CodeFile(DataPoint):
|
||||
__tablename__ = "CodeFile"
|
||||
__tablename__ = "codefile"
|
||||
extracted_id: str # actually file path
|
||||
type: Optional[str] = "CodeFile"
|
||||
source_code: Optional[str] = None
|
||||
|
|
@ -21,7 +21,7 @@ class CodeFile(DataPoint):
|
|||
}
|
||||
|
||||
class CodePart(DataPoint):
|
||||
__tablename__ = "CodePart"
|
||||
__tablename__ = "codepart"
|
||||
# part_of: Optional[CodeFile]
|
||||
source_code: str
|
||||
type: Optional[str] = "CodePart"
|
||||
|
|
|
|||
|
|
@ -20,16 +20,16 @@ async def extract_graph_from_data(
|
|||
*[extract_content_graph(chunk.text, graph_model) for chunk in data_chunks]
|
||||
)
|
||||
graph_engine = await get_graph_engine()
|
||||
chunk_and_chunk_graphs = [
|
||||
(chunk, chunk_graph) for chunk, chunk_graph in zip(data_chunks, chunk_graphs)
|
||||
]
|
||||
|
||||
existing_edges_map = await retrieve_existing_edges(
|
||||
chunk_and_chunk_graphs,
|
||||
data_chunks,
|
||||
chunk_graphs,
|
||||
graph_engine,
|
||||
)
|
||||
|
||||
graph_nodes, graph_edges = expand_with_nodes_and_edges(
|
||||
chunk_and_chunk_graphs,
|
||||
data_chunks,
|
||||
chunk_graphs,
|
||||
existing_edges_map,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ async def node_enrich_and_connect(
|
|||
if desc_id in data_points_map:
|
||||
desc = data_points_map[desc_id]
|
||||
else:
|
||||
node_data = await graph_engine.extract_node(desc_id)
|
||||
node_data = await graph_engine.extract_node(str(desc_id))
|
||||
try:
|
||||
desc = convert_node_to_data_point(node_data)
|
||||
except Exception:
|
||||
|
|
@ -87,9 +87,17 @@ async def enrich_dependency_graph(data_points: list[DataPoint]) -> AsyncGenerato
|
|||
"""Enriches the graph with topological ranks and 'depends_on' edges."""
|
||||
nodes = []
|
||||
edges = []
|
||||
added_nodes = {}
|
||||
added_edges = {}
|
||||
visited_properties = {}
|
||||
|
||||
for data_point in data_points:
|
||||
graph_nodes, graph_edges = await get_graph_from_model(data_point)
|
||||
graph_nodes, graph_edges = await get_graph_from_model(
|
||||
data_point,
|
||||
added_nodes = added_nodes,
|
||||
added_edges = added_edges,
|
||||
visited_properties = visited_properties,
|
||||
)
|
||||
nodes.extend(graph_nodes)
|
||||
edges.extend(graph_edges)
|
||||
|
||||
|
|
|
|||
|
|
@ -11,12 +11,14 @@ async def add_data_points(data_points: list[DataPoint]):
|
|||
|
||||
added_nodes = {}
|
||||
added_edges = {}
|
||||
visited_properties = {}
|
||||
|
||||
results = await asyncio.gather(*[
|
||||
get_graph_from_model(
|
||||
data_point,
|
||||
added_nodes = added_nodes,
|
||||
added_edges = added_edges,
|
||||
visited_properties = visited_properties,
|
||||
) for data_point in data_points
|
||||
])
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
|
||||
from cognee.modules.data.processing.document_types import Document
|
||||
from cognee.modules.chunking.models import DocumentChunk
|
||||
from cognee.shared.CodeGraphEntities import CodeFile
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ from uuid import uuid5
|
|||
from pydantic import BaseModel
|
||||
from cognee.modules.data.extraction.extract_summary import extract_summary
|
||||
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from .models import TextSummary
|
||||
|
||||
async def summarize_text(data_chunks: list[DocumentChunk], summarization_model: Type[BaseModel]):
|
||||
|
|
@ -23,6 +22,4 @@ async def summarize_text(data_chunks: list[DocumentChunk], summarization_model:
|
|||
) for (chunk_index, chunk) in enumerate(data_chunks)
|
||||
]
|
||||
|
||||
await add_data_points(summaries)
|
||||
|
||||
return data_chunks
|
||||
return summaries
|
||||
|
|
|
|||
|
|
@ -73,10 +73,13 @@ async def test_circular_reference_extraction():
|
|||
nodes = []
|
||||
edges = []
|
||||
|
||||
added_nodes = {}
|
||||
added_edges = {}
|
||||
|
||||
start = time.perf_counter_ns()
|
||||
|
||||
results = await asyncio.gather(*[
|
||||
get_graph_from_model(code_file) for code_file in code_files
|
||||
get_graph_from_model(code_file, added_nodes = added_nodes, added_edges = added_edges) for code_file in code_files
|
||||
])
|
||||
|
||||
time_to_run = time.perf_counter_ns() - start
|
||||
|
|
@ -87,12 +90,6 @@ async def test_circular_reference_extraction():
|
|||
nodes.extend(result_nodes)
|
||||
edges.extend(result_edges)
|
||||
|
||||
# for code_file in code_files:
|
||||
# model_nodes, model_edges = get_graph_from_model(code_file)
|
||||
|
||||
# nodes.extend(model_nodes)
|
||||
# edges.extend(model_edges)
|
||||
|
||||
assert len(nodes) == 1501
|
||||
assert len(edges) == 1501 * 20 + 1500 * 5
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,69 @@
|
|||
import asyncio
|
||||
import random
|
||||
from typing import List
|
||||
from uuid import uuid5, NAMESPACE_OID
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.graph.utils import get_graph_from_model
|
||||
|
||||
class Document(DataPoint):
|
||||
path: str
|
||||
|
||||
class DocumentChunk(DataPoint):
|
||||
part_of: Document
|
||||
text: str
|
||||
contains: List["Entity"] = None
|
||||
|
||||
class EntityType(DataPoint):
|
||||
name: str
|
||||
|
||||
class Entity(DataPoint):
|
||||
name: str
|
||||
is_type: EntityType
|
||||
|
||||
DocumentChunk.model_rebuild()
|
||||
|
||||
|
||||
async def get_graph_from_model_test():
|
||||
document = Document(path = "file_path")
|
||||
|
||||
document_chunks = [DocumentChunk(
|
||||
id = uuid5(NAMESPACE_OID, f"file{file_index}"),
|
||||
text = "some text",
|
||||
part_of = document,
|
||||
contains = [],
|
||||
) for file_index in range(1)]
|
||||
|
||||
for document_chunk in document_chunks:
|
||||
document_chunk.contains.append(Entity(
|
||||
name = f"Entity",
|
||||
is_type = EntityType(
|
||||
name = "Type 1",
|
||||
),
|
||||
))
|
||||
|
||||
nodes = []
|
||||
edges = []
|
||||
|
||||
added_nodes = {}
|
||||
added_edges = {}
|
||||
visited_properties = {}
|
||||
|
||||
results = await asyncio.gather(*[
|
||||
get_graph_from_model(
|
||||
document_chunk,
|
||||
added_nodes = added_nodes,
|
||||
added_edges = added_edges,
|
||||
visited_properties = visited_properties,
|
||||
) for document_chunk in document_chunks
|
||||
])
|
||||
|
||||
for result_nodes, result_edges in results:
|
||||
nodes.extend(result_nodes)
|
||||
edges.extend(result_edges)
|
||||
|
||||
assert len(nodes) == 4
|
||||
assert len(edges) == 3
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(get_graph_from_model_test())
|
||||
|
|
@ -64,7 +64,6 @@ async def generate_patch_with_cognee(instance, llm_client, search_type=SearchTyp
|
|||
|
||||
tasks = [
|
||||
Task(get_repo_file_dependencies),
|
||||
Task(add_data_points, task_config = { "batch_size": 50 }),
|
||||
Task(enrich_dependency_graph, task_config = { "batch_size": 50 }),
|
||||
Task(expand_dependency_graph, task_config = { "batch_size": 50 }),
|
||||
Task(add_data_points, task_config = { "batch_size": 50 }),
|
||||
|
|
|
|||
|
|
@ -265,7 +265,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": null,
|
||||
"id": "df16431d0f48b006",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
|
|
@ -304,7 +304,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": null,
|
||||
"id": "9086abf3af077ab4",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
|
|
@ -349,7 +349,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": null,
|
||||
"id": "a9de0cc07f798b7f",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
|
|
@ -393,7 +393,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": null,
|
||||
"id": "185ff1c102d06111",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
|
|
@ -437,7 +437,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": null,
|
||||
"id": "d55ce4c58f8efb67",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
|
|
@ -479,7 +479,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": null,
|
||||
"id": "ca4ecc32721ad332",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
|
|
@ -529,7 +529,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": null,
|
||||
"id": "bce39dc6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
|
|
@ -622,7 +622,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": null,
|
||||
"id": "7c431fdef4921ae0",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
|
|
@ -654,13 +654,13 @@
|
|||
" Task(classify_documents),\n",
|
||||
" Task(check_permissions_on_documents, user = user, permissions = [\"write\"]),\n",
|
||||
" Task(extract_chunks_from_documents), # Extract text chunks based on the document type.\n",
|
||||
" Task(add_data_points, task_config = { \"batch_size\": 10 }),\n",
|
||||
" Task(extract_graph_from_data, graph_model = KnowledgeGraph, task_config = { \"batch_size\": 10 }), # Generate knowledge graphs from the document chunks.\n",
|
||||
" Task(\n",
|
||||
" summarize_text,\n",
|
||||
" summarization_model = cognee_config.summarization_model,\n",
|
||||
" task_config = { \"batch_size\": 10 }\n",
|
||||
" ),\n",
|
||||
" Task(add_data_points, task_config = { \"batch_size\": 10 }),\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
" pipeline = run_tasks(tasks, data_documents)\n",
|
||||
|
|
@ -883,7 +883,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.6"
|
||||
"version": "3.11.8"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
|
|
|||
|
|
@ -28,10 +28,27 @@ if __name__ == "__main__":
|
|||
society = create_organization_recursive(
|
||||
"society", "Society", PERSON_NAMES, args.recursive_depth
|
||||
)
|
||||
nodes, edges = asyncio.run(get_graph_from_model(society))
|
||||
added_nodes = {}
|
||||
added_edges = {}
|
||||
visited_properties = {}
|
||||
nodes, edges = asyncio.run(get_graph_from_model(
|
||||
society,
|
||||
added_nodes = added_nodes,
|
||||
added_edges = added_edges,
|
||||
visited_properties = visited_properties,
|
||||
))
|
||||
|
||||
def get_graph_from_model_sync(model):
|
||||
return asyncio.run(get_graph_from_model(model))
|
||||
added_nodes = {}
|
||||
added_edges = {}
|
||||
visited_properties = {}
|
||||
|
||||
return asyncio.run(get_graph_from_model(
|
||||
model,
|
||||
added_nodes = added_nodes,
|
||||
added_edges = added_edges,
|
||||
visited_properties = visited_properties,
|
||||
))
|
||||
|
||||
results = benchmark_function(get_graph_from_model_sync, society, num_runs=args.runs)
|
||||
print("\nBenchmark Results:")
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue