cognee/cognee/modules/retrieval/temporal_retriever.py

218 lines
8.1 KiB
Python

import os
import asyncio
from typing import Any, Optional, List, Type
from datetime import datetime
from operator import itemgetter
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
from cognee.modules.retrieval.utils.session_cache import (
save_conversation_history,
get_conversation_history,
)
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.llm.prompts import render_prompt
from cognee.infrastructure.llm import LLMGateway
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.shared.logging_utils import get_logger
from cognee.context_global_variables import session_user
from cognee.infrastructure.databases.cache.config import CacheConfig
from cognee.tasks.temporal_graph.models import QueryInterval
logger = get_logger()
class TemporalRetriever(GraphCompletionRetriever):
"""
Handles graph completion by generating responses based on a series of interactions with
a language model. This class extends from GraphCompletionRetriever and is designed to
manage the retrieval and validation process for user queries, integrating follow-up
questions based on reasoning. The public methods are:
- get_completion
Instance variables include:
- validation_system_prompt_path
- validation_user_prompt_path
- followup_system_prompt_path
- followup_user_prompt_path
"""
def __init__(
self,
user_prompt_path: str = "graph_context_for_question.txt",
system_prompt_path: str = "answer_simple_question.txt",
time_extraction_prompt_path: str = "extract_query_time.txt",
top_k: Optional[int] = 5,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5,
):
super().__init__(
user_prompt_path=user_prompt_path,
system_prompt_path=system_prompt_path,
top_k=top_k,
node_type=node_type,
node_name=node_name,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
)
self.user_prompt_path = user_prompt_path
self.system_prompt_path = system_prompt_path
self.time_extraction_prompt_path = time_extraction_prompt_path
self.top_k = top_k if top_k is not None else 5
self.node_type = node_type
self.node_name = node_name
def descriptions_to_string(self, results):
descs = []
for entry in results:
d = entry.get("description")
if d:
descs.append(d.strip())
return "\n#####################\n".join(descs)
async def extract_time_from_query(self, query: str):
prompt_path = self.time_extraction_prompt_path
if os.path.isabs(prompt_path):
base_directory = os.path.dirname(prompt_path)
prompt_path = os.path.basename(prompt_path)
else:
base_directory = None
time_now = datetime.now().strftime("%d-%m-%Y")
system_prompt = render_prompt(
prompt_path, {"time_now": time_now}, base_directory=base_directory
)
interval = await LLMGateway.acreate_structured_output(query, system_prompt, QueryInterval)
time_from = interval.starts_at
time_to = interval.ends_at
return time_from, time_to
async def filter_top_k_events(self, relevant_events, scored_results):
# Build a score lookup from vector search results
score_lookup = {res.id: res.score for res in scored_results}
events_with_scores = []
for event in relevant_events[0]["events"]:
score = score_lookup.get(event["id"], float("inf"))
events_with_scores.append({**event, "score": score})
events_with_scores.sort(key=itemgetter("score"))
return events_with_scores[: self.top_k]
async def get_context(self, query: str) -> Any:
"""Retrieves context based on the query."""
time_from, time_to = await self.extract_time_from_query(query)
graph_engine = await get_graph_engine()
if time_from and time_to:
ids = await graph_engine.collect_time_ids(time_from=time_from, time_to=time_to)
elif time_from:
ids = await graph_engine.collect_time_ids(time_from=time_from)
elif time_to:
ids = await graph_engine.collect_time_ids(time_to=time_to)
else:
logger.info(
"No timestamps identified based on the query, performing retrieval using triplet search on events and entities."
)
triplets = await self.get_triplets(query)
return await self.resolve_edges_to_text(triplets)
if ids:
relevant_events = await graph_engine.collect_events(ids=ids)
else:
logger.info(
"No events identified based on timestamp filtering, performing retrieval using triplet search on events and entities."
)
triplets = await self.get_triplets(query)
return await self.resolve_edges_to_text(triplets)
vector_engine = get_vector_engine()
query_vector = (await vector_engine.embedding_engine.embed_text([query]))[0]
vector_search_results = await vector_engine.search(
collection_name="Event_name", query_vector=query_vector, limit=None
)
top_k_events = await self.filter_top_k_events(relevant_events, vector_search_results)
return self.descriptions_to_string(top_k_events)
async def get_completion(
self,
query: str,
context: Optional[str] = None,
session_id: Optional[str] = None,
response_model: Type = str,
) -> List[Any]:
"""
Generates a response using the query and optional context.
Parameters:
-----------
- query (str): The query string for which a completion is generated.
- context (Optional[str]): Optional context to use; if None, it will be
retrieved based on the query. (default None)
- session_id (Optional[str]): Optional session identifier for caching. If None,
defaults to 'default_session'. (default None)
- response_model (Type): The Pydantic model type for structured output. (default str)
Returns:
--------
- List[str]: A list containing the generated completion.
"""
if not context:
context = await self.get_context(query=query)
if context:
# Check if we need to generate context summary for caching
cache_config = CacheConfig()
user = session_user.get()
user_id = getattr(user, "id", None)
session_save = user_id and cache_config.caching
if session_save:
conversation_history = await get_conversation_history(session_id=session_id)
context_summary, completion = await asyncio.gather(
summarize_text(context),
generate_completion(
query=query,
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
conversation_history=conversation_history,
response_model=response_model,
),
)
else:
completion = await generate_completion(
query=query,
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
response_model=response_model,
)
if session_save:
await save_conversation_history(
query=query,
context_summary=context_summary,
answer=completion,
session_id=session_id,
)
return [completion]