diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index 80b8855d1..66002825f 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -1,10 +1,13 @@ from typing import Any, Optional +from collections import Counter +import string from cognee.infrastructure.engine import DataPoint from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses from cognee.modules.retrieval.base_retriever import BaseRetriever from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search from cognee.modules.retrieval.utils.completion import generate_completion +from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS from cognee.tasks.completion.exceptions import NoRelevantDataFound @@ -22,16 +25,34 @@ class GraphCompletionRetriever(BaseRetriever): self.system_prompt_path = system_prompt_path self.top_k = top_k if top_k is not None else 5 + def _get_nodes(self, retrieved_edges: list) -> dict: + """Creates a dictionary of nodes with their names and content.""" + nodes = {} + for edge in retrieved_edges: + for node in (edge.node1, edge.node2): + if node.id not in nodes: + text = node.attributes.get("text") + if text: + name = self._get_title(text) + content = text + else: + name = node.attributes.get("name", "Unnamed Node") + content = name + nodes[node.id] = {"node": node, "name": name, "content": content} + return nodes + async def resolve_edges_to_text(self, retrieved_edges: list) -> str: """Converts retrieved graph edges into a human-readable string format.""" - edge_strings = [] - for edge in retrieved_edges: - node1_string = edge.node1.attributes.get("text") or edge.node1.attributes.get("name") - node2_string = edge.node2.attributes.get("text") or edge.node2.attributes.get("name") - edge_string = edge.attributes["relationship_type"] - edge_str = f"{node1_string} -- {edge_string} -- {node2_string}" - edge_strings.append(edge_str) - return "\n---\n".join(edge_strings) + nodes = self._get_nodes(retrieved_edges) + node_section = "\n".join( + f"Node: {info['name']}\n__node_content_start__\n{info['content']}\n__node_content_end__\n" + for info in nodes.values() + ) + connection_section = "\n".join( + f"{nodes[edge.node1.id]['name']} --[{edge.attributes['relationship_type']}]--> {nodes[edge.node2.id]['name']}" + for edge in retrieved_edges + ) + return f"Nodes:\n{node_section}\n\nConnections:\n{connection_section}" async def get_triplets(self, query: str) -> list: """Retrieves relevant graph triplets.""" @@ -69,3 +90,23 @@ class GraphCompletionRetriever(BaseRetriever): system_prompt_path=self.system_prompt_path, ) return [completion] + + def _top_n_words(self, text, stop_words=None, top_n=3, separator=", "): + """Concatenates the top N frequent words in text.""" + if stop_words is None: + stop_words = DEFAULT_STOP_WORDS + + words = [word.lower().strip(string.punctuation) for word in text.split()] + + if stop_words: + words = [word for word in words if word and word not in stop_words] + + top_words = [word for word, freq in Counter(words).most_common(top_n)] + + return separator.join(top_words) + + def _get_title(self, text: str, first_n_words: int = 7, top_n_words: int = 3) -> str: + """Creates a title, by combining first words with most frequent words from the text.""" + first_n_words = text.split()[:first_n_words] + top_n_words = self._top_n_words(text, top_n=top_n_words) + return f"{' '.join(first_n_words)}... [{top_n_words}]" diff --git a/cognee/modules/retrieval/utils/stop_words.py b/cognee/modules/retrieval/utils/stop_words.py new file mode 100644 index 000000000..3f881e39d --- /dev/null +++ b/cognee/modules/retrieval/utils/stop_words.py @@ -0,0 +1,71 @@ +"""Common stop words for text processing.""" + +# Common English stop words to filter out in text processing +DEFAULT_STOP_WORDS = { + "a", + "an", + "the", + "and", + "or", + "but", + "is", + "are", + "was", + "were", + "in", + "on", + "at", + "to", + "for", + "with", + "by", + "about", + "of", + "from", + "as", + "that", + "this", + "these", + "those", + "it", + "its", + "them", + "they", + "their", + "he", + "she", + "his", + "her", + "him", + "we", + "our", + "you", + "your", + "not", + "be", + "been", + "being", + "have", + "has", + "had", + "do", + "does", + "did", + "can", + "could", + "will", + "would", + "shall", + "should", + "may", + "might", + "must", + "when", + "where", + "which", + "who", + "whom", + "whose", + "why", + "how", +} diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py index acf6f4ece..7befa8243 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py @@ -40,23 +40,37 @@ class TestGraphCompletionRetriever: @pytest.mark.asyncio async def test_resolve_edges_to_text(self, mock_retriever): + node_a = AsyncMock(id="node_a_id", attributes={"text": "Node A text content"}) + node_b = AsyncMock(id="node_b_id", attributes={"text": "Node B text content"}) + node_c = AsyncMock(id="node_c_id", attributes={"name": "Node C"}) + triplets = [ AsyncMock( - node1=AsyncMock(attributes={"text": "Node A"}), + node1=node_a, attributes={"relationship_type": "connects"}, - node2=AsyncMock(attributes={"text": "Node B"}), + node2=node_b, ), AsyncMock( - node1=AsyncMock(attributes={"text": "Node X"}), + node1=node_a, attributes={"relationship_type": "links"}, - node2=AsyncMock(attributes={"text": "Node Y"}), + node2=node_c, ), ] - result = await mock_retriever.resolve_edges_to_text(triplets) + with patch.object(mock_retriever, "_get_title", return_value="Test Title"): + result = await mock_retriever.resolve_edges_to_text(triplets) - expected_output = "Node A -- connects -- Node B\n---\nNode X -- links -- Node Y" - assert result == expected_output + assert "Nodes:" in result + assert "Connections:" in result + + assert "Node: Test Title" in result + assert "__node_content_start__" in result + assert "Node A text content" in result + assert "__node_content_end__" in result + assert "Node: Node C" in result + + assert "Test Title --[connects]--> Test Title" in result + assert "Test Title --[links]--> Node C" in result @pytest.mark.asyncio @patch( @@ -124,16 +138,13 @@ class TestGraphCompletionRetriever: mock_get_llm_client, mock_retriever, ): - # Setup query = "test query with empty graph" - # Mock graph engine with empty graph mock_graph_engine = MagicMock() mock_graph_engine.get_graph_data = AsyncMock() mock_graph_engine.get_graph_data.return_value = ([], []) mock_get_graph_engine.return_value = mock_graph_engine - # Mock LLM client mock_llm_client = MagicMock() mock_llm_client.acreate_structured_output = AsyncMock() mock_llm_client.acreate_structured_output.return_value = ( @@ -141,9 +152,85 @@ class TestGraphCompletionRetriever: ) mock_get_llm_client.return_value = mock_llm_client - # Execute with pytest.raises(EntityNotFoundError): await mock_retriever.get_completion(query) - # Verify graph engine was called mock_graph_engine.get_graph_data.assert_called_once() + + def test_top_n_words(self, mock_retriever): + """Test extraction of top frequent words from text.""" + text = "The quick brown fox jumps over the lazy dog. The fox is quick." + + result = mock_retriever._top_n_words(text) + assert len(result.split(", ")) <= 3 + assert "fox" in result + assert "quick" in result + + result = mock_retriever._top_n_words(text, top_n=2) + assert len(result.split(", ")) <= 2 + + result = mock_retriever._top_n_words(text, separator=" | ") + assert " | " in result + + result = mock_retriever._top_n_words(text, stop_words={"fox", "quick"}) + assert "fox" not in result + assert "quick" not in result + + def test_get_title(self, mock_retriever): + """Test title generation from text.""" + text = "This is a long paragraph about various topics that should generate a title. The main topics are AI, programming and data science." + + title = mock_retriever._get_title(text) + assert "..." in title + assert "[" in title and "]" in title + + title = mock_retriever._get_title(text, first_n_words=3) + first_part = title.split("...")[0].strip() + assert len(first_part.split()) == 3 + + title = mock_retriever._get_title(text, top_n_words=2) + top_part = title.split("[")[1].split("]")[0] + assert len(top_part.split(", ")) <= 2 + + def test_get_nodes(self, mock_retriever): + """Test node processing and deduplication.""" + node_with_text = AsyncMock(id="text_node", attributes={"text": "This is a text node"}) + node_with_name = AsyncMock(id="name_node", attributes={"name": "Named Node"}) + node_without_attrs = AsyncMock(id="empty_node", attributes={}) + + edges = [ + AsyncMock( + node1=node_with_text, node2=node_with_name, attributes={"relationship_type": "rel1"} + ), + AsyncMock( + node1=node_with_text, + node2=node_without_attrs, + attributes={"relationship_type": "rel2"}, + ), + AsyncMock( + node1=node_with_name, + node2=node_without_attrs, + attributes={"relationship_type": "rel3"}, + ), + ] + + with patch.object(mock_retriever, "_get_title", return_value="Generated Title"): + nodes = mock_retriever._get_nodes(edges) + + assert len(nodes) == 3 + + for node_id, info in nodes.items(): + assert "node" in info + assert "name" in info + assert "content" in info + + text_node_info = nodes[node_with_text.id] + assert text_node_info["name"] == "Generated Title" + assert text_node_info["content"] == "This is a text node" + + name_node_info = nodes[node_with_name.id] + assert name_node_info["name"] == "Named Node" + assert name_node_info["content"] == "Named Node" + + empty_node_info = nodes[node_without_attrs.id] + assert empty_node_info["name"] == "Unnamed Node"