fix: unit tests and ruff lint errors

This commit is contained in:
Boris Arzentar 2025-07-02 11:39:10 +02:00
parent 86bd3e4a5a
commit 611df1e9b9
No known key found for this signature in database
GPG key ID: D5CC274C784807B7
8 changed files with 44 additions and 27 deletions

View file

@ -23,4 +23,4 @@ runs:
pip install poetry
- name: Install dependencies
shell: bash
run: poetry install --no-interaction -E api -E docs -E evals -E gemini -E codegraph -E ollama -E dev
run: poetry install --no-interaction -E api -E docs -E evals -E gemini -E codegraph -E ollama -E dev -E neo4j

View file

@ -82,13 +82,15 @@ def record_graph_changes(func):
for node in nodes:
node_id = UUID(str(node.id))
relationship_ledgers.append(GraphRelationshipLedger(
id=uuid5(NAMESPACE_OID, f"{datetime.now(timezone.utc).timestamp()}"),
source_node_id=node_id,
destination_node_id=node_id,
creator_function=f"{creator}.node",
node_label=getattr(node, "name", None) or str(node.id),
))
relationship_ledgers.append(
GraphRelationshipLedger(
id=uuid5(NAMESPACE_OID, f"{datetime.now(timezone.utc).timestamp()}"),
source_node_id=node_id,
destination_node_id=node_id,
creator_function=f"{creator}.node",
node_label=getattr(node, "name", None) or str(node.id),
)
)
try:
session.add_all(relationship_ledgers)
@ -106,12 +108,14 @@ def record_graph_changes(func):
source_id = UUID(str(edge[0]))
target_id = UUID(str(edge[1]))
rel_type = str(edge[2])
relationship_ledgers.append(GraphRelationshipLedger(
id=uuid5(NAMESPACE_OID, f"{datetime.now(timezone.utc).timestamp()}"),
source_node_id=source_id,
destination_node_id=target_id,
creator_function=f"{creator}.{rel_type}",
))
relationship_ledgers.append(
GraphRelationshipLedger(
id=uuid5(NAMESPACE_OID, f"{datetime.now(timezone.utc).timestamp()}"),
source_node_id=source_id,
destination_node_id=target_id,
creator_function=f"{creator}.{rel_type}",
)
)
try:
session.add_all(relationship_ledgers)

View file

@ -107,7 +107,7 @@ class GenericAPIAdapter(LLMInterface):
) as error:
if (
isinstance(error, InstructorRetryException)
and not "content management policy" in str(error).lower()
and "content management policy" not in str(error).lower()
):
raise error
@ -141,7 +141,7 @@ class GenericAPIAdapter(LLMInterface):
) as error:
if (
isinstance(error, InstructorRetryException)
and not "content management policy" in str(error).lower()
and "content management policy" not in str(error).lower()
):
raise error
else:

View file

@ -134,7 +134,7 @@ class OpenAIAdapter(LLMInterface):
) as error:
if (
isinstance(error, InstructorRetryException)
and not "content management policy" in str(error).lower()
and "content management policy" not in str(error).lower()
):
raise error
@ -168,7 +168,7 @@ class OpenAIAdapter(LLMInterface):
) as error:
if (
isinstance(error, InstructorRetryException)
and not "content management policy" in str(error).lower()
and "content management policy" not in str(error).lower()
):
raise error
else:

View file

@ -68,7 +68,7 @@ async def get_graph_from_model_test():
assert len(nodes) == 4
assert len(edges) == 3
document_chunk_node = next(filter(lambda node: node.type is "DocumentChunk", nodes))
document_chunk_node = next(filter(lambda node: node.type == "DocumentChunk", nodes))
assert not hasattr(document_chunk_node, "part_of"), "Expected part_of attribute to be removed"

View file

@ -1,6 +1,7 @@
# import json
# import asyncio
from pympler import asizeof
# from cognee.modules.storage.utils import JSONEncoder
from distributed.queues import save_data_points_queue
# from cognee.modules.graph.utils import get_graph_from_model
@ -62,6 +63,7 @@ async def save_data_points(data_points_and_relationships: tuple[list, list]):
# graph_data_deduplication.reset()
class GraphDataDeduplication:
nodes_and_edges_map: dict
@ -93,8 +95,11 @@ class GraphDataDeduplication:
def try_pushing_nodes_to_queue(node_batch):
try:
save_data_points_queue.put((node_batch, []))
except Exception as e:
first_half, second_half = node_batch[:len(node_batch) // 2], node_batch[len(node_batch) // 2:]
except Exception:
first_half, second_half = (
node_batch[: len(node_batch) // 2],
node_batch[len(node_batch) // 2 :],
)
save_data_points_queue.put((first_half, []))
save_data_points_queue.put((second_half, []))
@ -102,7 +107,10 @@ def try_pushing_nodes_to_queue(node_batch):
def try_pushing_edges_to_queue(edge_batch):
try:
save_data_points_queue.put(([], edge_batch))
except Exception as e:
first_half, second_half = edge_batch[:len(edge_batch) // 2], edge_batch[len(edge_batch) // 2:]
except Exception:
first_half, second_half = (
edge_batch[: len(edge_batch) // 2],
edge_batch[len(edge_batch) // 2 :],
)
save_data_points_queue.put(([], first_half))
save_data_points_queue.put(([], second_half))

View file

@ -52,7 +52,10 @@ async def summarize_text(
)
graph_data_deduplication = GraphDataDeduplication()
deduplicated_nodes_and_edges = [graph_data_deduplication.deduplicate_nodes_and_edges(nodes, edges + relationships) for nodes, edges in nodes_and_edges]
deduplicated_nodes_and_edges = [
graph_data_deduplication.deduplicate_nodes_and_edges(nodes, edges + relationships)
for nodes, edges in nodes_and_edges
]
return deduplicated_nodes_and_edges

View file

@ -21,7 +21,9 @@ async def data_point_saver_worker():
return True
if len(nodes_and_edges) == 2:
print(f"Processing {len(nodes_and_edges[0])} nodes and {len(nodes_and_edges[1])} edges.")
print(
f"Processing {len(nodes_and_edges[0])} nodes and {len(nodes_and_edges[1])} edges."
)
nodes = nodes_and_edges[0]
edges = nodes_and_edges[1]
@ -30,8 +32,8 @@ async def data_point_saver_worker():
if edges:
await graph_engine.add_edges(edges)
print(f"Finished processing nodes and edges.")
print("Finished processing nodes and edges.")
else:
print(f"No jobs, go to sleep.")
print("No jobs, go to sleep.")
await asyncio.sleep(5)