cognee/cognee/modules/retrieval/graph_completion_retriever.py
lxobr ee88fcf5d3
feat: reimplement resolve_edges_to_text with cleaner formatting (#652)
<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->
- Optimized to deduplicate nodes appearing in multiple triplets,
avoiding redundant text repetition
- Reimplemented `resolve_edges_to_text` with cleaner formatting
  - Added `_top_n_words` method for extracting frequent words from text
- Created `_get_title` function to generate titles from text content
based on first words and word frequency
  - Extracted node processing logic to `_get_nodes` helper method
  - Created dedicated `stop_words` utility with common English stopwords

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

## Summary by CodeRabbit

- **New Features**
- Improved text output formatting that organizes content into clearly
defined sections for enhanced readability.
- Enhanced text processing capabilities, including refined title
generation and key phrase extraction.
- Introduced a comprehensive utility for managing common stop words,
further optimizing text analysis.
  
- **Bug Fixes**
- Updated tests to ensure accurate validation of new functionalities and
improved existing test coverage.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: hajdul88 <52442977+hajdul88@users.noreply.github.com>
2025-03-20 14:52:04 +01:00

112 lines
4.7 KiB
Python

from typing import Any, Optional
from collections import Counter
import string
from cognee.infrastructure.engine import DataPoint
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
from cognee.modules.retrieval.utils.completion import generate_completion
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
from cognee.tasks.completion.exceptions import NoRelevantDataFound
class GraphCompletionRetriever(BaseRetriever):
"""Retriever for handling graph-based completion searches."""
def __init__(
self,
user_prompt_path: str = "graph_context_for_question.txt",
system_prompt_path: str = "answer_simple_question.txt",
top_k: Optional[int] = 5,
):
"""Initialize retriever with prompt paths and search parameters."""
self.user_prompt_path = user_prompt_path
self.system_prompt_path = system_prompt_path
self.top_k = top_k if top_k is not None else 5
def _get_nodes(self, retrieved_edges: list) -> dict:
"""Creates a dictionary of nodes with their names and content."""
nodes = {}
for edge in retrieved_edges:
for node in (edge.node1, edge.node2):
if node.id not in nodes:
text = node.attributes.get("text")
if text:
name = self._get_title(text)
content = text
else:
name = node.attributes.get("name", "Unnamed Node")
content = name
nodes[node.id] = {"node": node, "name": name, "content": content}
return nodes
async def resolve_edges_to_text(self, retrieved_edges: list) -> str:
"""Converts retrieved graph edges into a human-readable string format."""
nodes = self._get_nodes(retrieved_edges)
node_section = "\n".join(
f"Node: {info['name']}\n__node_content_start__\n{info['content']}\n__node_content_end__\n"
for info in nodes.values()
)
connection_section = "\n".join(
f"{nodes[edge.node1.id]['name']} --[{edge.attributes['relationship_type']}]--> {nodes[edge.node2.id]['name']}"
for edge in retrieved_edges
)
return f"Nodes:\n{node_section}\n\nConnections:\n{connection_section}"
async def get_triplets(self, query: str) -> list:
"""Retrieves relevant graph triplets."""
subclasses = get_all_subclasses(DataPoint)
vector_index_collections = []
for subclass in subclasses:
index_fields = subclass.model_fields["metadata"].default.get("index_fields", [])
for field_name in index_fields:
vector_index_collections.append(f"{subclass.__name__}_{field_name}")
found_triplets = await brute_force_triplet_search(
query, top_k=self.top_k, collections=vector_index_collections or None
)
if len(found_triplets) == 0:
raise NoRelevantDataFound
return found_triplets
async def get_context(self, query: str) -> Any:
"""Retrieves and resolves graph triplets into context."""
triplets = await self.get_triplets(query)
return await self.resolve_edges_to_text(triplets)
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
"""Generates a completion using graph connections context."""
if context is None:
context = await self.get_context(query)
completion = await generate_completion(
query=query,
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
)
return [completion]
def _top_n_words(self, text, stop_words=None, top_n=3, separator=", "):
"""Concatenates the top N frequent words in text."""
if stop_words is None:
stop_words = DEFAULT_STOP_WORDS
words = [word.lower().strip(string.punctuation) for word in text.split()]
if stop_words:
words = [word for word in words if word and word not in stop_words]
top_words = [word for word, freq in Counter(words).most_common(top_n)]
return separator.join(top_words)
def _get_title(self, text: str, first_n_words: int = 7, top_n_words: int = 3) -> str:
"""Creates a title, by combining first words with most frequent words from the text."""
first_n_words = text.split()[:first_n_words]
top_n_words = self._top_n_words(text, top_n=top_n_words)
return f"{' '.join(first_n_words)}... [{top_n_words}]"