# Conflicts: # cognee/api/v1/add/add.py # cognee/api/v1/search/search.py # cognee/infrastructure/databases/graph/graph_db_interface.py # cognee/infrastructure/engine/models/DataPoint.py # cognee/modules/retrieval/graph_completion_retriever.py # cognee/modules/search/methods/search.py # cognee/modules/visualization/cognee_network_visualization.py # cognee/tasks/documents/classify_documents.py # cognee/tasks/ingestion/ingest_data.py # examples/python/simple_node_set_example.py
132 lines
5.5 KiB
Python
132 lines
5.5 KiB
Python
from typing import Any, Optional, Type, List
|
|
from collections import Counter
|
|
import string
|
|
|
|
from cognee.infrastructure.engine import DataPoint
|
|
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
|
|
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
|
|
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
|
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
|
|
from cognee.modules.retrieval.utils.completion import generate_completion
|
|
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
|
|
|
|
|
|
class GraphCompletionRetriever(BaseRetriever):
|
|
"""Retriever for handling graph-based completion searches."""
|
|
|
|
def __init__(
|
|
self,
|
|
user_prompt_path: str = "graph_context_for_question.txt",
|
|
system_prompt_path: str = "answer_simple_question.txt",
|
|
top_k: Optional[int] = 5,
|
|
node_type: Optional[Type] = None,
|
|
node_name: List[Optional[str]] = None,
|
|
):
|
|
"""Initialize retriever with prompt paths and search parameters."""
|
|
self.user_prompt_path = user_prompt_path
|
|
self.system_prompt_path = system_prompt_path
|
|
self.top_k = top_k if top_k is not None else 5
|
|
self.node_type = node_type
|
|
self.node_name = node_name
|
|
|
|
def _get_nodes(self, retrieved_edges: list) -> dict:
|
|
"""Creates a dictionary of nodes with their names and content."""
|
|
nodes = {}
|
|
for edge in retrieved_edges:
|
|
for node in (edge.node1, edge.node2):
|
|
if node.id not in nodes:
|
|
text = node.attributes.get("text")
|
|
if text:
|
|
name = self._get_title(text)
|
|
content = text
|
|
else:
|
|
name = node.attributes.get("name", "Unnamed Node")
|
|
content = name
|
|
nodes[node.id] = {"node": node, "name": name, "content": content}
|
|
return nodes
|
|
|
|
async def resolve_edges_to_text(self, retrieved_edges: list) -> str:
|
|
"""Converts retrieved graph edges into a human-readable string format."""
|
|
nodes = self._get_nodes(retrieved_edges)
|
|
node_section = "\n".join(
|
|
f"Node: {info['name']}\n__node_content_start__\n{info['content']}\n__node_content_end__\n"
|
|
for info in nodes.values()
|
|
)
|
|
connection_section = "\n".join(
|
|
f"{nodes[edge.node1.id]['name']} --[{edge.attributes['relationship_type']}]--> {nodes[edge.node2.id]['name']}"
|
|
for edge in retrieved_edges
|
|
)
|
|
return f"Nodes:\n{node_section}\n\nConnections:\n{connection_section}"
|
|
|
|
async def get_triplets(self, query: str) -> list:
|
|
"""Retrieves relevant graph triplets."""
|
|
subclasses = get_all_subclasses(DataPoint)
|
|
vector_index_collections = []
|
|
|
|
for subclass in subclasses:
|
|
if "metadata" in subclass.model_fields:
|
|
metadata_field = subclass.model_fields["metadata"]
|
|
if hasattr(metadata_field, "default") and metadata_field.default is not None:
|
|
if isinstance(metadata_field.default, dict):
|
|
index_fields = metadata_field.default.get("index_fields", [])
|
|
for field_name in index_fields:
|
|
vector_index_collections.append(f"{subclass.__name__}_{field_name}")
|
|
|
|
found_triplets = await brute_force_triplet_search(
|
|
query,
|
|
top_k=self.top_k,
|
|
collections=vector_index_collections or None,
|
|
node_type=self.node_type,
|
|
node_name=self.node_name,
|
|
)
|
|
|
|
if len(found_triplets) == 0:
|
|
#:TODO: quick and dirty solution for sf demo, as the triplets can be empty
|
|
return []
|
|
|
|
return found_triplets
|
|
|
|
async def get_context(self, query: str) -> str:
|
|
"""Retrieves and resolves graph triplets into context."""
|
|
try:
|
|
triplets = await self.get_triplets(query)
|
|
except EntityNotFoundError:
|
|
return ""
|
|
|
|
if len(triplets) == 0:
|
|
return ""
|
|
|
|
return await self.resolve_edges_to_text(triplets)
|
|
|
|
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
|
"""Generates a completion using graph connections context."""
|
|
if context is None:
|
|
context = await self.get_context(query)
|
|
|
|
completion = await generate_completion(
|
|
query=query,
|
|
context=context,
|
|
user_prompt_path=self.user_prompt_path,
|
|
system_prompt_path=self.system_prompt_path,
|
|
)
|
|
return [completion]
|
|
|
|
def _top_n_words(self, text, stop_words=None, top_n=3, separator=", "):
|
|
"""Concatenates the top N frequent words in text."""
|
|
if stop_words is None:
|
|
stop_words = DEFAULT_STOP_WORDS
|
|
|
|
words = [word.lower().strip(string.punctuation) for word in text.split()]
|
|
|
|
if stop_words:
|
|
words = [word for word in words if word and word not in stop_words]
|
|
|
|
top_words = [word for word, freq in Counter(words).most_common(top_n)]
|
|
|
|
return separator.join(top_words)
|
|
|
|
def _get_title(self, text: str, first_n_words: int = 7, top_n_words: int = 3) -> str:
|
|
"""Creates a title, by combining first words with most frequent words from the text."""
|
|
first_n_words = text.split()[:first_n_words]
|
|
top_n_words = self._top_n_words(text, top_n=top_n_words)
|
|
return f"{' '.join(first_n_words)}... [{top_n_words}]"
|