<!-- .github/pull_request_template.md --> ## Description This PR fixes distributed pipeline + updates core changes in distr logic. ## Type of Change <!-- Please check the relevant option --> - [x] Bug fix (non-breaking change that fixes an issue) - [x] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [x] Code refactoring - [x] Performance improvement - [ ] Other (please specify): ## Changes Made Fixes distributed pipeline: -Changed spawning logic + adds incremental loading to run_tasks_diistributed -Adds batching to consumer nodes -Fixes consumer stopping criteria by adding stop signal + handling -Changed edge embedding solution to avoid huge network load in a case of a multicontainer environment ## Testing Tested it by running 1GB on modal + manually ## Screenshots/Videos (if applicable) None ## Pre-submission Checklist <!-- Please check all boxes that apply before submitting your PR --> - [x] **I have tested my changes thoroughly before submitting this PR** - [x] **This PR contains minimal changes necessary to address the issue/feature** - [ ] My code follows the project's coding standards and style guidelines - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I have added necessary documentation (if applicable) - [ ] All new and existing tests pass - [ ] I have searched existing PRs to ensure this change hasn't been submitted already - [ ] I have linked any relevant issues in the description - [ ] My commits have clear and descriptive messages ## Related Issues None ## Additional Notes None ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. --------- Co-authored-by: Boris <boris@topoteretes.com> Co-authored-by: Boris Arzentar <borisarzentar@gmail.com>
120 lines
4 KiB
Python
120 lines
4 KiB
Python
from uuid import NAMESPACE_OID, uuid5
|
|
|
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
|
|
|
from cognee.low_level import DataPoint
|
|
from cognee.infrastructure.llm import LLMGateway
|
|
from cognee.shared.logging_utils import get_logger
|
|
from cognee.modules.engine.models import NodeSet
|
|
from cognee.tasks.storage import add_data_points, index_graph_edges
|
|
from typing import Optional, List, Any
|
|
from pydantic import Field
|
|
|
|
logger = get_logger("coding_rule_association")
|
|
|
|
|
|
class Rule(DataPoint):
|
|
"""A single developer rule extracted from text."""
|
|
|
|
text: str = Field(..., description="The coding rule associated with the conversation")
|
|
belongs_to_set: Optional[NodeSet] = None
|
|
metadata: dict = {"index_fields": ["rule"]}
|
|
|
|
|
|
class RuleSet(DataPoint):
|
|
"""Collection of parsed rules."""
|
|
|
|
rules: List[Rule] = Field(
|
|
...,
|
|
description="List of developer rules extracted from the input text. Each rule represents a coding best practice or guideline.",
|
|
)
|
|
|
|
|
|
async def get_existing_rules(rules_nodeset_name: str) -> str:
|
|
graph_engine = await get_graph_engine()
|
|
nodes_data, _ = await graph_engine.get_nodeset_subgraph(
|
|
node_type=NodeSet, node_name=[rules_nodeset_name]
|
|
)
|
|
|
|
existing_rules = [
|
|
item[1]["text"]
|
|
for item in nodes_data
|
|
if isinstance(item, tuple)
|
|
and len(item) == 2
|
|
and isinstance(item[1], dict)
|
|
and "text" in item[1]
|
|
]
|
|
|
|
existing_rules = "\n".join(f"- {rule}" for rule in existing_rules)
|
|
|
|
return existing_rules
|
|
|
|
|
|
async def get_origin_edges(data: str, rules: List[Rule]) -> list[Any]:
|
|
vector_engine = get_vector_engine()
|
|
|
|
origin_chunk = await vector_engine.search("DocumentChunk_text", data, limit=1)
|
|
|
|
try:
|
|
origin_id = origin_chunk[0].id
|
|
except (AttributeError, KeyError, TypeError, IndexError):
|
|
origin_id = None
|
|
|
|
relationships = []
|
|
|
|
if origin_id and isinstance(rules, (list, tuple)) and len(rules) > 0:
|
|
for rule in rules:
|
|
try:
|
|
rule_id = getattr(rule, "id", None)
|
|
if rule_id is not None:
|
|
rel_name = "rule_associated_from"
|
|
relationships.append(
|
|
(
|
|
rule_id,
|
|
origin_id,
|
|
rel_name,
|
|
{
|
|
"relationship_name": rel_name,
|
|
"source_node_id": rule_id,
|
|
"target_node_id": origin_id,
|
|
"ontology_valid": False,
|
|
},
|
|
)
|
|
)
|
|
except Exception as e:
|
|
logger.info(f"Warning: Skipping invalid rule due to error: {e}")
|
|
else:
|
|
logger.info("No valid origin_id or rules provided.")
|
|
|
|
return relationships
|
|
|
|
|
|
async def add_rule_associations(data: str, rules_nodeset_name: str):
|
|
graph_engine = await get_graph_engine()
|
|
existing_rules = await get_existing_rules(rules_nodeset_name=rules_nodeset_name)
|
|
|
|
user_context = {"chat": data, "rules": existing_rules}
|
|
|
|
user_prompt = LLMGateway.render_prompt(
|
|
"coding_rule_association_agent_user.txt", context=user_context
|
|
)
|
|
system_prompt = LLMGateway.render_prompt("coding_rule_association_agent_system.txt", context={})
|
|
|
|
rule_list = await LLMGateway.acreate_structured_output(
|
|
text_input=user_prompt, system_prompt=system_prompt, response_model=RuleSet
|
|
)
|
|
|
|
rules_nodeset = NodeSet(
|
|
id=uuid5(NAMESPACE_OID, name=rules_nodeset_name), name=rules_nodeset_name
|
|
)
|
|
for rule in rule_list.rules:
|
|
rule.belongs_to_set = rules_nodeset
|
|
|
|
edges_to_save = await get_origin_edges(data=data, rules=rule_list.rules)
|
|
|
|
await add_data_points(data_points=rule_list.rules)
|
|
|
|
if len(edges_to_save) > 0:
|
|
await graph_engine.add_edges(edges_to_save)
|
|
await index_graph_edges(edges_to_save)
|