fix: unit tests and ruff lint errors
This commit is contained in:
parent
86bd3e4a5a
commit
611df1e9b9
8 changed files with 44 additions and 27 deletions
2
.github/actions/cognee_setup/action.yml
vendored
2
.github/actions/cognee_setup/action.yml
vendored
|
|
@ -23,4 +23,4 @@ runs:
|
||||||
pip install poetry
|
pip install poetry
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
shell: bash
|
shell: bash
|
||||||
run: poetry install --no-interaction -E api -E docs -E evals -E gemini -E codegraph -E ollama -E dev
|
run: poetry install --no-interaction -E api -E docs -E evals -E gemini -E codegraph -E ollama -E dev -E neo4j
|
||||||
|
|
|
||||||
|
|
@ -82,13 +82,15 @@ def record_graph_changes(func):
|
||||||
|
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
node_id = UUID(str(node.id))
|
node_id = UUID(str(node.id))
|
||||||
relationship_ledgers.append(GraphRelationshipLedger(
|
relationship_ledgers.append(
|
||||||
id=uuid5(NAMESPACE_OID, f"{datetime.now(timezone.utc).timestamp()}"),
|
GraphRelationshipLedger(
|
||||||
source_node_id=node_id,
|
id=uuid5(NAMESPACE_OID, f"{datetime.now(timezone.utc).timestamp()}"),
|
||||||
destination_node_id=node_id,
|
source_node_id=node_id,
|
||||||
creator_function=f"{creator}.node",
|
destination_node_id=node_id,
|
||||||
node_label=getattr(node, "name", None) or str(node.id),
|
creator_function=f"{creator}.node",
|
||||||
))
|
node_label=getattr(node, "name", None) or str(node.id),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
session.add_all(relationship_ledgers)
|
session.add_all(relationship_ledgers)
|
||||||
|
|
@ -106,12 +108,14 @@ def record_graph_changes(func):
|
||||||
source_id = UUID(str(edge[0]))
|
source_id = UUID(str(edge[0]))
|
||||||
target_id = UUID(str(edge[1]))
|
target_id = UUID(str(edge[1]))
|
||||||
rel_type = str(edge[2])
|
rel_type = str(edge[2])
|
||||||
relationship_ledgers.append(GraphRelationshipLedger(
|
relationship_ledgers.append(
|
||||||
id=uuid5(NAMESPACE_OID, f"{datetime.now(timezone.utc).timestamp()}"),
|
GraphRelationshipLedger(
|
||||||
source_node_id=source_id,
|
id=uuid5(NAMESPACE_OID, f"{datetime.now(timezone.utc).timestamp()}"),
|
||||||
destination_node_id=target_id,
|
source_node_id=source_id,
|
||||||
creator_function=f"{creator}.{rel_type}",
|
destination_node_id=target_id,
|
||||||
))
|
creator_function=f"{creator}.{rel_type}",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
session.add_all(relationship_ledgers)
|
session.add_all(relationship_ledgers)
|
||||||
|
|
|
||||||
|
|
@ -107,7 +107,7 @@ class GenericAPIAdapter(LLMInterface):
|
||||||
) as error:
|
) as error:
|
||||||
if (
|
if (
|
||||||
isinstance(error, InstructorRetryException)
|
isinstance(error, InstructorRetryException)
|
||||||
and not "content management policy" in str(error).lower()
|
and "content management policy" not in str(error).lower()
|
||||||
):
|
):
|
||||||
raise error
|
raise error
|
||||||
|
|
||||||
|
|
@ -141,7 +141,7 @@ class GenericAPIAdapter(LLMInterface):
|
||||||
) as error:
|
) as error:
|
||||||
if (
|
if (
|
||||||
isinstance(error, InstructorRetryException)
|
isinstance(error, InstructorRetryException)
|
||||||
and not "content management policy" in str(error).lower()
|
and "content management policy" not in str(error).lower()
|
||||||
):
|
):
|
||||||
raise error
|
raise error
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -134,7 +134,7 @@ class OpenAIAdapter(LLMInterface):
|
||||||
) as error:
|
) as error:
|
||||||
if (
|
if (
|
||||||
isinstance(error, InstructorRetryException)
|
isinstance(error, InstructorRetryException)
|
||||||
and not "content management policy" in str(error).lower()
|
and "content management policy" not in str(error).lower()
|
||||||
):
|
):
|
||||||
raise error
|
raise error
|
||||||
|
|
||||||
|
|
@ -168,7 +168,7 @@ class OpenAIAdapter(LLMInterface):
|
||||||
) as error:
|
) as error:
|
||||||
if (
|
if (
|
||||||
isinstance(error, InstructorRetryException)
|
isinstance(error, InstructorRetryException)
|
||||||
and not "content management policy" in str(error).lower()
|
and "content management policy" not in str(error).lower()
|
||||||
):
|
):
|
||||||
raise error
|
raise error
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -68,7 +68,7 @@ async def get_graph_from_model_test():
|
||||||
assert len(nodes) == 4
|
assert len(nodes) == 4
|
||||||
assert len(edges) == 3
|
assert len(edges) == 3
|
||||||
|
|
||||||
document_chunk_node = next(filter(lambda node: node.type is "DocumentChunk", nodes))
|
document_chunk_node = next(filter(lambda node: node.type == "DocumentChunk", nodes))
|
||||||
assert not hasattr(document_chunk_node, "part_of"), "Expected part_of attribute to be removed"
|
assert not hasattr(document_chunk_node, "part_of"), "Expected part_of attribute to be removed"
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
# import json
|
# import json
|
||||||
# import asyncio
|
# import asyncio
|
||||||
from pympler import asizeof
|
from pympler import asizeof
|
||||||
|
|
||||||
# from cognee.modules.storage.utils import JSONEncoder
|
# from cognee.modules.storage.utils import JSONEncoder
|
||||||
from distributed.queues import save_data_points_queue
|
from distributed.queues import save_data_points_queue
|
||||||
# from cognee.modules.graph.utils import get_graph_from_model
|
# from cognee.modules.graph.utils import get_graph_from_model
|
||||||
|
|
@ -62,6 +63,7 @@ async def save_data_points(data_points_and_relationships: tuple[list, list]):
|
||||||
|
|
||||||
# graph_data_deduplication.reset()
|
# graph_data_deduplication.reset()
|
||||||
|
|
||||||
|
|
||||||
class GraphDataDeduplication:
|
class GraphDataDeduplication:
|
||||||
nodes_and_edges_map: dict
|
nodes_and_edges_map: dict
|
||||||
|
|
||||||
|
|
@ -93,8 +95,11 @@ class GraphDataDeduplication:
|
||||||
def try_pushing_nodes_to_queue(node_batch):
|
def try_pushing_nodes_to_queue(node_batch):
|
||||||
try:
|
try:
|
||||||
save_data_points_queue.put((node_batch, []))
|
save_data_points_queue.put((node_batch, []))
|
||||||
except Exception as e:
|
except Exception:
|
||||||
first_half, second_half = node_batch[:len(node_batch) // 2], node_batch[len(node_batch) // 2:]
|
first_half, second_half = (
|
||||||
|
node_batch[: len(node_batch) // 2],
|
||||||
|
node_batch[len(node_batch) // 2 :],
|
||||||
|
)
|
||||||
save_data_points_queue.put((first_half, []))
|
save_data_points_queue.put((first_half, []))
|
||||||
save_data_points_queue.put((second_half, []))
|
save_data_points_queue.put((second_half, []))
|
||||||
|
|
||||||
|
|
@ -102,7 +107,10 @@ def try_pushing_nodes_to_queue(node_batch):
|
||||||
def try_pushing_edges_to_queue(edge_batch):
|
def try_pushing_edges_to_queue(edge_batch):
|
||||||
try:
|
try:
|
||||||
save_data_points_queue.put(([], edge_batch))
|
save_data_points_queue.put(([], edge_batch))
|
||||||
except Exception as e:
|
except Exception:
|
||||||
first_half, second_half = edge_batch[:len(edge_batch) // 2], edge_batch[len(edge_batch) // 2:]
|
first_half, second_half = (
|
||||||
|
edge_batch[: len(edge_batch) // 2],
|
||||||
|
edge_batch[len(edge_batch) // 2 :],
|
||||||
|
)
|
||||||
save_data_points_queue.put(([], first_half))
|
save_data_points_queue.put(([], first_half))
|
||||||
save_data_points_queue.put(([], second_half))
|
save_data_points_queue.put(([], second_half))
|
||||||
|
|
|
||||||
|
|
@ -52,7 +52,10 @@ async def summarize_text(
|
||||||
)
|
)
|
||||||
|
|
||||||
graph_data_deduplication = GraphDataDeduplication()
|
graph_data_deduplication = GraphDataDeduplication()
|
||||||
deduplicated_nodes_and_edges = [graph_data_deduplication.deduplicate_nodes_and_edges(nodes, edges + relationships) for nodes, edges in nodes_and_edges]
|
deduplicated_nodes_and_edges = [
|
||||||
|
graph_data_deduplication.deduplicate_nodes_and_edges(nodes, edges + relationships)
|
||||||
|
for nodes, edges in nodes_and_edges
|
||||||
|
]
|
||||||
|
|
||||||
return deduplicated_nodes_and_edges
|
return deduplicated_nodes_and_edges
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,9 @@ async def data_point_saver_worker():
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if len(nodes_and_edges) == 2:
|
if len(nodes_and_edges) == 2:
|
||||||
print(f"Processing {len(nodes_and_edges[0])} nodes and {len(nodes_and_edges[1])} edges.")
|
print(
|
||||||
|
f"Processing {len(nodes_and_edges[0])} nodes and {len(nodes_and_edges[1])} edges."
|
||||||
|
)
|
||||||
nodes = nodes_and_edges[0]
|
nodes = nodes_and_edges[0]
|
||||||
edges = nodes_and_edges[1]
|
edges = nodes_and_edges[1]
|
||||||
|
|
||||||
|
|
@ -30,8 +32,8 @@ async def data_point_saver_worker():
|
||||||
|
|
||||||
if edges:
|
if edges:
|
||||||
await graph_engine.add_edges(edges)
|
await graph_engine.add_edges(edges)
|
||||||
print(f"Finished processing nodes and edges.")
|
print("Finished processing nodes and edges.")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
print(f"No jobs, go to sleep.")
|
print("No jobs, go to sleep.")
|
||||||
await asyncio.sleep(5)
|
await asyncio.sleep(5)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue