feat: adds cognee-user interactions to search

This commit is contained in:
hajdul88 2025-08-18 13:14:06 +02:00
parent 91d0d38e43
commit 711c805c83
9 changed files with 217 additions and 18 deletions

View file

@ -19,6 +19,7 @@ async def search(
top_k: int = 10,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
) -> list:
"""
Search and query the knowledge graph for insights, information, and connections.
@ -107,6 +108,8 @@ async def search(
node_name: Filter results to specific named entities (for targeted search).
save_interaction: Save interaction (query, context, answer connected to triplet endpoints) results into the graph or not
Returns:
list: Search results in format determined by query_type:
@ -189,6 +192,7 @@ async def search(
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
)
return filtered_search_results

View file

@ -29,6 +29,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
top_k: Optional[int] = 5,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
):
super().__init__(
user_prompt_path=user_prompt_path,
@ -36,6 +37,7 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
)
async def get_completion(
@ -105,11 +107,16 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
round_idx += 1
answer = await generate_completion(
completion = await generate_completion(
query=query,
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
)
return [answer]
if self.save_interaction and context and triplets and completion:
await self.save_qa(
question=query, answer=completion, context=context, triplets=triplets
)
return [completion]

View file

@ -35,6 +35,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
top_k: Optional[int] = 5,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
):
super().__init__(
user_prompt_path=user_prompt_path,
@ -42,6 +43,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
)
self.validation_system_prompt_path = validation_system_prompt_path
self.validation_user_prompt_path = validation_user_prompt_path
@ -75,7 +77,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
"""
followup_question = ""
triplets = []
answer = [""]
completion = [""]
for round_idx in range(max_iter + 1):
if round_idx == 0:
@ -85,15 +87,15 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
triplets += await self.get_triplets(followup_question)
context = await self.resolve_edges_to_text(list(set(triplets)))
answer = await generate_completion(
completion = await generate_completion(
query=query,
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
)
logger.info(f"Chain-of-thought: round {round_idx} - answer: {answer}")
logger.info(f"Chain-of-thought: round {round_idx} - answer: {completion}")
if round_idx < max_iter:
valid_args = {"query": query, "answer": answer, "context": context}
valid_args = {"query": query, "answer": completion, "context": context}
valid_user_prompt = LLMGateway.render_prompt(
filename=self.validation_user_prompt_path, context=valid_args
)
@ -106,7 +108,7 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
system_prompt=valid_system_prompt,
response_model=str,
)
followup_args = {"query": query, "answer": answer, "reasoning": reasoning}
followup_args = {"query": query, "answer": completion, "reasoning": reasoning}
followup_prompt = LLMGateway.render_prompt(
filename=self.followup_user_prompt_path, context=followup_args
)
@ -121,4 +123,9 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
f"Chain-of-thought: round {round_idx} - follow-up question: {followup_question}"
)
return [answer]
if self.save_interaction and context and triplets and completion:
await self.save_qa(
question=query, answer=completion, context=context, triplets=triplets
)
return [completion]

View file

@ -1,14 +1,20 @@
from typing import Any, Optional, Type, List
from typing import Any, Optional, Type, List, Coroutine
from collections import Counter
from uuid import NAMESPACE_OID, uuid5
import string
from cognee.infrastructure.engine import DataPoint
from cognee.tasks.storage import add_data_points
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
from cognee.modules.retrieval.utils.completion import generate_completion
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
from cognee.shared.logging_utils import get_logger
from cognee.modules.retrieval.utils.extract_uuid_from_node import extract_uuid_from_node
from cognee.modules.retrieval.utils.models import CogneeUserInteraction
from cognee.modules.engine.models.node_set import NodeSet
from cognee.infrastructure.databases.graph import get_graph_engine
logger = get_logger("GraphCompletionRetriever")
@ -33,8 +39,10 @@ class GraphCompletionRetriever(BaseRetriever):
top_k: Optional[int] = 5,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
):
"""Initialize retriever with prompt paths and search parameters."""
self.save_interaction = save_interaction
self.user_prompt_path = user_prompt_path
self.system_prompt_path = system_prompt_path
self.top_k = top_k if top_k is not None else 5
@ -118,7 +126,7 @@ class GraphCompletionRetriever(BaseRetriever):
return found_triplets
async def get_context(self, query: str) -> str:
async def get_context(self, query: str) -> str | tuple[str, list]:
"""
Retrieves and resolves graph triplets into context based on a query.
@ -139,7 +147,9 @@ class GraphCompletionRetriever(BaseRetriever):
logger.warning("Empty context was provided to the completion")
return ""
return await self.resolve_edges_to_text(triplets)
context = await self.resolve_edges_to_text(triplets)
return context, triplets
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
"""
@ -157,8 +167,10 @@ class GraphCompletionRetriever(BaseRetriever):
- Any: A generated completion based on the query and context provided.
"""
triplets = None
if context is None:
context = await self.get_context(query)
context, triplets = await self.get_context(query)
completion = await generate_completion(
query=query,
@ -166,6 +178,12 @@ class GraphCompletionRetriever(BaseRetriever):
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
)
if self.save_interaction and context and triplets and completion:
await self.save_qa(
question=query, answer=completion, context=context, triplets=triplets
)
return [completion]
def _top_n_words(self, text, stop_words=None, top_n=3, separator=", "):
@ -187,3 +205,67 @@ class GraphCompletionRetriever(BaseRetriever):
first_n_words = text.split()[:first_n_words]
top_n_words = self._top_n_words(text, top_n=top_n_words)
return f"{' '.join(first_n_words)}... [{top_n_words}]"
async def save_qa(self, question: str, answer: str, context: str, triplets: List) -> None:
"""
Saves a question and answer pair for later analysis or storage.
Parameters:
-----------
- question (str): The question text.
- answer (str): The answer text.
- context (str): The context text.
- triplets (List): A list of triples retrieved from the graph.
"""
nodeset_name = "Interactions"
interactions_node_set = NodeSet(
id=uuid5(NAMESPACE_OID, name=nodeset_name), name=nodeset_name
)
source_id = uuid5(NAMESPACE_OID, name=(question + answer + context))
cognee_user_interaction = CogneeUserInteraction(
id=source_id,
question=question,
answer=answer,
context=context,
belongs_to_set=interactions_node_set,
)
await add_data_points(data_points=[cognee_user_interaction], update_edge_collection=False)
relationships = []
relationship_name = "used_graph_element_to_answer"
for triplet in triplets:
target_id_1 = extract_uuid_from_node(triplet.node1)
target_id_2 = extract_uuid_from_node(triplet.node2)
if target_id_1 and target_id_2:
relationships.append(
(
source_id,
target_id_1,
relationship_name,
{
"relationship_name": relationship_name,
"source_node_id": source_id,
"target_node_id": target_id_1,
"ontology_valid": False,
},
)
)
relationships.append(
(
source_id,
target_id_2,
relationship_name,
{
"relationship_name": relationship_name,
"source_node_id": source_id,
"target_node_id": target_id_2,
"ontology_valid": False,
},
)
)
if len(relationships) > 0:
graph_engine = await get_graph_engine()
await graph_engine.add_edges(relationships)

View file

@ -24,6 +24,7 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
top_k: Optional[int] = 5,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
):
"""Initialize retriever with default prompt paths and search parameters."""
super().__init__(
@ -32,6 +33,7 @@ class GraphSummaryCompletionRetriever(GraphCompletionRetriever):
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
)
self.summarize_prompt_path = summarize_prompt_path

View file

@ -0,0 +1,18 @@
from typing import Any, Optional
from uuid import UUID
def extract_uuid_from_node(node: Any) -> Optional[UUID]:
"""
Try to pull a UUID string out of node.id or node.properties['id'],
then return a UUID instance (or None if neither exists).
"""
id_str = None
if not id_str:
id_str = getattr(node, "id", None)
if hasattr(node, "attributes") and not id_str:
id_str = node.attributes.get("id", None)
id = UUID(id_str) if isinstance(id_str, str) else None
return id

View file

@ -0,0 +1,36 @@
from typing import Optional
from cognee.infrastructure.engine.models.DataPoint import DataPoint
from cognee.modules.engine.models.node_set import NodeSet
from enum import Enum
from pydantic import BaseModel, ValidationError
class CogneeUserInteraction(DataPoint):
"""User - Cognee interaction"""
question: str
answer: str
context: str
belongs_to_set: Optional[NodeSet] = None
class CogneeUserFeedback(DataPoint):
"""User - Cognee Feedback"""
feedback: str
sentiment: str
belongs_to_set: Optional[NodeSet] = None
class UserFeedbackSentiment(str, Enum):
"""User - User feedback sentiment"""
positive = "positive"
negative = "negative"
neutral = "neutral"
class UserFeedbackEvaluation(BaseModel):
"""User - User feedback evaluation"""
evaluation: UserFeedbackSentiment

View file

@ -39,6 +39,7 @@ async def search(
top_k: int = 10,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
):
"""
@ -58,7 +59,7 @@ async def search(
# Use search function filtered by permissions if access control is enabled
if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
return await authorized_search(
query_text, query_type, user, dataset_ids, system_prompt_path, top_k
query_text, query_type, user, dataset_ids, system_prompt_path, top_k, save_interaction
)
query = await log_query(query_text, query_type.value, user.id)
@ -71,6 +72,7 @@ async def search(
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
)
await log_result(
@ -92,6 +94,7 @@ async def specific_search(
top_k: int = 10,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
) -> list:
search_tasks: dict[SearchType, Callable] = {
SearchType.SUMMARIES: SummariesRetriever(top_k=top_k).get_completion,
@ -105,24 +108,28 @@ async def specific_search(
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
).get_completion,
SearchType.GRAPH_COMPLETION_COT: GraphCompletionCotRetriever(
system_prompt_path=system_prompt_path,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
).get_completion,
SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: GraphCompletionContextExtensionRetriever(
system_prompt_path=system_prompt_path,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
).get_completion,
SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever(
system_prompt_path=system_prompt_path,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
).get_completion,
SearchType.CODE: CodeRetriever(top_k=top_k).get_completion,
SearchType.CYPHER: CypherSearchRetriever().get_completion,
@ -154,6 +161,7 @@ async def authorized_search(
dataset_ids: Optional[list[UUID]] = None,
system_prompt_path: str = "answer_simple_question.txt",
top_k: int = 10,
save_interaction: bool = False,
) -> list:
"""
Verifies access for provided datasets or uses all datasets user has read access for and performs search per dataset.
@ -167,7 +175,7 @@ async def authorized_search(
# Searches all provided datasets and handles setting up of appropriate database context based on permissions
search_results = await specific_search_by_context(
search_datasets, query_text, query_type, user, system_prompt_path, top_k
search_datasets, query_text, query_type, user, system_prompt_path, top_k, save_interaction
)
await log_result(query.id, json.dumps(search_results, cls=JSONEncoder), user.id)
@ -182,6 +190,7 @@ async def specific_search_by_context(
user: User,
system_prompt_path: str,
top_k: int,
save_interaction: bool = False,
):
"""
Searches all provided datasets and handles setting up of appropriate database context based on permissions.
@ -192,7 +201,12 @@ async def specific_search_by_context(
# Set database configuration in async context for each dataset user has access for
await set_database_global_context_variables(dataset.id, dataset.owner_id)
search_results = await specific_search(
query_type, query_text, user, system_prompt_path=system_prompt_path, top_k=top_k
query_type,
query_text,
user,
system_prompt_path=system_prompt_path,
top_k=top_k,
save_interaction=save_interaction,
)
return {
"search_result": search_results,

View file

@ -7,7 +7,36 @@ from .index_data_points import index_data_points
from .index_graph_edges import index_graph_edges
async def add_data_points(data_points: List[DataPoint]) -> List[DataPoint]:
async def add_data_points(
data_points: List[DataPoint], update_edge_collection=True
) -> List[DataPoint]:
"""
Add a batch of data points to the graph database by extracting nodes and edges,
deduplicating them, and indexing them for retrieval.
This function parallelizes the graph extraction for each data point,
merges the resulting nodes and edges, and ensures uniqueness before
committing them to the underlying graph engine. It also updates the
associated retrieval indices for nodes and (optionally) edges.
Args:
data_points (List[DataPoint]):
A list of data points to process and insert into the graph.
update_edge_collection (bool, optional):
Whether to update the edge index after adding edges.
Defaults to True.
Returns:
List[DataPoint]:
The original list of data points after processing and insertion.
Side Effects:
- Calls `get_graph_from_model` concurrently for each data point.
- Deduplicates nodes and edges across all results.
- Updates the node index via `index_data_points`.
- Inserts nodes and edges into the graph engine.
- Optionally updates the edge index via `index_graph_edges`.
"""
nodes = []
edges = []
@ -40,7 +69,7 @@ async def add_data_points(data_points: List[DataPoint]) -> List[DataPoint]:
await graph_engine.add_nodes(nodes)
await graph_engine.add_edges(edges)
# This step has to happen after adding nodes and edges because we query the graph.
await index_graph_edges()
if update_edge_collection:
await index_graph_edges()
return data_points