feat: reimplement resolve_edges_to_text with cleaner formatting (#652)
<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> - Optimized to deduplicate nodes appearing in multiple triplets, avoiding redundant text repetition - Reimplemented `resolve_edges_to_text` with cleaner formatting - Added `_top_n_words` method for extracting frequent words from text - Created `_get_title` function to generate titles from text content based on first words and word frequency - Extracted node processing logic to `_get_nodes` helper method - Created dedicated `stop_words` utility with common English stopwords ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Summary by CodeRabbit - **New Features** - Improved text output formatting that organizes content into clearly defined sections for enhanced readability. - Enhanced text processing capabilities, including refined title generation and key phrase extraction. - Introduced a comprehensive utility for managing common stop words, further optimizing text analysis. - **Bug Fixes** - Updated tests to ensure accurate validation of new functionalities and improved existing test coverage. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: hajdul88 <52442977+hajdul88@users.noreply.github.com>
This commit is contained in:
parent
164cb581ec
commit
ee88fcf5d3
3 changed files with 219 additions and 20 deletions
|
|
@ -1,10 +1,13 @@
|
|||
from typing import Any, Optional
|
||||
from collections import Counter
|
||||
import string
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
|
||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
|
||||
from cognee.modules.retrieval.utils.completion import generate_completion
|
||||
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
|
||||
from cognee.tasks.completion.exceptions import NoRelevantDataFound
|
||||
|
||||
|
||||
|
|
@ -22,16 +25,34 @@ class GraphCompletionRetriever(BaseRetriever):
|
|||
self.system_prompt_path = system_prompt_path
|
||||
self.top_k = top_k if top_k is not None else 5
|
||||
|
||||
def _get_nodes(self, retrieved_edges: list) -> dict:
|
||||
"""Creates a dictionary of nodes with their names and content."""
|
||||
nodes = {}
|
||||
for edge in retrieved_edges:
|
||||
for node in (edge.node1, edge.node2):
|
||||
if node.id not in nodes:
|
||||
text = node.attributes.get("text")
|
||||
if text:
|
||||
name = self._get_title(text)
|
||||
content = text
|
||||
else:
|
||||
name = node.attributes.get("name", "Unnamed Node")
|
||||
content = name
|
||||
nodes[node.id] = {"node": node, "name": name, "content": content}
|
||||
return nodes
|
||||
|
||||
async def resolve_edges_to_text(self, retrieved_edges: list) -> str:
|
||||
"""Converts retrieved graph edges into a human-readable string format."""
|
||||
edge_strings = []
|
||||
for edge in retrieved_edges:
|
||||
node1_string = edge.node1.attributes.get("text") or edge.node1.attributes.get("name")
|
||||
node2_string = edge.node2.attributes.get("text") or edge.node2.attributes.get("name")
|
||||
edge_string = edge.attributes["relationship_type"]
|
||||
edge_str = f"{node1_string} -- {edge_string} -- {node2_string}"
|
||||
edge_strings.append(edge_str)
|
||||
return "\n---\n".join(edge_strings)
|
||||
nodes = self._get_nodes(retrieved_edges)
|
||||
node_section = "\n".join(
|
||||
f"Node: {info['name']}\n__node_content_start__\n{info['content']}\n__node_content_end__\n"
|
||||
for info in nodes.values()
|
||||
)
|
||||
connection_section = "\n".join(
|
||||
f"{nodes[edge.node1.id]['name']} --[{edge.attributes['relationship_type']}]--> {nodes[edge.node2.id]['name']}"
|
||||
for edge in retrieved_edges
|
||||
)
|
||||
return f"Nodes:\n{node_section}\n\nConnections:\n{connection_section}"
|
||||
|
||||
async def get_triplets(self, query: str) -> list:
|
||||
"""Retrieves relevant graph triplets."""
|
||||
|
|
@ -69,3 +90,23 @@ class GraphCompletionRetriever(BaseRetriever):
|
|||
system_prompt_path=self.system_prompt_path,
|
||||
)
|
||||
return [completion]
|
||||
|
||||
def _top_n_words(self, text, stop_words=None, top_n=3, separator=", "):
|
||||
"""Concatenates the top N frequent words in text."""
|
||||
if stop_words is None:
|
||||
stop_words = DEFAULT_STOP_WORDS
|
||||
|
||||
words = [word.lower().strip(string.punctuation) for word in text.split()]
|
||||
|
||||
if stop_words:
|
||||
words = [word for word in words if word and word not in stop_words]
|
||||
|
||||
top_words = [word for word, freq in Counter(words).most_common(top_n)]
|
||||
|
||||
return separator.join(top_words)
|
||||
|
||||
def _get_title(self, text: str, first_n_words: int = 7, top_n_words: int = 3) -> str:
|
||||
"""Creates a title, by combining first words with most frequent words from the text."""
|
||||
first_n_words = text.split()[:first_n_words]
|
||||
top_n_words = self._top_n_words(text, top_n=top_n_words)
|
||||
return f"{' '.join(first_n_words)}... [{top_n_words}]"
|
||||
|
|
|
|||
71
cognee/modules/retrieval/utils/stop_words.py
Normal file
71
cognee/modules/retrieval/utils/stop_words.py
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
"""Common stop words for text processing."""
|
||||
|
||||
# Common English stop words to filter out in text processing
|
||||
DEFAULT_STOP_WORDS = {
|
||||
"a",
|
||||
"an",
|
||||
"the",
|
||||
"and",
|
||||
"or",
|
||||
"but",
|
||||
"is",
|
||||
"are",
|
||||
"was",
|
||||
"were",
|
||||
"in",
|
||||
"on",
|
||||
"at",
|
||||
"to",
|
||||
"for",
|
||||
"with",
|
||||
"by",
|
||||
"about",
|
||||
"of",
|
||||
"from",
|
||||
"as",
|
||||
"that",
|
||||
"this",
|
||||
"these",
|
||||
"those",
|
||||
"it",
|
||||
"its",
|
||||
"them",
|
||||
"they",
|
||||
"their",
|
||||
"he",
|
||||
"she",
|
||||
"his",
|
||||
"her",
|
||||
"him",
|
||||
"we",
|
||||
"our",
|
||||
"you",
|
||||
"your",
|
||||
"not",
|
||||
"be",
|
||||
"been",
|
||||
"being",
|
||||
"have",
|
||||
"has",
|
||||
"had",
|
||||
"do",
|
||||
"does",
|
||||
"did",
|
||||
"can",
|
||||
"could",
|
||||
"will",
|
||||
"would",
|
||||
"shall",
|
||||
"should",
|
||||
"may",
|
||||
"might",
|
||||
"must",
|
||||
"when",
|
||||
"where",
|
||||
"which",
|
||||
"who",
|
||||
"whom",
|
||||
"whose",
|
||||
"why",
|
||||
"how",
|
||||
}
|
||||
|
|
@ -40,23 +40,37 @@ class TestGraphCompletionRetriever:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_edges_to_text(self, mock_retriever):
|
||||
node_a = AsyncMock(id="node_a_id", attributes={"text": "Node A text content"})
|
||||
node_b = AsyncMock(id="node_b_id", attributes={"text": "Node B text content"})
|
||||
node_c = AsyncMock(id="node_c_id", attributes={"name": "Node C"})
|
||||
|
||||
triplets = [
|
||||
AsyncMock(
|
||||
node1=AsyncMock(attributes={"text": "Node A"}),
|
||||
node1=node_a,
|
||||
attributes={"relationship_type": "connects"},
|
||||
node2=AsyncMock(attributes={"text": "Node B"}),
|
||||
node2=node_b,
|
||||
),
|
||||
AsyncMock(
|
||||
node1=AsyncMock(attributes={"text": "Node X"}),
|
||||
node1=node_a,
|
||||
attributes={"relationship_type": "links"},
|
||||
node2=AsyncMock(attributes={"text": "Node Y"}),
|
||||
node2=node_c,
|
||||
),
|
||||
]
|
||||
|
||||
result = await mock_retriever.resolve_edges_to_text(triplets)
|
||||
with patch.object(mock_retriever, "_get_title", return_value="Test Title"):
|
||||
result = await mock_retriever.resolve_edges_to_text(triplets)
|
||||
|
||||
expected_output = "Node A -- connects -- Node B\n---\nNode X -- links -- Node Y"
|
||||
assert result == expected_output
|
||||
assert "Nodes:" in result
|
||||
assert "Connections:" in result
|
||||
|
||||
assert "Node: Test Title" in result
|
||||
assert "__node_content_start__" in result
|
||||
assert "Node A text content" in result
|
||||
assert "__node_content_end__" in result
|
||||
assert "Node: Node C" in result
|
||||
|
||||
assert "Test Title --[connects]--> Test Title" in result
|
||||
assert "Test Title --[links]--> Node C" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
|
|
@ -124,16 +138,13 @@ class TestGraphCompletionRetriever:
|
|||
mock_get_llm_client,
|
||||
mock_retriever,
|
||||
):
|
||||
# Setup
|
||||
query = "test query with empty graph"
|
||||
|
||||
# Mock graph engine with empty graph
|
||||
mock_graph_engine = MagicMock()
|
||||
mock_graph_engine.get_graph_data = AsyncMock()
|
||||
mock_graph_engine.get_graph_data.return_value = ([], [])
|
||||
mock_get_graph_engine.return_value = mock_graph_engine
|
||||
|
||||
# Mock LLM client
|
||||
mock_llm_client = MagicMock()
|
||||
mock_llm_client.acreate_structured_output = AsyncMock()
|
||||
mock_llm_client.acreate_structured_output.return_value = (
|
||||
|
|
@ -141,9 +152,85 @@ class TestGraphCompletionRetriever:
|
|||
)
|
||||
mock_get_llm_client.return_value = mock_llm_client
|
||||
|
||||
# Execute
|
||||
with pytest.raises(EntityNotFoundError):
|
||||
await mock_retriever.get_completion(query)
|
||||
|
||||
# Verify graph engine was called
|
||||
mock_graph_engine.get_graph_data.assert_called_once()
|
||||
|
||||
def test_top_n_words(self, mock_retriever):
|
||||
"""Test extraction of top frequent words from text."""
|
||||
text = "The quick brown fox jumps over the lazy dog. The fox is quick."
|
||||
|
||||
result = mock_retriever._top_n_words(text)
|
||||
assert len(result.split(", ")) <= 3
|
||||
assert "fox" in result
|
||||
assert "quick" in result
|
||||
|
||||
result = mock_retriever._top_n_words(text, top_n=2)
|
||||
assert len(result.split(", ")) <= 2
|
||||
|
||||
result = mock_retriever._top_n_words(text, separator=" | ")
|
||||
assert " | " in result
|
||||
|
||||
result = mock_retriever._top_n_words(text, stop_words={"fox", "quick"})
|
||||
assert "fox" not in result
|
||||
assert "quick" not in result
|
||||
|
||||
def test_get_title(self, mock_retriever):
|
||||
"""Test title generation from text."""
|
||||
text = "This is a long paragraph about various topics that should generate a title. The main topics are AI, programming and data science."
|
||||
|
||||
title = mock_retriever._get_title(text)
|
||||
assert "..." in title
|
||||
assert "[" in title and "]" in title
|
||||
|
||||
title = mock_retriever._get_title(text, first_n_words=3)
|
||||
first_part = title.split("...")[0].strip()
|
||||
assert len(first_part.split()) == 3
|
||||
|
||||
title = mock_retriever._get_title(text, top_n_words=2)
|
||||
top_part = title.split("[")[1].split("]")[0]
|
||||
assert len(top_part.split(", ")) <= 2
|
||||
|
||||
def test_get_nodes(self, mock_retriever):
|
||||
"""Test node processing and deduplication."""
|
||||
node_with_text = AsyncMock(id="text_node", attributes={"text": "This is a text node"})
|
||||
node_with_name = AsyncMock(id="name_node", attributes={"name": "Named Node"})
|
||||
node_without_attrs = AsyncMock(id="empty_node", attributes={})
|
||||
|
||||
edges = [
|
||||
AsyncMock(
|
||||
node1=node_with_text, node2=node_with_name, attributes={"relationship_type": "rel1"}
|
||||
),
|
||||
AsyncMock(
|
||||
node1=node_with_text,
|
||||
node2=node_without_attrs,
|
||||
attributes={"relationship_type": "rel2"},
|
||||
),
|
||||
AsyncMock(
|
||||
node1=node_with_name,
|
||||
node2=node_without_attrs,
|
||||
attributes={"relationship_type": "rel3"},
|
||||
),
|
||||
]
|
||||
|
||||
with patch.object(mock_retriever, "_get_title", return_value="Generated Title"):
|
||||
nodes = mock_retriever._get_nodes(edges)
|
||||
|
||||
assert len(nodes) == 3
|
||||
|
||||
for node_id, info in nodes.items():
|
||||
assert "node" in info
|
||||
assert "name" in info
|
||||
assert "content" in info
|
||||
|
||||
text_node_info = nodes[node_with_text.id]
|
||||
assert text_node_info["name"] == "Generated Title"
|
||||
assert text_node_info["content"] == "This is a text node"
|
||||
|
||||
name_node_info = nodes[node_with_name.id]
|
||||
assert name_node_info["name"] == "Named Node"
|
||||
assert name_node_info["content"] == "Named Node"
|
||||
|
||||
empty_node_info = nodes[node_without_attrs.id]
|
||||
assert empty_node_info["name"] == "Unnamed Node"
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue