diff --git a/cognee/modules/retrieval/coding_rules_retriever.py b/cognee/modules/retrieval/coding_rules_retriever.py new file mode 100644 index 000000000..2578d1ee1 --- /dev/null +++ b/cognee/modules/retrieval/coding_rules_retriever.py @@ -0,0 +1,19 @@ +from cognee.shared.logging_utils import get_logger +from cognee.tasks.codingagents.coding_rule_associations import get_existing_rules + +logger = get_logger("CodingRulesRetriever") + + +class CodingRulesRetriever: + """Retriever for handling codeing rule based searches.""" + + def __init__(self, rules_nodeset_name): + if isinstance(rules_nodeset_name, list): + rules_nodeset_name = rules_nodeset_name[0] + self.rules_nodeset_name = rules_nodeset_name + """Initialize retriever with search parameters.""" + + async def get_existing_rules(self, query_text): + return await get_existing_rules( + rules_nodeset_name=self.rules_nodeset_name, return_list=True + ) diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py index 71bf61d6b..b341e4a8a 100644 --- a/cognee/modules/search/methods/search.py +++ b/cognee/modules/search/methods/search.py @@ -13,6 +13,7 @@ from cognee.modules.retrieval.insights_retriever import InsightsRetriever from cognee.modules.retrieval.summaries_retriever import SummariesRetriever from cognee.modules.retrieval.completion_retriever import CompletionRetriever from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever +from cognee.modules.retrieval.coding_rules_retriever import CodingRulesRetriever from cognee.modules.retrieval.graph_summary_completion_retriever import ( GraphSummaryCompletionRetriever, ) @@ -167,6 +168,9 @@ async def specific_search( SearchType.CYPHER: CypherSearchRetriever().get_completion, SearchType.NATURAL_LANGUAGE: NaturalLanguageRetriever().get_completion, SearchType.FEEDBACK: UserQAFeedback(last_k=last_k).add_feedback, + SearchType.CODING_RULES: CodingRulesRetriever( + rules_nodeset_name=node_name + ).get_existing_rules, } # If the query type is FEELING_LUCKY, select the search type intelligently diff --git a/cognee/modules/search/types/SearchType.py b/cognee/modules/search/types/SearchType.py index c1f0521b2..0a7cae63a 100644 --- a/cognee/modules/search/types/SearchType.py +++ b/cognee/modules/search/types/SearchType.py @@ -15,3 +15,4 @@ class SearchType(Enum): GRAPH_COMPLETION_CONTEXT_EXTENSION = "GRAPH_COMPLETION_CONTEXT_EXTENSION" FEELING_LUCKY = "FEELING_LUCKY" FEEDBACK = "FEEDBACK" + CODING_RULES = "CODING_RULES" diff --git a/cognee/tasks/codingagents/coding_rule_associations.py b/cognee/tasks/codingagents/coding_rule_associations.py index e722e7728..c809bc68f 100644 --- a/cognee/tasks/codingagents/coding_rule_associations.py +++ b/cognee/tasks/codingagents/coding_rule_associations.py @@ -31,7 +31,7 @@ class RuleSet(DataPoint): ) -async def get_existing_rules(rules_nodeset_name: str) -> str: +async def get_existing_rules(rules_nodeset_name: str, return_list: bool = False) -> str: graph_engine = await get_graph_engine() nodes_data, _ = await graph_engine.get_nodeset_subgraph( node_type=NodeSet, node_name=[rules_nodeset_name] @@ -46,7 +46,8 @@ async def get_existing_rules(rules_nodeset_name: str) -> str: and "text" in item[1] ] - existing_rules = "\n".join(f"- {rule}" for rule in existing_rules) + if not return_list: + existing_rules = "\n".join(f"- {rule}" for rule in existing_rules) return existing_rules diff --git a/examples/python/memify_coding_agent_example.py b/examples/python/memify_coding_agent_example.py index 61af467d3..7f8c58802 100644 --- a/examples/python/memify_coding_agent_example.py +++ b/examples/python/memify_coding_agent_example.py @@ -85,8 +85,13 @@ async def main(): ) # Find the new specific coding rules added to graph through memify (created based on chat conversation between team members) - developer_rules = await get_existing_rules(rules_nodeset_name="coding_agent_rules") - print(developer_rules) + print( + await cognee.search( + query_text="List me the coding rules", + query_type=cognee.SearchType.CODING_RULES, + node_name=["coding_agent_rules"], + ) + ) # Visualize new graph with added memify context file_path = os.path.join(