<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin.
116 lines
4.9 KiB
Python
116 lines
4.9 KiB
Python
from typing import Any, Optional
|
|
from collections import Counter
|
|
import string
|
|
|
|
from cognee.infrastructure.engine import DataPoint
|
|
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
|
|
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
|
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
|
|
from cognee.modules.retrieval.utils.completion import generate_completion
|
|
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
|
|
|
|
|
|
class GraphCompletionRetriever(BaseRetriever):
|
|
"""Retriever for handling graph-based completion searches."""
|
|
|
|
def __init__(
|
|
self,
|
|
user_prompt_path: str = "graph_context_for_question.txt",
|
|
system_prompt_path: str = "answer_simple_question.txt",
|
|
top_k: Optional[int] = 5,
|
|
):
|
|
"""Initialize retriever with prompt paths and search parameters."""
|
|
self.user_prompt_path = user_prompt_path
|
|
self.system_prompt_path = system_prompt_path
|
|
self.top_k = top_k if top_k is not None else 5
|
|
|
|
def _get_nodes(self, retrieved_edges: list) -> dict:
|
|
"""Creates a dictionary of nodes with their names and content."""
|
|
nodes = {}
|
|
for edge in retrieved_edges:
|
|
for node in (edge.node1, edge.node2):
|
|
if node.id not in nodes:
|
|
text = node.attributes.get("text")
|
|
if text:
|
|
name = self._get_title(text)
|
|
content = text
|
|
else:
|
|
name = node.attributes.get("name", "Unnamed Node")
|
|
content = name
|
|
nodes[node.id] = {"node": node, "name": name, "content": content}
|
|
return nodes
|
|
|
|
async def resolve_edges_to_text(self, retrieved_edges: list) -> str:
|
|
"""Converts retrieved graph edges into a human-readable string format."""
|
|
nodes = self._get_nodes(retrieved_edges)
|
|
node_section = "\n".join(
|
|
f"Node: {info['name']}\n__node_content_start__\n{info['content']}\n__node_content_end__\n"
|
|
for info in nodes.values()
|
|
)
|
|
connection_section = "\n".join(
|
|
f"{nodes[edge.node1.id]['name']} --[{edge.attributes['relationship_type']}]--> {nodes[edge.node2.id]['name']}"
|
|
for edge in retrieved_edges
|
|
)
|
|
return f"Nodes:\n{node_section}\n\nConnections:\n{connection_section}"
|
|
|
|
async def get_triplets(self, query: str) -> list:
|
|
"""Retrieves relevant graph triplets."""
|
|
subclasses = get_all_subclasses(DataPoint)
|
|
vector_index_collections = []
|
|
|
|
for subclass in subclasses:
|
|
if "metadata" in subclass.model_fields:
|
|
metadata_field = subclass.model_fields["metadata"]
|
|
if hasattr(metadata_field, "default") and metadata_field.default is not None:
|
|
if isinstance(metadata_field.default, dict):
|
|
index_fields = metadata_field.default.get("index_fields", [])
|
|
for field_name in index_fields:
|
|
vector_index_collections.append(f"{subclass.__name__}_{field_name}")
|
|
|
|
found_triplets = await brute_force_triplet_search(
|
|
query, top_k=self.top_k, collections=vector_index_collections or None
|
|
)
|
|
|
|
return found_triplets
|
|
|
|
async def get_context(self, query: str) -> str:
|
|
"""Retrieves and resolves graph triplets into context."""
|
|
triplets = await self.get_triplets(query)
|
|
|
|
if len(triplets) == 0:
|
|
return ""
|
|
|
|
return await self.resolve_edges_to_text(triplets)
|
|
|
|
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
|
"""Generates a completion using graph connections context."""
|
|
if context is None:
|
|
context = await self.get_context(query)
|
|
|
|
completion = await generate_completion(
|
|
query=query,
|
|
context=context,
|
|
user_prompt_path=self.user_prompt_path,
|
|
system_prompt_path=self.system_prompt_path,
|
|
)
|
|
return [completion]
|
|
|
|
def _top_n_words(self, text, stop_words=None, top_n=3, separator=", "):
|
|
"""Concatenates the top N frequent words in text."""
|
|
if stop_words is None:
|
|
stop_words = DEFAULT_STOP_WORDS
|
|
|
|
words = [word.lower().strip(string.punctuation) for word in text.split()]
|
|
|
|
if stop_words:
|
|
words = [word for word in words if word and word not in stop_words]
|
|
|
|
top_words = [word for word, freq in Counter(words).most_common(top_n)]
|
|
|
|
return separator.join(top_words)
|
|
|
|
def _get_title(self, text: str, first_n_words: int = 7, top_n_words: int = 3) -> str:
|
|
"""Creates a title, by combining first words with most frequent words from the text."""
|
|
first_n_words = text.split()[:first_n_words]
|
|
top_n_words = self._top_n_words(text, top_n=top_n_words)
|
|
return f"{' '.join(first_n_words)}... [{top_n_words}]"
|