Merge remote-tracking branch 'origin/main'

This commit is contained in:
Boris Arzentar 2024-11-23 14:07:44 +01:00
commit a2fa25fb60
15 changed files with 474 additions and 116 deletions

View file

@ -4,7 +4,7 @@
[![GitHub commits](https://badgen.net/github/commits/topoteretes/cognee)](https://GitHub.com/topoteretes/cognee/commit/)
[![Github tag](https://badgen.net/github/tag/topoteretes/cognee)](https://github.com/topoteretes/cognee/tags/)
[![Downloads](https://static.pepy.tech/badge/cognee)](https://pepy.tech/project/cognee)
[![GitHub license](https://badgen.net/github/license/topoteretes/cognee)](https://github.com/topoteretes/cognee/blob/master/LICENSE)
We build for developers who need a reliable, production-ready data layer for AI applications

View file

@ -1,8 +1,16 @@
from datetime import datetime, timezone
from cognee.infrastructure.engine import DataPoint
from cognee.modules.storage.utils import copy_model
def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes = {}, added_edges = {}):
def get_graph_from_model(data_point: DataPoint, added_nodes=None, added_edges=None):
if not added_nodes:
added_nodes = {}
if not added_edges:
added_edges = {}
nodes = []
edges = []
@ -12,87 +20,94 @@ def get_graph_from_model(data_point: DataPoint, include_root = True, added_nodes
for field_name, field_value in data_point:
if field_name == "_metadata":
continue
if isinstance(field_value, DataPoint):
elif isinstance(field_value, DataPoint):
excluded_properties.add(field_name)
nodes, edges, added_nodes, added_edges = add_nodes_and_edges(
data_point,
field_name,
field_value,
nodes,
edges,
added_nodes,
added_edges,
)
property_nodes, property_edges = get_graph_from_model(field_value, True, added_nodes, added_edges)
for node in property_nodes:
if str(node.id) not in added_nodes:
nodes.append(node)
added_nodes[str(node.id)] = True
for edge in property_edges:
edge_key = str(edge[0]) + str(edge[1]) + edge[2]
if str(edge_key) not in added_edges:
edges.append(edge)
added_edges[str(edge_key)] = True
for property_node in get_own_properties(property_nodes, property_edges):
edge_key = str(data_point.id) + str(property_node.id) + field_name
if str(edge_key) not in added_edges:
edges.append((data_point.id, property_node.id, field_name, {
"source_node_id": data_point.id,
"target_node_id": property_node.id,
"relationship_name": field_name,
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
}))
added_edges[str(edge_key)] = True
continue
if isinstance(field_value, list) and len(field_value) > 0 and isinstance(field_value[0], DataPoint):
elif (
isinstance(field_value, list)
and len(field_value) > 0
and isinstance(field_value[0], DataPoint)
):
excluded_properties.add(field_name)
for item in field_value:
property_nodes, property_edges = get_graph_from_model(item, True, added_nodes, added_edges)
for node in property_nodes:
if str(node.id) not in added_nodes:
nodes.append(node)
added_nodes[str(node.id)] = True
for edge in property_edges:
edge_key = str(edge[0]) + str(edge[1]) + edge[2]
if str(edge_key) not in added_edges:
edges.append(edge)
added_edges[edge_key] = True
for property_node in get_own_properties(property_nodes, property_edges):
edge_key = str(data_point.id) + str(property_node.id) + field_name
if str(edge_key) not in added_edges:
edges.append((data_point.id, property_node.id, field_name, {
"source_node_id": data_point.id,
"target_node_id": property_node.id,
"relationship_name": field_name,
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
"metadata": {
"type": "list"
},
}))
added_edges[edge_key] = True
continue
data_point_properties[field_name] = field_value
n_edges_before = len(edges)
nodes, edges, added_nodes, added_edges = add_nodes_and_edges(
data_point, field_name, item, nodes, edges, added_nodes, added_edges
)
edges = edges[:n_edges_before] + [
(*edge[:3], {**edge[3], "metadata": {"type": "list"}})
for edge in edges[n_edges_before:]
]
else:
data_point_properties[field_name] = field_value
SimpleDataPointModel = copy_model(
type(data_point),
include_fields = {
include_fields={
"_metadata": (dict, data_point._metadata),
},
exclude_fields = excluded_properties,
exclude_fields=excluded_properties,
)
if include_root:
nodes.append(SimpleDataPointModel(**data_point_properties))
nodes.append(SimpleDataPointModel(**data_point_properties))
return nodes, edges
def add_nodes_and_edges(
data_point, field_name, field_value, nodes, edges, added_nodes, added_edges
):
property_nodes, property_edges = get_graph_from_model(
field_value, dict(added_nodes), dict(added_edges)
)
for node in property_nodes:
if str(node.id) not in added_nodes:
nodes.append(node)
added_nodes[str(node.id)] = True
for edge in property_edges:
edge_key = str(edge[0]) + str(edge[1]) + edge[2]
if str(edge_key) not in added_edges:
edges.append(edge)
added_edges[str(edge_key)] = True
for property_node in get_own_properties(property_nodes, property_edges):
edge_key = str(data_point.id) + str(property_node.id) + field_name
if str(edge_key) not in added_edges:
edges.append(
(
data_point.id,
property_node.id,
field_name,
{
"source_node_id": data_point.id,
"target_node_id": property_node.id,
"relationship_name": field_name,
"updated_at": datetime.now(timezone.utc).strftime(
"%Y-%m-%d %H:%M:%S"
),
},
)
)
added_edges[str(edge_key)] = True
return (nodes, edges, added_nodes, added_edges)
def get_own_properties(property_nodes, property_edges):
own_properties = []

View file

@ -1,29 +1,41 @@
from typing import Callable
from pydantic_core import PydanticUndefined
from cognee.infrastructure.engine import DataPoint
from cognee.modules.storage.utils import copy_model
def get_model_instance_from_graph(nodes: list[DataPoint], edges: list, entity_id: str):
node_map = {}
def get_model_instance_from_graph(
nodes: list[DataPoint],
edges: list[tuple[str, str, str, dict[str, str]]],
entity_id: str,
):
node_map = {node.id: node for node in nodes}
for node in nodes:
node_map[node.id] = node
for edge in edges:
source_node = node_map[edge[0]]
target_node = node_map[edge[1]]
edge_label = edge[2]
edge_properties = edge[3] if len(edge) == 4 else {}
for source_node_id, target_node_id, edge_label, edge_properties in edges:
source_node = node_map[source_node_id]
target_node = node_map[target_node_id]
edge_metadata = edge_properties.get("metadata", {})
edge_type = edge_metadata.get("type")
edge_type = edge_metadata.get("type", "default")
if edge_type == "list":
NewModel = copy_model(type(source_node), { edge_label: (list[type(target_node)], PydanticUndefined) })
NewModel = copy_model(
type(source_node),
{edge_label: (list[type(target_node)], PydanticUndefined)},
)
source_node_dict = source_node.model_dump()
source_node_edge_label_values = source_node_dict.get(edge_label, [])
source_node_dict[edge_label] = source_node_edge_label_values + [target_node]
node_map[edge[0]] = NewModel(**source_node.model_dump(), **{ edge_label: [target_node] })
node_map[source_node_id] = NewModel(**source_node_dict)
else:
NewModel = copy_model(type(source_node), { edge_label: (type(target_node), PydanticUndefined) })
NewModel = copy_model(
type(source_node), {edge_label: (type(target_node), PydanticUndefined)}
)
node_map[edge[0]] = NewModel(**source_node.model_dump(), **{ edge_label: target_node })
node_map[target_node_id] = NewModel(
**source_node.model_dump(), **{edge_label: target_node}
)
return node_map[entity_id]

View file

@ -29,7 +29,9 @@ def copy_model(model: DataPoint, include_fields: dict = {}, exclude_fields: list
**include_fields
}
return create_model(model.__name__, **final_fields)
model = create_model(model.__name__, **final_fields)
model.model_rebuild()
return model
def get_own_properties(data_point: DataPoint):
properties = {}

View file

@ -26,7 +26,7 @@ async def main():
await render_graph(None, include_nodes = True, include_labels = True)
search_results = await cognee.search(SearchType.CHUNKS, query = "Student")
search_results = await cognee.search(SearchType.CHUNKS, query_text = "Student")
assert len(search_results) != 0, "The search results list is empty."
print("\n\nExtracted chunks are:\n")
for result in search_results:

View file

@ -1,14 +1,9 @@
from datetime import datetime, timezone
from enum import Enum
from typing import Optional
import pytest
from cognee.infrastructure.engine import DataPoint
from cognee.modules.graph.utils import (
get_graph_from_model,
get_model_instance_from_graph,
)
class CarTypeName(Enum):
@ -47,8 +42,8 @@ class Person(DataPoint):
_metadata: dict = dict(index_fields=["name"])
@pytest.fixture(scope="session")
def graph_outputs():
@pytest.fixture(scope="function")
def boris():
boris = Person(
id="boris",
name="Boris",
@ -70,11 +65,4 @@ def graph_outputs():
"expires_on": "2025-11-06",
},
)
nodes, edges = get_graph_from_model(boris)
car, person = nodes[0], nodes[1]
edge = edges[0]
parsed_person = get_model_instance_from_graph(nodes, edges, "boris")
return (car, person, edge, parsed_person)
return boris

View file

@ -0,0 +1,37 @@
import warnings
import pytest
from cognee.modules.graph.utils import get_graph_from_model
from cognee.tests.unit.interfaces.graph.util import (
PERSON_NAMES,
count_society,
create_organization_recursive,
)
@pytest.mark.parametrize("recursive_depth", [1, 2, 3])
def test_society_nodes_and_edges(recursive_depth):
import sys
if sys.version_info[0] == 3 and sys.version_info[1] >= 11:
society = create_organization_recursive(
"society", "Society", PERSON_NAMES, recursive_depth
)
n_organizations, n_persons = count_society(society)
society_counts_total = n_organizations + n_persons
nodes, edges = get_graph_from_model(society)
assert (
len(nodes) == society_counts_total
), f"{society_counts_total = } != {len(nodes) = }, not all DataPoint instances were found"
assert len(edges) == (
len(nodes) - 1
), f"{(len(nodes) - 1) = } != {len(edges) = }, there have to be n_nodes - 1 edges, as each node has exactly one parent node, except for the root node"
else:
warnings.warn(
"The recursive pydantic data structure cannot be reconstructed from the graph because the 'inner' pydantic class is not defined. Hence this test is skipped. This problem is solved in Python 3.11"
)

View file

@ -1,6 +1,19 @@
from cognee.modules.graph.utils import get_graph_from_model
from cognee.tests.unit.interfaces.graph.util import run_test_against_ground_truth
EDGE_GROUND_TRUTH = (
CAR_SEDAN_EDGE = (
"car1",
"sedan",
"is_type",
{
"source_node_id": "car1",
"target_node_id": "sedan",
"relationship_name": "is_type",
},
)
BORIS_CAR_EDGE_GROUND_TRUTH = (
"boris",
"car1",
"owns_car",
@ -12,6 +25,8 @@ EDGE_GROUND_TRUTH = (
},
)
CAR_TYPE_GROUND_TRUTH = {"id": "sedan"}
CAR_GROUND_TRUTH = {
"id": "car1",
"brand": "Toyota",
@ -33,22 +48,42 @@ PERSON_GROUND_TRUTH = {
}
def test_extracted_person(graph_outputs):
(_, person, _, _) = graph_outputs
run_test_against_ground_truth("person", person, PERSON_GROUND_TRUTH)
def test_extracted_car_type(boris):
nodes, _ = get_graph_from_model(boris)
assert len(nodes) == 3
car_type = nodes[0]
run_test_against_ground_truth("car_type", car_type, CAR_TYPE_GROUND_TRUTH)
def test_extracted_car(graph_outputs):
(car, _, _, _) = graph_outputs
def test_extracted_car(boris):
nodes, _ = get_graph_from_model(boris)
assert len(nodes) == 3
car = nodes[1]
run_test_against_ground_truth("car", car, CAR_GROUND_TRUTH)
def test_extracted_edge(graph_outputs):
(_, _, edge, _) = graph_outputs
def test_extracted_person(boris):
nodes, _ = get_graph_from_model(boris)
assert len(nodes) == 3
person = nodes[2]
run_test_against_ground_truth("person", person, PERSON_GROUND_TRUTH)
def test_extracted_car_sedan_edge(boris):
_, edges = get_graph_from_model(boris)
edge = edges[0]
assert CAR_SEDAN_EDGE[:3] == edge[:3], f"{CAR_SEDAN_EDGE[:3] = } != {edge[:3] = }"
for key, ground_truth in CAR_SEDAN_EDGE[3].items():
assert ground_truth == edge[3][key], f"{ground_truth = } != {edge[3][key] = }"
def test_extracted_boris_car_edge(boris):
_, edges = get_graph_from_model(boris)
edge = edges[1]
assert (
EDGE_GROUND_TRUTH[:3] == edge[:3]
), f"{EDGE_GROUND_TRUTH[:3] = } != {edge[:3] = }"
for key, ground_truth in EDGE_GROUND_TRUTH[3].items():
BORIS_CAR_EDGE_GROUND_TRUTH[:3] == edge[:3]
), f"{BORIS_CAR_EDGE_GROUND_TRUTH[:3] = } != {edge[:3] = }"
for key, ground_truth in BORIS_CAR_EDGE_GROUND_TRUTH[3].items():
assert ground_truth == edge[3][key], f"{ground_truth = } != {edge[3][key] = }"

View file

@ -0,0 +1,33 @@
import warnings
import pytest
from cognee.modules.graph.utils import (
get_graph_from_model,
get_model_instance_from_graph,
)
from cognee.tests.unit.interfaces.graph.util import (
PERSON_NAMES,
create_organization_recursive,
show_first_difference,
)
@pytest.mark.parametrize("recursive_depth", [1, 2, 3])
def test_society_nodes_and_edges(recursive_depth):
import sys
if sys.version_info[0] == 3 and sys.version_info[1] >= 11:
society = create_organization_recursive(
"society", "Society", PERSON_NAMES, recursive_depth
)
nodes, edges = get_graph_from_model(society)
parsed_society = get_model_instance_from_graph(nodes, edges, "society")
assert str(society) == (str(parsed_society)), show_first_difference(
str(society), str(parsed_society), "society", "parsed_society"
)
else:
warnings.warn(
"The recursive pydantic data structure cannot be reconstructed from the graph because the 'inner' pydantic class is not defined. Hence this test is skipped. This problem is solved in Python 3.11"
)

View file

@ -1,3 +1,7 @@
from cognee.modules.graph.utils import (
get_graph_from_model,
get_model_instance_from_graph,
)
from cognee.tests.unit.interfaces.graph.util import run_test_against_ground_truth
PARSED_PERSON_GROUND_TRUTH = {
@ -21,8 +25,10 @@ CAR_GROUND_TRUTH = {
}
def test_parsed_person(graph_outputs):
(_, _, _, parsed_person) = graph_outputs
def test_parsed_person(boris):
nodes, edges = get_graph_from_model(boris)
parsed_person = get_model_instance_from_graph(nodes, edges, "boris")
run_test_against_ground_truth(
"parsed_person", parsed_person, PARSED_PERSON_GROUND_TRUTH
)

View file

@ -1,5 +1,9 @@
import random
import string
from datetime import datetime, timezone
from typing import Any, Dict
from typing import Any, Dict, Optional
from cognee.infrastructure.engine import DataPoint
def run_test_against_ground_truth(
@ -21,6 +25,8 @@ def run_test_against_ground_truth(
assert (
ground_truth2 == getattr(test_target_item, key)[key2]
), f"{test_target_item_name}/{key = }/{key2 = }: {ground_truth2 = } != {getattr(test_target_item, key)[key2] = }"
elif isinstance(ground_truth, list):
raise NotImplementedError("Currently not implemented for 'list'")
else:
assert ground_truth == getattr(
test_target_item, key
@ -28,3 +34,117 @@ def run_test_against_ground_truth(
time_delta = datetime.now(timezone.utc) - getattr(test_target_item, "updated_at")
assert time_delta.total_seconds() < 60, f"{ time_delta.total_seconds() = }"
class Organization(DataPoint):
id: str
name: str
members: Optional[list["SocietyPerson"]]
class SocietyPerson(DataPoint):
id: str
name: str
memberships: Optional[list[Organization]]
SocietyPerson.model_rebuild()
Organization.model_rebuild()
ORGANIZATION_NAMES = [
"ChessClub",
"RowingClub",
"TheatreTroupe",
"PoliticalParty",
"Charity",
"FanClub",
"FilmClub",
"NeighborhoodGroup",
"LocalCouncil",
"Band",
]
PERSON_NAMES = ["Sarah", "Anna", "John", "Sam"]
def create_society_person_recursive(id, name, organization_names, max_depth, depth=0):
id_suffix = "".join(random.choice(string.ascii_lowercase) for _ in range(10))
if depth < max_depth:
memberships = [
create_organization_recursive(
f"{org_name}-{depth}-{id_suffix}",
org_name.lower(),
PERSON_NAMES,
max_depth,
depth + 1,
)
for org_name in organization_names
]
else:
memberships = None
return SocietyPerson(id=id, name=f"{name}{depth}", memberships=memberships)
def create_organization_recursive(id, name, member_names, max_depth, depth=0):
id_suffix = "".join(random.choice(string.ascii_lowercase) for _ in range(10))
if depth < max_depth:
members = [
create_society_person_recursive(
f"{member_name}-{depth}-{id_suffix}",
member_name.lower(),
ORGANIZATION_NAMES,
max_depth,
depth + 1,
)
for member_name in member_names
]
else:
members = None
return Organization(id=id, name=f"{name}{depth}", members=members)
def count_society(obj):
if isinstance(obj, SocietyPerson):
if obj.memberships is not None:
organization_counts, society_person_counts = zip(
*[count_society(organization) for organization in obj.memberships]
)
organization_count = sum(organization_counts)
society_person_count = sum(society_person_counts) + 1
return (organization_count, society_person_count)
else:
return (0, 1)
if isinstance(obj, Organization):
if obj.members is not None:
organization_counts, society_person_counts = zip(
*[count_society(organization) for organization in obj.members]
)
organization_count = sum(organization_counts) + 1
society_person_count = sum(society_person_counts)
return (organization_count, society_person_count)
else:
return (1, 0)
else:
raise Exception("Not allowed")
def show_first_difference(str1, str2, str1_name, str2_name, context=30):
for i, (c1, c2) in enumerate(zip(str1, str2)):
if c1 != c2:
start = max(0, i - context)
end1 = min(len(str1), i + context + 1)
end2 = min(len(str2), i + context + 1)
if i > 0:
return f"identical: '{str1[start:i-1]}' | {str1_name}: '{str1[i-1:end1]}'... != {str2_name}: '{str2[i-1:end2]}'..."
else:
return f"{str1_name} and {str2_name} have no overlap in characters"
if len(str1) > len(str2):
return f"{str2_name} is identical up to the {i}th character, missing afterwards '{str1[i:i+context]}'..."
if len(str2) > len(str1):
return f"{str1_name} is identical up to the {i}th character, missing afterwards '{str2[i:i+context]}'..."
else:
return f"{str1_name} and {str2_name} are identical."

View file

@ -59,7 +59,7 @@ await cognee.add(text) # Add a new piece of information
await cognee.cognify() # Use LLMs and cognee to create knowledge
search_results = await cognee.search("INSIGHTS", {'query': 'Tell me about NLP'}) # Query cognee for the knowledge
search_results = await cognee.search(SearchType.INSIGHTS, query_text='Tell me about NLP') # Query cognee for the knowledge
for result_text in search_results:
print(result_text)

View file

@ -209,7 +209,7 @@ async def main(enable_steps):
if enable_steps.get("search_insights"):
search_results = await cognee.search(
SearchType.INSIGHTS,
{'query': 'Which applicant has the most relevant experience in data science?'}
query_text='Which applicant has the most relevant experience in data science?'
)
print("Search results:")
for result_text in search_results:

View file

@ -0,0 +1,67 @@
import statistics
import time
import tracemalloc
from typing import Any, Callable, Dict
import psutil
def benchmark_function(func: Callable, *args, num_runs: int = 5) -> Dict[str, Any]:
"""
Benchmark a function for memory usage and computational performance.
Args:
func: Function to benchmark
*args: Arguments to pass to the function
num_runs: Number of times to run the benchmark
Returns:
Dictionary containing benchmark metrics
"""
execution_times = []
peak_memory_usages = []
cpu_percentages = []
process = psutil.Process()
for _ in range(num_runs):
# Start memory tracking
tracemalloc.start()
initial_memory = process.memory_info().rss
# Measure execution time and CPU usage
start_time = time.perf_counter()
start_cpu_time = process.cpu_times()
result = func(*args)
end_cpu_time = process.cpu_times()
end_time = time.perf_counter()
# Calculate metrics
execution_time = end_time - start_time
cpu_time = (end_cpu_time.user + end_cpu_time.system) - (
start_cpu_time.user + start_cpu_time.system
)
current, peak = tracemalloc.get_traced_memory()
final_memory = process.memory_info().rss
memory_used = final_memory - initial_memory
# Store results
execution_times.append(execution_time)
peak_memory_usages.append(peak / 1024 / 1024) # Convert to MB
cpu_percentages.append((cpu_time / execution_time) * 100)
tracemalloc.stop()
analysis = {
"mean_execution_time": statistics.mean(execution_times),
"mean_peak_memory_mb": statistics.mean(peak_memory_usages),
"mean_cpu_percent": statistics.mean(cpu_percentages),
"num_runs": num_runs,
}
if num_runs > 1:
analysis["std_execution_time"] = statistics.stdev(execution_times)
return analysis

View file

@ -0,0 +1,43 @@
import argparse
import time
from benchmark_function import benchmark_function
from cognee.modules.graph.utils import get_graph_from_model
from cognee.tests.unit.interfaces.graph.util import (
PERSON_NAMES,
create_organization_recursive,
)
# Example usage:
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Benchmark graph model with configurable recursive depth"
)
parser.add_argument(
"--recursive-depth",
type=int,
default=3,
help="Recursive depth for graph generation (default: 3)",
)
parser.add_argument(
"--runs", type=int, default=5, help="Number of benchmark runs (default: 5)"
)
args = parser.parse_args()
society = create_organization_recursive(
"society", "Society", PERSON_NAMES, args.recursive_depth
)
nodes, edges = get_graph_from_model(society)
results = benchmark_function(get_graph_from_model, society, num_runs=args.runs)
print("\nBenchmark Results:")
print(
f"N nodes: {len(nodes)}, N edges: {len(edges)}, Recursion depth: {args.recursive_depth}"
)
print(f"Mean Peak Memory: {results['mean_peak_memory_mb']:.2f} MB")
print(f"Mean CPU Usage: {results['mean_cpu_percent']:.2f}%")
print(f"Mean Execution Time: {results['mean_execution_time']:.4f} seconds")
if "std_execution_time" in results:
print(f"Execution Time Std: {results['std_execution_time']:.4f} seconds")