feat: reimplement resolve_edges_to_text with cleaner formatting (#652)

<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->
- Optimized to deduplicate nodes appearing in multiple triplets,
avoiding redundant text repetition
- Reimplemented `resolve_edges_to_text` with cleaner formatting
  - Added `_top_n_words` method for extracting frequent words from text
- Created `_get_title` function to generate titles from text content
based on first words and word frequency
  - Extracted node processing logic to `_get_nodes` helper method
  - Created dedicated `stop_words` utility with common English stopwords

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

## Summary by CodeRabbit

- **New Features**
- Improved text output formatting that organizes content into clearly
defined sections for enhanced readability.
- Enhanced text processing capabilities, including refined title
generation and key phrase extraction.
- Introduced a comprehensive utility for managing common stop words,
further optimizing text analysis.
  
- **Bug Fixes**
- Updated tests to ensure accurate validation of new functionalities and
improved existing test coverage.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: hajdul88 <52442977+hajdul88@users.noreply.github.com>
This commit is contained in:
lxobr 2025-03-20 14:52:04 +01:00 committed by GitHub
parent 164cb581ec
commit ee88fcf5d3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 219 additions and 20 deletions

View file

@ -1,10 +1,13 @@
from typing import Any, Optional
from collections import Counter
import string
from cognee.infrastructure.engine import DataPoint
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
from cognee.modules.retrieval.utils.completion import generate_completion
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
from cognee.tasks.completion.exceptions import NoRelevantDataFound
@ -22,16 +25,34 @@ class GraphCompletionRetriever(BaseRetriever):
self.system_prompt_path = system_prompt_path
self.top_k = top_k if top_k is not None else 5
def _get_nodes(self, retrieved_edges: list) -> dict:
"""Creates a dictionary of nodes with their names and content."""
nodes = {}
for edge in retrieved_edges:
for node in (edge.node1, edge.node2):
if node.id not in nodes:
text = node.attributes.get("text")
if text:
name = self._get_title(text)
content = text
else:
name = node.attributes.get("name", "Unnamed Node")
content = name
nodes[node.id] = {"node": node, "name": name, "content": content}
return nodes
async def resolve_edges_to_text(self, retrieved_edges: list) -> str:
"""Converts retrieved graph edges into a human-readable string format."""
edge_strings = []
for edge in retrieved_edges:
node1_string = edge.node1.attributes.get("text") or edge.node1.attributes.get("name")
node2_string = edge.node2.attributes.get("text") or edge.node2.attributes.get("name")
edge_string = edge.attributes["relationship_type"]
edge_str = f"{node1_string} -- {edge_string} -- {node2_string}"
edge_strings.append(edge_str)
return "\n---\n".join(edge_strings)
nodes = self._get_nodes(retrieved_edges)
node_section = "\n".join(
f"Node: {info['name']}\n__node_content_start__\n{info['content']}\n__node_content_end__\n"
for info in nodes.values()
)
connection_section = "\n".join(
f"{nodes[edge.node1.id]['name']} --[{edge.attributes['relationship_type']}]--> {nodes[edge.node2.id]['name']}"
for edge in retrieved_edges
)
return f"Nodes:\n{node_section}\n\nConnections:\n{connection_section}"
async def get_triplets(self, query: str) -> list:
"""Retrieves relevant graph triplets."""
@ -69,3 +90,23 @@ class GraphCompletionRetriever(BaseRetriever):
system_prompt_path=self.system_prompt_path,
)
return [completion]
def _top_n_words(self, text, stop_words=None, top_n=3, separator=", "):
"""Concatenates the top N frequent words in text."""
if stop_words is None:
stop_words = DEFAULT_STOP_WORDS
words = [word.lower().strip(string.punctuation) for word in text.split()]
if stop_words:
words = [word for word in words if word and word not in stop_words]
top_words = [word for word, freq in Counter(words).most_common(top_n)]
return separator.join(top_words)
def _get_title(self, text: str, first_n_words: int = 7, top_n_words: int = 3) -> str:
"""Creates a title, by combining first words with most frequent words from the text."""
first_n_words = text.split()[:first_n_words]
top_n_words = self._top_n_words(text, top_n=top_n_words)
return f"{' '.join(first_n_words)}... [{top_n_words}]"

View file

@ -0,0 +1,71 @@
"""Common stop words for text processing."""
# Common English stop words to filter out in text processing
DEFAULT_STOP_WORDS = {
"a",
"an",
"the",
"and",
"or",
"but",
"is",
"are",
"was",
"were",
"in",
"on",
"at",
"to",
"for",
"with",
"by",
"about",
"of",
"from",
"as",
"that",
"this",
"these",
"those",
"it",
"its",
"them",
"they",
"their",
"he",
"she",
"his",
"her",
"him",
"we",
"our",
"you",
"your",
"not",
"be",
"been",
"being",
"have",
"has",
"had",
"do",
"does",
"did",
"can",
"could",
"will",
"would",
"shall",
"should",
"may",
"might",
"must",
"when",
"where",
"which",
"who",
"whom",
"whose",
"why",
"how",
}

View file

@ -40,23 +40,37 @@ class TestGraphCompletionRetriever:
@pytest.mark.asyncio
async def test_resolve_edges_to_text(self, mock_retriever):
node_a = AsyncMock(id="node_a_id", attributes={"text": "Node A text content"})
node_b = AsyncMock(id="node_b_id", attributes={"text": "Node B text content"})
node_c = AsyncMock(id="node_c_id", attributes={"name": "Node C"})
triplets = [
AsyncMock(
node1=AsyncMock(attributes={"text": "Node A"}),
node1=node_a,
attributes={"relationship_type": "connects"},
node2=AsyncMock(attributes={"text": "Node B"}),
node2=node_b,
),
AsyncMock(
node1=AsyncMock(attributes={"text": "Node X"}),
node1=node_a,
attributes={"relationship_type": "links"},
node2=AsyncMock(attributes={"text": "Node Y"}),
node2=node_c,
),
]
result = await mock_retriever.resolve_edges_to_text(triplets)
with patch.object(mock_retriever, "_get_title", return_value="Test Title"):
result = await mock_retriever.resolve_edges_to_text(triplets)
expected_output = "Node A -- connects -- Node B\n---\nNode X -- links -- Node Y"
assert result == expected_output
assert "Nodes:" in result
assert "Connections:" in result
assert "Node: Test Title" in result
assert "__node_content_start__" in result
assert "Node A text content" in result
assert "__node_content_end__" in result
assert "Node: Node C" in result
assert "Test Title --[connects]--> Test Title" in result
assert "Test Title --[links]--> Node C" in result
@pytest.mark.asyncio
@patch(
@ -124,16 +138,13 @@ class TestGraphCompletionRetriever:
mock_get_llm_client,
mock_retriever,
):
# Setup
query = "test query with empty graph"
# Mock graph engine with empty graph
mock_graph_engine = MagicMock()
mock_graph_engine.get_graph_data = AsyncMock()
mock_graph_engine.get_graph_data.return_value = ([], [])
mock_get_graph_engine.return_value = mock_graph_engine
# Mock LLM client
mock_llm_client = MagicMock()
mock_llm_client.acreate_structured_output = AsyncMock()
mock_llm_client.acreate_structured_output.return_value = (
@ -141,9 +152,85 @@ class TestGraphCompletionRetriever:
)
mock_get_llm_client.return_value = mock_llm_client
# Execute
with pytest.raises(EntityNotFoundError):
await mock_retriever.get_completion(query)
# Verify graph engine was called
mock_graph_engine.get_graph_data.assert_called_once()
def test_top_n_words(self, mock_retriever):
"""Test extraction of top frequent words from text."""
text = "The quick brown fox jumps over the lazy dog. The fox is quick."
result = mock_retriever._top_n_words(text)
assert len(result.split(", ")) <= 3
assert "fox" in result
assert "quick" in result
result = mock_retriever._top_n_words(text, top_n=2)
assert len(result.split(", ")) <= 2
result = mock_retriever._top_n_words(text, separator=" | ")
assert " | " in result
result = mock_retriever._top_n_words(text, stop_words={"fox", "quick"})
assert "fox" not in result
assert "quick" not in result
def test_get_title(self, mock_retriever):
"""Test title generation from text."""
text = "This is a long paragraph about various topics that should generate a title. The main topics are AI, programming and data science."
title = mock_retriever._get_title(text)
assert "..." in title
assert "[" in title and "]" in title
title = mock_retriever._get_title(text, first_n_words=3)
first_part = title.split("...")[0].strip()
assert len(first_part.split()) == 3
title = mock_retriever._get_title(text, top_n_words=2)
top_part = title.split("[")[1].split("]")[0]
assert len(top_part.split(", ")) <= 2
def test_get_nodes(self, mock_retriever):
"""Test node processing and deduplication."""
node_with_text = AsyncMock(id="text_node", attributes={"text": "This is a text node"})
node_with_name = AsyncMock(id="name_node", attributes={"name": "Named Node"})
node_without_attrs = AsyncMock(id="empty_node", attributes={})
edges = [
AsyncMock(
node1=node_with_text, node2=node_with_name, attributes={"relationship_type": "rel1"}
),
AsyncMock(
node1=node_with_text,
node2=node_without_attrs,
attributes={"relationship_type": "rel2"},
),
AsyncMock(
node1=node_with_name,
node2=node_without_attrs,
attributes={"relationship_type": "rel3"},
),
]
with patch.object(mock_retriever, "_get_title", return_value="Generated Title"):
nodes = mock_retriever._get_nodes(edges)
assert len(nodes) == 3
for node_id, info in nodes.items():
assert "node" in info
assert "name" in info
assert "content" in info
text_node_info = nodes[node_with_text.id]
assert text_node_info["name"] == "Generated Title"
assert text_node_info["content"] == "This is a text node"
name_node_info = nodes[node_with_name.id]
assert name_node_info["name"] == "Named Node"
assert name_node_info["content"] == "Named Node"
empty_node_info = nodes[node_without_attrs.id]
assert empty_node_info["name"] == "Unnamed Node"