feat: adds a really naive temporal retriever

This commit is contained in:
hajdul88 2025-08-02 17:09:12 +02:00
parent 6119ac08de
commit 13769ce6fb
5 changed files with 280 additions and 6 deletions

View file

@ -28,6 +28,7 @@ from cognee.modules.data.models import Dataset
from cognee.shared.utils import send_telemetry
from cognee.modules.users.permissions.methods import get_specific_user_permission_datasets
from cognee.modules.search.operations import log_query, log_result
from cognee.temporal_poc.temporal_retriever import TemporalRetriever
async def search(
@ -127,6 +128,7 @@ async def specific_search(
SearchType.CODE: CodeRetriever(top_k=top_k).get_completion,
SearchType.CYPHER: CypherSearchRetriever().get_completion,
SearchType.NATURAL_LANGUAGE: NaturalLanguageRetriever().get_completion,
SearchType.TEMPORAL: TemporalRetriever().get_completion,
}
search_task = search_tasks.get(query_type)

View file

@ -13,3 +13,4 @@ class SearchType(Enum):
NATURAL_LANGUAGE = "NATURAL_LANGUAGE"
GRAPH_COMPLETION_COT = "GRAPH_COMPLETION_COT"
GRAPH_COMPLETION_CONTEXT_EXTENSION = "GRAPH_COMPLETION_CONTEXT_EXTENSION"
TEMPORAL = "TEMPORAL"

View file

@ -17,6 +17,11 @@ class Interval(BaseModel):
ends_at: Timestamp
class QueryInterval(BaseModel):
starts_at: Optional[Timestamp] = None
ends_at: Optional[Timestamp] = None
class Event(BaseModel):
name: str
description: Optional[str] = None

View file

@ -2,6 +2,8 @@ import asyncio
import cognee
from cognee.shared.logging_utils import setup_logging, INFO
from cognee.temporal_poc.temporal_cognify import temporal_cognify
from cognee.api.v1.search import SearchType
import json
from pathlib import Path
@ -25,14 +27,23 @@ async def reading_temporal_data():
async def main():
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
import random
texts = await reading_temporal_data()
texts = texts[:5]
if random.random() > 0.9999999:
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
await cognee.add(texts)
await temporal_cognify()
texts = await reading_temporal_data()
texts = texts[:5]
await cognee.add(texts)
await temporal_cognify()
search_results = await cognee.search(
query_type=SearchType.TEMPORAL, query_text="What happened in 2015"
)
print(search_results)
print()

View file

@ -0,0 +1,255 @@
from typing import Any, Optional, Type, List
from collections import Counter
import string
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
from cognee.modules.retrieval.base_retriever import BaseRetriever
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
from cognee.modules.retrieval.utils.completion import generate_completion
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
from cognee.shared.logging_utils import get_logger
from cognee.temporal_poc.models.models import QueryInterval
from cognee.temporal_poc.temporal_cognify import date_to_int
logger = get_logger("TemporalRetriever")
class TemporalRetriever(BaseRetriever):
def __init__(
self,
user_prompt_path: str = "graph_context_for_question.txt",
system_prompt_path: str = "answer_simple_question.txt",
top_k: Optional[int] = 5,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
):
"""Initialize retriever with prompt paths and search parameters."""
self.user_prompt_path = user_prompt_path
self.system_prompt_path = system_prompt_path
self.top_k = top_k if top_k is not None else 5
self.node_type = node_type
self.node_name = node_name
def _get_nodes(self, retrieved_edges: list) -> dict:
"""Creates a dictionary of nodes with their names and content."""
nodes = {}
for edge in retrieved_edges:
for node in (edge.node1, edge.node2):
if node.id not in nodes:
text = node.attributes.get("text")
if text:
name = self._get_title(text)
content = text
else:
name = node.attributes.get("name", "Unnamed Node")
content = node.attributes.get("description", name)
nodes[node.id] = {"node": node, "name": name, "content": content}
return nodes
async def resolve_edges_to_text(self, retrieved_edges: list) -> str:
nodes = self._get_nodes(retrieved_edges)
node_section = "\n".join(
f"Node: {info['name']}\n__node_content_start__\n{info['content']}\n__node_content_end__\n"
for info in nodes.values()
)
connection_section = "\n".join(
f"{nodes[edge.node1.id]['name']} --[{edge.attributes['relationship_type']}]--> {nodes[edge.node2.id]['name']}"
for edge in retrieved_edges
)
return f"Nodes:\n{node_section}\n\nConnections:\n{connection_section}"
async def get_triplets(self, query: str) -> list:
"""
Retrieves relevant graph triplets based on a query string.
Parameters:
-----------
- query (str): The query string used to search for relevant triplets in the graph.
Returns:
--------
- list: A list of found triplets that match the query.
"""
subclasses = get_all_subclasses(DataPoint)
vector_index_collections = []
for subclass in subclasses:
if "metadata" in subclass.model_fields:
metadata_field = subclass.model_fields["metadata"]
if hasattr(metadata_field, "default") and metadata_field.default is not None:
if isinstance(metadata_field.default, dict):
index_fields = metadata_field.default.get("index_fields", [])
for field_name in index_fields:
vector_index_collections.append(f"{subclass.__name__}_{field_name}")
found_triplets = await brute_force_triplet_search(
query,
top_k=self.top_k,
collections=vector_index_collections or None,
node_type=self.node_type,
node_name=self.node_name,
)
return found_triplets
async def extract_time_from_query(self, query: str):
llm_client = get_llm_client()
system_prompt = """
For the purposes of identifying timestamps in a query, you are tasked with extracting relevant timestamps from the query.
## Timestamp requirements
- If the query contains interval extrack both starts_at and ends_at properties
- If the query contains an instantaneous timestamp, starts_at and ends_at should be the same
- If the query its open ended (before 2009 or after 2009), the corresponding non defined end of the time should be none
-For example: "before 2009" -- starts_at: None, ends_at: 2009 or "after 2009" -- starts_at: 2009, ends_at: None
- Put always the data that comes first in time as starts_at and the timestamps that comes second in time as ends_at
## Output Format
Your reply should be a JSON: list of dictionaries with the following structure:
```python
class QueryInterval(BaseModel):
starts_at: Optional[Timestamp] = None
ends_at: Optional[Timestamp] = None
```
"""
interval = await llm_client.acreate_structured_output(query, system_prompt, QueryInterval)
return interval
def descriptions_to_string(self, results):
descs = []
for entry in results:
events = entry.get("events", [])
for ev in events:
d = ev.get("description")
if d:
descs.append(d.strip())
return "\n-".join(descs)
async def get_context(self, query: str) -> str:
# :TODO: This is a POC and yes this method is far far far far from nice :D
graph_engine = await get_graph_engine()
interval = await self.extract_time_from_query(query=query)
time_from = interval.starts_at
time_to = interval.ends_at
event_collection_cypher = """UNWIND [{quoted}] AS uid
MATCH (start {{id: uid}})
MATCH (start)-[*1..2]-(event)
WHERE event.type = 'Event'
WITH DISTINCT event
RETURN collect(event) AS events;
"""
if time_from and time_to:
time_from = date_to_int(time_from)
time_to = date_to_int(time_to)
cypher = """
MATCH (n)
WHERE n.type = 'Timestamp'
AND n.time_at >= $time_from
AND n.time_at <= $time_to
RETURN n.id AS id
"""
params = {"time_from": time_from, "time_to": time_to}
time_nodes = await graph_engine.query(cypher, params)
time_ids_list = [item["id"] for item in time_nodes if "id" in item]
ids = ", ".join("'{0}'".format(uid) for uid in time_ids_list)
event_collection_cypher = event_collection_cypher.format(quoted=ids)
relevant_events = await graph_engine.query(event_collection_cypher)
context = self.descriptions_to_string(relevant_events)
return context
elif time_from:
time_from = date_to_int(time_from)
cypher = """
MATCH (n)
WHERE n.type = 'Timestamp'
AND n.time_at >= $time_from
RETURN n.id AS id
"""
params = {"time_from": time_from}
time_nodes = await graph_engine.query(cypher, params)
time_ids_list = [item["id"] for item in time_nodes if "id" in item]
ids = ", ".join("'{0}'".format(uid) for uid in time_ids_list)
event_collection_cypher = event_collection_cypher.format(quoted=ids)
relevant_events = await graph_engine.query(event_collection_cypher)
context = self.descriptions_to_string(relevant_events)
return context
elif time_to:
time_to = date_to_int(time_to)
cypher = """
MATCH (n)
WHERE n.type = 'Timestamp'
AND n.time_at <= $time_to
RETURN n.id AS id
"""
params = {"time_to": time_to}
time_nodes = await graph_engine.query(cypher, params)
time_ids_list = [item["id"] for item in time_nodes if "id" in item]
ids = ", ".join("'{0}'".format(uid) for uid in time_ids_list)
event_collection_cypher = event_collection_cypher.format(quoted=ids)
relevant_events = await graph_engine.query(event_collection_cypher)
context = self.descriptions_to_string(relevant_events)
return context
else:
logger.info(
"We couldn't find any timestamps in this query therefore we return empty context"
)
return ""
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
if context is None:
context = await self.get_context(query)
completion = await generate_completion(
query=query,
context=context,
user_prompt_path=self.user_prompt_path,
system_prompt_path=self.system_prompt_path,
)
return [completion]
def _top_n_words(self, text, stop_words=None, top_n=3, separator=", "):
if stop_words is None:
stop_words = DEFAULT_STOP_WORDS
words = [word.lower().strip(string.punctuation) for word in text.split()]
if stop_words:
words = [word for word in words if word and word not in stop_words]
top_words = [word for word, freq in Counter(words).most_common(top_n)]
return separator.join(top_words)
def _get_title(self, text: str, first_n_words: int = 7, top_n_words: int = 3) -> str:
first_n_words = text.split()[:first_n_words]
top_n_words = self._top_n_words(text, top_n=top_n_words)
return f"{' '.join(first_n_words)}... [{top_n_words}]"