fix: unit tests and ruff lint errors
This commit is contained in:
parent
86bd3e4a5a
commit
611df1e9b9
8 changed files with 44 additions and 27 deletions
2
.github/actions/cognee_setup/action.yml
vendored
2
.github/actions/cognee_setup/action.yml
vendored
|
|
@ -23,4 +23,4 @@ runs:
|
|||
pip install poetry
|
||||
- name: Install dependencies
|
||||
shell: bash
|
||||
run: poetry install --no-interaction -E api -E docs -E evals -E gemini -E codegraph -E ollama -E dev
|
||||
run: poetry install --no-interaction -E api -E docs -E evals -E gemini -E codegraph -E ollama -E dev -E neo4j
|
||||
|
|
|
|||
|
|
@ -82,13 +82,15 @@ def record_graph_changes(func):
|
|||
|
||||
for node in nodes:
|
||||
node_id = UUID(str(node.id))
|
||||
relationship_ledgers.append(GraphRelationshipLedger(
|
||||
id=uuid5(NAMESPACE_OID, f"{datetime.now(timezone.utc).timestamp()}"),
|
||||
source_node_id=node_id,
|
||||
destination_node_id=node_id,
|
||||
creator_function=f"{creator}.node",
|
||||
node_label=getattr(node, "name", None) or str(node.id),
|
||||
))
|
||||
relationship_ledgers.append(
|
||||
GraphRelationshipLedger(
|
||||
id=uuid5(NAMESPACE_OID, f"{datetime.now(timezone.utc).timestamp()}"),
|
||||
source_node_id=node_id,
|
||||
destination_node_id=node_id,
|
||||
creator_function=f"{creator}.node",
|
||||
node_label=getattr(node, "name", None) or str(node.id),
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
session.add_all(relationship_ledgers)
|
||||
|
|
@ -106,12 +108,14 @@ def record_graph_changes(func):
|
|||
source_id = UUID(str(edge[0]))
|
||||
target_id = UUID(str(edge[1]))
|
||||
rel_type = str(edge[2])
|
||||
relationship_ledgers.append(GraphRelationshipLedger(
|
||||
id=uuid5(NAMESPACE_OID, f"{datetime.now(timezone.utc).timestamp()}"),
|
||||
source_node_id=source_id,
|
||||
destination_node_id=target_id,
|
||||
creator_function=f"{creator}.{rel_type}",
|
||||
))
|
||||
relationship_ledgers.append(
|
||||
GraphRelationshipLedger(
|
||||
id=uuid5(NAMESPACE_OID, f"{datetime.now(timezone.utc).timestamp()}"),
|
||||
source_node_id=source_id,
|
||||
destination_node_id=target_id,
|
||||
creator_function=f"{creator}.{rel_type}",
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
session.add_all(relationship_ledgers)
|
||||
|
|
|
|||
|
|
@ -107,7 +107,7 @@ class GenericAPIAdapter(LLMInterface):
|
|||
) as error:
|
||||
if (
|
||||
isinstance(error, InstructorRetryException)
|
||||
and not "content management policy" in str(error).lower()
|
||||
and "content management policy" not in str(error).lower()
|
||||
):
|
||||
raise error
|
||||
|
||||
|
|
@ -141,7 +141,7 @@ class GenericAPIAdapter(LLMInterface):
|
|||
) as error:
|
||||
if (
|
||||
isinstance(error, InstructorRetryException)
|
||||
and not "content management policy" in str(error).lower()
|
||||
and "content management policy" not in str(error).lower()
|
||||
):
|
||||
raise error
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -134,7 +134,7 @@ class OpenAIAdapter(LLMInterface):
|
|||
) as error:
|
||||
if (
|
||||
isinstance(error, InstructorRetryException)
|
||||
and not "content management policy" in str(error).lower()
|
||||
and "content management policy" not in str(error).lower()
|
||||
):
|
||||
raise error
|
||||
|
||||
|
|
@ -168,7 +168,7 @@ class OpenAIAdapter(LLMInterface):
|
|||
) as error:
|
||||
if (
|
||||
isinstance(error, InstructorRetryException)
|
||||
and not "content management policy" in str(error).lower()
|
||||
and "content management policy" not in str(error).lower()
|
||||
):
|
||||
raise error
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ async def get_graph_from_model_test():
|
|||
assert len(nodes) == 4
|
||||
assert len(edges) == 3
|
||||
|
||||
document_chunk_node = next(filter(lambda node: node.type is "DocumentChunk", nodes))
|
||||
document_chunk_node = next(filter(lambda node: node.type == "DocumentChunk", nodes))
|
||||
assert not hasattr(document_chunk_node, "part_of"), "Expected part_of attribute to be removed"
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
# import json
|
||||
# import asyncio
|
||||
from pympler import asizeof
|
||||
|
||||
# from cognee.modules.storage.utils import JSONEncoder
|
||||
from distributed.queues import save_data_points_queue
|
||||
# from cognee.modules.graph.utils import get_graph_from_model
|
||||
|
|
@ -62,6 +63,7 @@ async def save_data_points(data_points_and_relationships: tuple[list, list]):
|
|||
|
||||
# graph_data_deduplication.reset()
|
||||
|
||||
|
||||
class GraphDataDeduplication:
|
||||
nodes_and_edges_map: dict
|
||||
|
||||
|
|
@ -93,8 +95,11 @@ class GraphDataDeduplication:
|
|||
def try_pushing_nodes_to_queue(node_batch):
|
||||
try:
|
||||
save_data_points_queue.put((node_batch, []))
|
||||
except Exception as e:
|
||||
first_half, second_half = node_batch[:len(node_batch) // 2], node_batch[len(node_batch) // 2:]
|
||||
except Exception:
|
||||
first_half, second_half = (
|
||||
node_batch[: len(node_batch) // 2],
|
||||
node_batch[len(node_batch) // 2 :],
|
||||
)
|
||||
save_data_points_queue.put((first_half, []))
|
||||
save_data_points_queue.put((second_half, []))
|
||||
|
||||
|
|
@ -102,7 +107,10 @@ def try_pushing_nodes_to_queue(node_batch):
|
|||
def try_pushing_edges_to_queue(edge_batch):
|
||||
try:
|
||||
save_data_points_queue.put(([], edge_batch))
|
||||
except Exception as e:
|
||||
first_half, second_half = edge_batch[:len(edge_batch) // 2], edge_batch[len(edge_batch) // 2:]
|
||||
except Exception:
|
||||
first_half, second_half = (
|
||||
edge_batch[: len(edge_batch) // 2],
|
||||
edge_batch[len(edge_batch) // 2 :],
|
||||
)
|
||||
save_data_points_queue.put(([], first_half))
|
||||
save_data_points_queue.put(([], second_half))
|
||||
|
|
|
|||
|
|
@ -52,7 +52,10 @@ async def summarize_text(
|
|||
)
|
||||
|
||||
graph_data_deduplication = GraphDataDeduplication()
|
||||
deduplicated_nodes_and_edges = [graph_data_deduplication.deduplicate_nodes_and_edges(nodes, edges + relationships) for nodes, edges in nodes_and_edges]
|
||||
deduplicated_nodes_and_edges = [
|
||||
graph_data_deduplication.deduplicate_nodes_and_edges(nodes, edges + relationships)
|
||||
for nodes, edges in nodes_and_edges
|
||||
]
|
||||
|
||||
return deduplicated_nodes_and_edges
|
||||
|
||||
|
|
|
|||
|
|
@ -21,7 +21,9 @@ async def data_point_saver_worker():
|
|||
return True
|
||||
|
||||
if len(nodes_and_edges) == 2:
|
||||
print(f"Processing {len(nodes_and_edges[0])} nodes and {len(nodes_and_edges[1])} edges.")
|
||||
print(
|
||||
f"Processing {len(nodes_and_edges[0])} nodes and {len(nodes_and_edges[1])} edges."
|
||||
)
|
||||
nodes = nodes_and_edges[0]
|
||||
edges = nodes_and_edges[1]
|
||||
|
||||
|
|
@ -30,8 +32,8 @@ async def data_point_saver_worker():
|
|||
|
||||
if edges:
|
||||
await graph_engine.add_edges(edges)
|
||||
print(f"Finished processing nodes and edges.")
|
||||
print("Finished processing nodes and edges.")
|
||||
|
||||
else:
|
||||
print(f"No jobs, go to sleep.")
|
||||
print("No jobs, go to sleep.")
|
||||
await asyncio.sleep(5)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue