From 611df1e9b9b44b86980b6150848e968f92f33f48 Mon Sep 17 00:00:00 2001 From: Boris Arzentar Date: Wed, 2 Jul 2025 11:39:10 +0200 Subject: [PATCH] fix: unit tests and ruff lint errors --- .github/actions/cognee_setup/action.yml | 2 +- .../databases/graph/graph_db_interface.py | 30 +++++++++++-------- .../llm/generic_llm_api/adapter.py | 4 +-- cognee/infrastructure/llm/openai/adapter.py | 4 +-- .../graph/get_graph_from_model_flat_test.py | 2 +- distributed/tasks/save_data_points.py | 16 +++++++--- distributed/tasks/summarize_text.py | 5 +++- .../workers/data_point_saver_worker.py | 8 +++-- 8 files changed, 44 insertions(+), 27 deletions(-) diff --git a/.github/actions/cognee_setup/action.yml b/.github/actions/cognee_setup/action.yml index 3b122b7f9..17ce7f413 100644 --- a/.github/actions/cognee_setup/action.yml +++ b/.github/actions/cognee_setup/action.yml @@ -23,4 +23,4 @@ runs: pip install poetry - name: Install dependencies shell: bash - run: poetry install --no-interaction -E api -E docs -E evals -E gemini -E codegraph -E ollama -E dev + run: poetry install --no-interaction -E api -E docs -E evals -E gemini -E codegraph -E ollama -E dev -E neo4j diff --git a/cognee/infrastructure/databases/graph/graph_db_interface.py b/cognee/infrastructure/databases/graph/graph_db_interface.py index cc11dde67..16600b386 100644 --- a/cognee/infrastructure/databases/graph/graph_db_interface.py +++ b/cognee/infrastructure/databases/graph/graph_db_interface.py @@ -82,13 +82,15 @@ def record_graph_changes(func): for node in nodes: node_id = UUID(str(node.id)) - relationship_ledgers.append(GraphRelationshipLedger( - id=uuid5(NAMESPACE_OID, f"{datetime.now(timezone.utc).timestamp()}"), - source_node_id=node_id, - destination_node_id=node_id, - creator_function=f"{creator}.node", - node_label=getattr(node, "name", None) or str(node.id), - )) + relationship_ledgers.append( + GraphRelationshipLedger( + id=uuid5(NAMESPACE_OID, f"{datetime.now(timezone.utc).timestamp()}"), + source_node_id=node_id, + destination_node_id=node_id, + creator_function=f"{creator}.node", + node_label=getattr(node, "name", None) or str(node.id), + ) + ) try: session.add_all(relationship_ledgers) @@ -106,12 +108,14 @@ def record_graph_changes(func): source_id = UUID(str(edge[0])) target_id = UUID(str(edge[1])) rel_type = str(edge[2]) - relationship_ledgers.append(GraphRelationshipLedger( - id=uuid5(NAMESPACE_OID, f"{datetime.now(timezone.utc).timestamp()}"), - source_node_id=source_id, - destination_node_id=target_id, - creator_function=f"{creator}.{rel_type}", - )) + relationship_ledgers.append( + GraphRelationshipLedger( + id=uuid5(NAMESPACE_OID, f"{datetime.now(timezone.utc).timestamp()}"), + source_node_id=source_id, + destination_node_id=target_id, + creator_function=f"{creator}.{rel_type}", + ) + ) try: session.add_all(relationship_ledgers) diff --git a/cognee/infrastructure/llm/generic_llm_api/adapter.py b/cognee/infrastructure/llm/generic_llm_api/adapter.py index a3ff3c1cf..3d423da6e 100644 --- a/cognee/infrastructure/llm/generic_llm_api/adapter.py +++ b/cognee/infrastructure/llm/generic_llm_api/adapter.py @@ -107,7 +107,7 @@ class GenericAPIAdapter(LLMInterface): ) as error: if ( isinstance(error, InstructorRetryException) - and not "content management policy" in str(error).lower() + and "content management policy" not in str(error).lower() ): raise error @@ -141,7 +141,7 @@ class GenericAPIAdapter(LLMInterface): ) as error: if ( isinstance(error, InstructorRetryException) - and not "content management policy" in str(error).lower() + and "content management policy" not in str(error).lower() ): raise error else: diff --git a/cognee/infrastructure/llm/openai/adapter.py b/cognee/infrastructure/llm/openai/adapter.py index f366c8a67..f0667a696 100644 --- a/cognee/infrastructure/llm/openai/adapter.py +++ b/cognee/infrastructure/llm/openai/adapter.py @@ -134,7 +134,7 @@ class OpenAIAdapter(LLMInterface): ) as error: if ( isinstance(error, InstructorRetryException) - and not "content management policy" in str(error).lower() + and "content management policy" not in str(error).lower() ): raise error @@ -168,7 +168,7 @@ class OpenAIAdapter(LLMInterface): ) as error: if ( isinstance(error, InstructorRetryException) - and not "content management policy" in str(error).lower() + and "content management policy" not in str(error).lower() ): raise error else: diff --git a/cognee/tests/unit/interfaces/graph/get_graph_from_model_flat_test.py b/cognee/tests/unit/interfaces/graph/get_graph_from_model_flat_test.py index 7b32cde3b..8a1f7b5b8 100644 --- a/cognee/tests/unit/interfaces/graph/get_graph_from_model_flat_test.py +++ b/cognee/tests/unit/interfaces/graph/get_graph_from_model_flat_test.py @@ -68,7 +68,7 @@ async def get_graph_from_model_test(): assert len(nodes) == 4 assert len(edges) == 3 - document_chunk_node = next(filter(lambda node: node.type is "DocumentChunk", nodes)) + document_chunk_node = next(filter(lambda node: node.type == "DocumentChunk", nodes)) assert not hasattr(document_chunk_node, "part_of"), "Expected part_of attribute to be removed" diff --git a/distributed/tasks/save_data_points.py b/distributed/tasks/save_data_points.py index 51b2ee12d..23039f6b6 100644 --- a/distributed/tasks/save_data_points.py +++ b/distributed/tasks/save_data_points.py @@ -1,6 +1,7 @@ # import json # import asyncio from pympler import asizeof + # from cognee.modules.storage.utils import JSONEncoder from distributed.queues import save_data_points_queue # from cognee.modules.graph.utils import get_graph_from_model @@ -62,6 +63,7 @@ async def save_data_points(data_points_and_relationships: tuple[list, list]): # graph_data_deduplication.reset() + class GraphDataDeduplication: nodes_and_edges_map: dict @@ -93,8 +95,11 @@ class GraphDataDeduplication: def try_pushing_nodes_to_queue(node_batch): try: save_data_points_queue.put((node_batch, [])) - except Exception as e: - first_half, second_half = node_batch[:len(node_batch) // 2], node_batch[len(node_batch) // 2:] + except Exception: + first_half, second_half = ( + node_batch[: len(node_batch) // 2], + node_batch[len(node_batch) // 2 :], + ) save_data_points_queue.put((first_half, [])) save_data_points_queue.put((second_half, [])) @@ -102,7 +107,10 @@ def try_pushing_nodes_to_queue(node_batch): def try_pushing_edges_to_queue(edge_batch): try: save_data_points_queue.put(([], edge_batch)) - except Exception as e: - first_half, second_half = edge_batch[:len(edge_batch) // 2], edge_batch[len(edge_batch) // 2:] + except Exception: + first_half, second_half = ( + edge_batch[: len(edge_batch) // 2], + edge_batch[len(edge_batch) // 2 :], + ) save_data_points_queue.put(([], first_half)) save_data_points_queue.put(([], second_half)) diff --git a/distributed/tasks/summarize_text.py b/distributed/tasks/summarize_text.py index 09da9a750..72ad36e49 100644 --- a/distributed/tasks/summarize_text.py +++ b/distributed/tasks/summarize_text.py @@ -52,7 +52,10 @@ async def summarize_text( ) graph_data_deduplication = GraphDataDeduplication() - deduplicated_nodes_and_edges = [graph_data_deduplication.deduplicate_nodes_and_edges(nodes, edges + relationships) for nodes, edges in nodes_and_edges] + deduplicated_nodes_and_edges = [ + graph_data_deduplication.deduplicate_nodes_and_edges(nodes, edges + relationships) + for nodes, edges in nodes_and_edges + ] return deduplicated_nodes_and_edges diff --git a/distributed/workers/data_point_saver_worker.py b/distributed/workers/data_point_saver_worker.py index 85b06d4b0..bf0035c2a 100644 --- a/distributed/workers/data_point_saver_worker.py +++ b/distributed/workers/data_point_saver_worker.py @@ -21,7 +21,9 @@ async def data_point_saver_worker(): return True if len(nodes_and_edges) == 2: - print(f"Processing {len(nodes_and_edges[0])} nodes and {len(nodes_and_edges[1])} edges.") + print( + f"Processing {len(nodes_and_edges[0])} nodes and {len(nodes_and_edges[1])} edges." + ) nodes = nodes_and_edges[0] edges = nodes_and_edges[1] @@ -30,8 +32,8 @@ async def data_point_saver_worker(): if edges: await graph_engine.add_edges(edges) - print(f"Finished processing nodes and edges.") + print("Finished processing nodes and edges.") else: - print(f"No jobs, go to sleep.") + print("No jobs, go to sleep.") await asyncio.sleep(5)