Compare commits
7 commits
main
...
feature/co
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2149bb5601 | ||
|
|
74acf029a9 | ||
|
|
85e32de418 | ||
|
|
13769ce6fb | ||
|
|
6119ac08de | ||
|
|
e90cbc43dd | ||
|
|
f2e66bc276 |
11 changed files with 748 additions and 2 deletions
|
|
@ -3,6 +3,8 @@ from typing import List
|
|||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.data.processing.document_types import Document
|
||||
from cognee.modules.engine.models import Entity
|
||||
from typing import Union
|
||||
from cognee.temporal_poc.models.models import Event
|
||||
|
||||
|
||||
class DocumentChunk(DataPoint):
|
||||
|
|
@ -30,6 +32,6 @@ class DocumentChunk(DataPoint):
|
|||
chunk_index: int
|
||||
cut_type: str
|
||||
is_part_of: Document
|
||||
contains: List[Entity] = None
|
||||
contains: List[Union[Entity, Event]] = None
|
||||
|
||||
metadata: dict = {"index_fields": ["text"]}
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ from cognee.modules.users.models import User
|
|||
from cognee.modules.data.models import Dataset
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee.modules.users.permissions.methods import get_specific_user_permission_datasets
|
||||
from cognee.temporal_poc.temporal_retriever import TemporalRetriever
|
||||
from cognee.modules.search.operations import log_query, log_result, select_search_type
|
||||
|
||||
|
||||
|
|
@ -127,6 +128,7 @@ async def specific_search(
|
|||
SearchType.CODE: CodeRetriever(top_k=top_k).get_completion,
|
||||
SearchType.CYPHER: CypherSearchRetriever().get_completion,
|
||||
SearchType.NATURAL_LANGUAGE: NaturalLanguageRetriever().get_completion,
|
||||
SearchType.TEMPORAL: TemporalRetriever().get_completion,
|
||||
}
|
||||
|
||||
# If the query type is FEELING_LUCKY, select the search type intelligently
|
||||
|
|
|
|||
|
|
@ -13,4 +13,5 @@ class SearchType(Enum):
|
|||
NATURAL_LANGUAGE = "NATURAL_LANGUAGE"
|
||||
GRAPH_COMPLETION_COT = "GRAPH_COMPLETION_COT"
|
||||
GRAPH_COMPLETION_CONTEXT_EXTENSION = "GRAPH_COMPLETION_CONTEXT_EXTENSION"
|
||||
TEMPORAL = "TEMPORAL"
|
||||
FEELING_LUCKY = "FEELING_LUCKY"
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
import asyncio
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
|
||||
vector_index_lock = asyncio.Lock()
|
||||
logger = get_logger("index_data_points")
|
||||
|
||||
# A single lock shared by all coroutines
|
||||
|
|
|
|||
33
cognee/temporal_poc/datapoints/datapoints.py
Normal file
33
cognee/temporal_poc/datapoints/datapoints.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.engine.models.EntityType import EntityType
|
||||
from typing import Optional, List, Any
|
||||
from pydantic import BaseModel, Field, ConfigDict, SkipValidation
|
||||
from cognee.infrastructure.engine.models.Edge import Edge
|
||||
from cognee.modules.engine.models.Entity import Entity
|
||||
|
||||
|
||||
class Timestamp(DataPoint):
|
||||
time_at: int = Field(...)
|
||||
year: int = Field(...)
|
||||
month: int = Field(...)
|
||||
day: int = Field(...)
|
||||
hour: int = Field(...)
|
||||
minute: int = Field(...)
|
||||
second: int = Field(...)
|
||||
timestamp_str: str = Field(...)
|
||||
|
||||
|
||||
class Interval(DataPoint):
|
||||
time_from: Timestamp = Field(...)
|
||||
time_to: Timestamp = Field(...)
|
||||
|
||||
|
||||
class Event(DataPoint):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
at: Optional[Timestamp] = None
|
||||
during: Optional[Interval] = None
|
||||
location: Optional[str] = None
|
||||
attributes: SkipValidation[Any] = None # (Edge, list[Entity])
|
||||
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
130
cognee/temporal_poc/event_extraction.py
Normal file
130
cognee/temporal_poc/event_extraction.py
Normal file
|
|
@ -0,0 +1,130 @@
|
|||
import asyncio
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Type, List
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
|
||||
from cognee.modules.engine.utils import generate_node_id
|
||||
from cognee.temporal_poc.models.models import EventList
|
||||
from cognee.temporal_poc.datapoints.datapoints import Interval, Timestamp, Event
|
||||
|
||||
|
||||
# Global system prompt for event extraction
|
||||
EVENT_EXTRACTION_SYSTEM_PROMPT = """
|
||||
For the purposes of building event-based knowledge graphs, you are tasked with extracting highly granular stream events from a text. The events are defined as follows:
|
||||
## Event Definition
|
||||
- Anything with a date or a timestamp is an event
|
||||
- Anything that took place in time (even if the time is unknown) is an event
|
||||
- Anything that lasted over a period of time, or happened in an instant is an event: from historical milestones (wars, presidencies, olympiads) to personal milestones (birth, death, employment, etc.), to mundane actions (a walk, a conversation, etc.)
|
||||
- **ANY action or verb represents an event** - this is the most important rule
|
||||
- Every single verb in the text corresponds to an event that must be extracted
|
||||
- This includes: thinking, feeling, seeing, hearing, moving, speaking, writing, reading, eating, sleeping, working, playing, studying, traveling, meeting, calling, texting, buying, selling, creating, destroying, building, breaking, starting, stopping, beginning, ending, etc.
|
||||
- Even the most mundane or obvious actions are events: "he walked", "she sat", "they talked", "I thought", "we waited"
|
||||
## Requirements
|
||||
- **Be extremely thorough** - extract EVERY event mentioned, no matter how small or obvious
|
||||
- **Timestamped first" - every time stamp, or date should have atleast one event
|
||||
- **Verbs/actions = one event** - After you are done with timestamped events -- every verb that is an action should have a corresponding event.
|
||||
- We expect long streams of events from any piece of text, easily reaching a hundred events
|
||||
- Granularity and richness of the stream is key to our success and is of utmost importance
|
||||
- Not all events will have timestamps, add timestamps only to known events
|
||||
- For events that were instantaneous, just attach the time_from or time_to property don't create both
|
||||
- **Do not skip any events** - if you're unsure whether something is an event, extract it anyway
|
||||
- **Quantity over filtering** - it's better to extract too many events than to miss any
|
||||
- **Descriptions** - Always include the event description together with entities (Who did what, what happened? What is the event?). If you can include the corresponding part from the text.
|
||||
## Output Format
|
||||
Your reply should be a JSON: list of dictionaries with the following structure:
|
||||
```python
|
||||
class Event(BaseModel):
|
||||
name: str [concise]
|
||||
description: Optional[str] = None
|
||||
time_from: Optional[Timestamp] = None
|
||||
time_to: Optional[Timestamp] = None
|
||||
location: Optional[str] = None
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def date_to_int(ts: Timestamp) -> int:
|
||||
"""Convert timestamp to integer milliseconds."""
|
||||
dt = datetime(ts.year, ts.month, ts.day, ts.hour, ts.minute, ts.second, tzinfo=timezone.utc)
|
||||
time = int(dt.timestamp() * 1000)
|
||||
return time
|
||||
|
||||
|
||||
def create_timestamp_datapoint(ts: Timestamp) -> Timestamp:
|
||||
"""Create a Timestamp datapoint from a Timestamp model."""
|
||||
time_at = date_to_int(ts)
|
||||
timestamp_str = (
|
||||
f"{ts.year:04d}-{ts.month:02d}-{ts.day:02d} {ts.hour:02d}:{ts.minute:02d}:{ts.second:02d}"
|
||||
)
|
||||
return Timestamp(
|
||||
id=generate_node_id(str(time_at)),
|
||||
time_at=time_at,
|
||||
year=ts.year,
|
||||
month=ts.month,
|
||||
day=ts.day,
|
||||
hour=ts.hour,
|
||||
minute=ts.minute,
|
||||
second=ts.second,
|
||||
timestamp_str=timestamp_str,
|
||||
)
|
||||
|
||||
|
||||
def create_event_datapoint(event) -> Event:
|
||||
"""Create an Event datapoint from an event model."""
|
||||
# Base event data
|
||||
event_data = {
|
||||
"name": event.name,
|
||||
"description": event.description,
|
||||
"location": event.location,
|
||||
}
|
||||
|
||||
# Create timestamps if they exist
|
||||
time_from = create_timestamp_datapoint(event.time_from) if event.time_from else None
|
||||
time_to = create_timestamp_datapoint(event.time_to) if event.time_to else None
|
||||
|
||||
# Add temporal information
|
||||
if time_from and time_to:
|
||||
event_data["during"] = Interval(time_from=time_from, time_to=time_to)
|
||||
# Enrich description with temporal info
|
||||
temporal_info = f"\n---\nTime data: {time_from.timestamp_str} to {time_to.timestamp_str}"
|
||||
event_data["description"] = (event_data["description"] or "Event") + temporal_info
|
||||
elif time_from or time_to:
|
||||
timestamp = time_from or time_to
|
||||
event_data["at"] = timestamp
|
||||
# Enrich description with temporal info
|
||||
temporal_info = f"\n---\nTime data: {timestamp.timestamp_str}"
|
||||
event_data["description"] = (event_data["description"] or "Event") + temporal_info
|
||||
|
||||
return Event(**event_data)
|
||||
|
||||
|
||||
async def extract_event_graph(
|
||||
content: str, response_model: Type[BaseModel], system_prompt: str = None
|
||||
):
|
||||
"""Extract event graph from content using LLM."""
|
||||
|
||||
if system_prompt is None:
|
||||
system_prompt = EVENT_EXTRACTION_SYSTEM_PROMPT
|
||||
|
||||
content_graph = await LLMGateway.acreate_structured_output(
|
||||
content, system_prompt, response_model
|
||||
)
|
||||
|
||||
return content_graph
|
||||
|
||||
|
||||
async def extract_events_and_entities(data_chunks: List[DocumentChunk]) -> List[DocumentChunk]:
|
||||
"""Extracts events and entities from a chunk of documents."""
|
||||
events = await asyncio.gather(
|
||||
*[extract_event_graph(chunk.text, EventList) for chunk in data_chunks]
|
||||
)
|
||||
|
||||
for data_chunk, event_list in zip(data_chunks, events):
|
||||
for event in event_list.events:
|
||||
event_datapoint = create_event_datapoint(event)
|
||||
data_chunk.contains.append(event_datapoint)
|
||||
|
||||
return data_chunks
|
||||
153
cognee/temporal_poc/event_knowledge_graph.py
Normal file
153
cognee/temporal_poc/event_knowledge_graph.py
Normal file
|
|
@ -0,0 +1,153 @@
|
|||
from typing import List, Type
|
||||
from pydantic import BaseModel
|
||||
|
||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
|
||||
from cognee.modules.engine.models.Entity import Entity
|
||||
from cognee.modules.engine.models.EntityType import EntityType
|
||||
from cognee.infrastructure.engine.models.Edge import Edge
|
||||
from cognee.modules.engine.utils import generate_node_id, generate_node_name
|
||||
from cognee.temporal_poc.models.models import EventEntityList
|
||||
from cognee.temporal_poc.datapoints.datapoints import Event
|
||||
from cognee.temporal_poc.models.models import EventWithEntities
|
||||
|
||||
ENTITY_EXTRACTION_SYSTEM_PROMPT = """For the purposes of building event-based knowledge graphs, you are tasked with extracting highly granular entities from events text. An entity is any distinct, identifiable thing, person, place, object, organization, concept, or phenomenon that can be named, referenced, or described in the event context. This includes but is not limited to: people, places, objects, organizations, concepts, events, processes, states, conditions, properties, attributes, roles, functions, and any other meaningful referents that contribute to understanding the event.
|
||||
**Temporal Entity Exclusion**: Do not extract timestamp-like entities (dates, times, durations) as these are handled separately. However, extract named temporal periods, eras, historical epochs, and culturally significant time references
|
||||
## Input Format
|
||||
The input will be a list of dictionaries, each containing:
|
||||
- `event_name`: The name of the event
|
||||
- `description`: The description of the event
|
||||
|
||||
## Task
|
||||
For each event, extract all entities mentioned in the event description and determine their relationship to the event.
|
||||
|
||||
## Output Format
|
||||
Return the same enriched JSON with an additional key in each dictionary: `attributes`.
|
||||
|
||||
The `attributes` should be a list of dictionaries, each containing:
|
||||
- `entity`: The name of the entity
|
||||
- `entity_type`: The type/category of the entity (person, place, organization, object, concept, etc.)
|
||||
- `relationship`: A concise description of how the entity relates to the event
|
||||
|
||||
## Requirements
|
||||
- **Be extremely thorough** - extract EVERY non-temporal entity mentioned, no matter how small, obvious, or seemingly insignificant
|
||||
- **After you are done with obvious entities, every noun, pronoun, proper noun, and named reference = one entity**
|
||||
- We expect rich entity networks from any event, easily reaching a dozens of entities per event
|
||||
- Granularity and richness of the entity extraction is key to our success and is of utmost importance
|
||||
- **Do not skip any entities** - if you're unsure whether something is an entity, extract it anyway
|
||||
- Use the event name for context when determining relationships
|
||||
- Relationships should be technical with one or at most two words. If two words, use underscore camelcase style
|
||||
- Relationships could imply general meaning like: subject, object, participant, recipient, agent, instrument, tool, source, cause, effect, purpose, manner, resource, etc.
|
||||
- You can combine two words to form a relationship name: subject_role, previous_owner, etc.
|
||||
- Focus on how the entity specifically relates to the event
|
||||
"""
|
||||
|
||||
|
||||
async def extract_event_entities(
|
||||
content: str, response_model: Type[BaseModel], system_prompt: str = None
|
||||
):
|
||||
"""Extract event entities from content using LLM."""
|
||||
|
||||
if system_prompt is None:
|
||||
system_prompt = ENTITY_EXTRACTION_SYSTEM_PROMPT
|
||||
|
||||
content_graph = await LLMGateway.acreate_structured_output(
|
||||
content, system_prompt, response_model
|
||||
)
|
||||
|
||||
return content_graph
|
||||
|
||||
|
||||
async def enrich_events(events: List[Event]) -> List[EventWithEntities]:
|
||||
"""Extract entities from events and return enriched events."""
|
||||
import json
|
||||
|
||||
# Convert events to JSON format for LLM processing
|
||||
events_json = [
|
||||
{"event_name": event.name, "description": event.description or ""} for event in events
|
||||
]
|
||||
|
||||
events_json_str = json.dumps(events_json)
|
||||
|
||||
# Extract entities from events
|
||||
entity_result = await extract_event_entities(events_json_str, EventEntityList)
|
||||
|
||||
return entity_result.events
|
||||
|
||||
|
||||
def add_entities_to_event(event: Event, event_with_entities: EventWithEntities) -> None:
|
||||
"""Add entities to event via attributes field."""
|
||||
if not event_with_entities.attributes:
|
||||
return
|
||||
|
||||
# Create entity types cache
|
||||
entity_types = {}
|
||||
|
||||
# Process each attribute
|
||||
for attribute in event_with_entities.attributes:
|
||||
# Get or create entity type
|
||||
entity_type = get_or_create_entity_type(entity_types, attribute.entity_type)
|
||||
|
||||
# Create entity
|
||||
entity_id = generate_node_id(attribute.entity)
|
||||
entity_name = generate_node_name(attribute.entity)
|
||||
entity = Entity(
|
||||
id=entity_id,
|
||||
name=entity_name,
|
||||
is_a=entity_type,
|
||||
description=f"Entity {attribute.entity} of type {attribute.entity_type}",
|
||||
ontology_valid=False,
|
||||
belongs_to_set=None,
|
||||
)
|
||||
|
||||
# Create edge
|
||||
edge = Edge(relationship_type=attribute.relationship)
|
||||
|
||||
# Add to event attributes
|
||||
if event.attributes is None:
|
||||
event.attributes = []
|
||||
event.attributes.append((edge, [entity]))
|
||||
|
||||
|
||||
def get_or_create_entity_type(entity_types: dict, entity_type_name: str) -> EntityType:
|
||||
"""Get existing entity type or create new one."""
|
||||
if entity_type_name not in entity_types:
|
||||
type_id = generate_node_id(entity_type_name)
|
||||
type_name = generate_node_name(entity_type_name)
|
||||
entity_type = EntityType(
|
||||
id=type_id,
|
||||
name=type_name,
|
||||
type=type_name,
|
||||
description=f"Type for {entity_type_name}",
|
||||
ontology_valid=False,
|
||||
)
|
||||
entity_types[entity_type_name] = entity_type
|
||||
|
||||
return entity_types[entity_type_name]
|
||||
|
||||
|
||||
async def extract_event_knowledge_graph(data_chunks: List[DocumentChunk]) -> List[DocumentChunk]:
|
||||
"""Extract events from chunks and enrich them with entities."""
|
||||
# Extract events from chunks
|
||||
all_events = []
|
||||
for chunk in data_chunks:
|
||||
for item in chunk.contains:
|
||||
if isinstance(item, Event):
|
||||
all_events.append(item)
|
||||
|
||||
if not all_events:
|
||||
return data_chunks
|
||||
|
||||
# Enrich events with entities
|
||||
enriched_events = await enrich_events(all_events)
|
||||
|
||||
# Add entities to events
|
||||
for event, enriched_event in zip(all_events, enriched_events):
|
||||
add_entities_to_event(event, enriched_event)
|
||||
|
||||
return data_chunks
|
||||
|
||||
|
||||
async def process_event_knowledge_graph(data_chunks: List[DocumentChunk]) -> List[DocumentChunk]:
|
||||
"""Process document chunks for event knowledge graph construction."""
|
||||
return await extract_event_knowledge_graph(data_chunks)
|
||||
50
cognee/temporal_poc/models/models.py
Normal file
50
cognee/temporal_poc/models/models.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
from typing import Optional, Literal, List
|
||||
from pydantic import BaseModel, Field, root_validator, ValidationError
|
||||
from cognee.modules.engine.models.Entity import Entity
|
||||
|
||||
|
||||
class Timestamp(BaseModel):
|
||||
year: int = Field(..., ge=1, le=9999)
|
||||
month: int = Field(..., ge=1, le=12)
|
||||
day: int = Field(..., ge=1, le=31)
|
||||
hour: int = Field(..., ge=0, le=23)
|
||||
minute: int = Field(..., ge=0, le=59)
|
||||
second: int = Field(..., ge=0, le=59)
|
||||
|
||||
|
||||
class Interval(BaseModel):
|
||||
starts_at: Timestamp
|
||||
ends_at: Timestamp
|
||||
|
||||
|
||||
class QueryInterval(BaseModel):
|
||||
starts_at: Optional[Timestamp] = None
|
||||
ends_at: Optional[Timestamp] = None
|
||||
|
||||
|
||||
class Event(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
time_from: Optional[Timestamp] = None
|
||||
time_to: Optional[Timestamp] = None
|
||||
location: Optional[str] = None
|
||||
|
||||
|
||||
class EventList(BaseModel):
|
||||
events: List[Event]
|
||||
|
||||
|
||||
class EntityAttribute(BaseModel):
|
||||
entity: str
|
||||
entity_type: str
|
||||
relationship: str
|
||||
|
||||
|
||||
class EventWithEntities(BaseModel):
|
||||
event_name: str
|
||||
description: Optional[str] = None
|
||||
attributes: List[EntityAttribute] = []
|
||||
|
||||
|
||||
class EventEntityList(BaseModel):
|
||||
events: List[EventWithEntities]
|
||||
64
cognee/temporal_poc/temporal_cognify.py
Normal file
64
cognee/temporal_poc/temporal_cognify.py
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
from typing import Union, Optional, List
|
||||
from uuid import UUID
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.shared.data_models import KnowledgeGraph
|
||||
from cognee.infrastructure.llm import get_max_chunk_tokens, get_llm_config
|
||||
|
||||
from cognee.api.v1.cognify.cognify import run_cognify_blocking
|
||||
from cognee.modules.pipelines.tasks.task import Task
|
||||
from cognee.modules.chunking.TextChunker import TextChunker
|
||||
from cognee.modules.ontology.rdf_xml.OntologyResolver import OntologyResolver
|
||||
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.tasks.documents import (
|
||||
check_permissions_on_dataset,
|
||||
classify_documents,
|
||||
extract_chunks_from_documents,
|
||||
)
|
||||
from cognee.tasks.graph import extract_graph_from_data
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.tasks.summarization import summarize_text
|
||||
from cognee.temporal_poc.event_extraction import extract_events_and_entities
|
||||
from cognee.temporal_poc.event_knowledge_graph import process_event_knowledge_graph
|
||||
|
||||
logger = get_logger("temporal_cognify")
|
||||
|
||||
|
||||
async def get_temporal_tasks(
|
||||
user: User = None, chunker=TextChunker, chunk_size: int = None
|
||||
) -> list[Task]:
|
||||
temporal_tasks = [
|
||||
Task(classify_documents),
|
||||
Task(check_permissions_on_dataset, user=user, permissions=["write"]),
|
||||
Task(
|
||||
extract_chunks_from_documents,
|
||||
max_chunk_size=chunk_size or get_max_chunk_tokens(),
|
||||
chunker=chunker,
|
||||
),
|
||||
Task(extract_events_and_entities, task_config={"chunk_size": 10}),
|
||||
Task(process_event_knowledge_graph),
|
||||
Task(add_data_points, task_config={"batch_size": 10}),
|
||||
]
|
||||
|
||||
return temporal_tasks
|
||||
|
||||
|
||||
async def temporal_cognify(
|
||||
datasets: Union[str, list[str], list[UUID]] = None,
|
||||
user: User = None,
|
||||
chunker=TextChunker,
|
||||
chunk_size: int = None,
|
||||
vector_db_config: dict = None,
|
||||
graph_db_config: dict = None,
|
||||
incremental_loading: bool = True,
|
||||
):
|
||||
tasks = await get_temporal_tasks(user, chunker, chunk_size)
|
||||
|
||||
return await run_cognify_blocking(
|
||||
tasks=tasks,
|
||||
user=user,
|
||||
datasets=datasets,
|
||||
vector_db_config=vector_db_config,
|
||||
graph_db_config=graph_db_config,
|
||||
incremental_loading=incremental_loading,
|
||||
)
|
||||
58
cognee/temporal_poc/temporal_example.py
Normal file
58
cognee/temporal_poc/temporal_example.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
import asyncio
|
||||
import cognee
|
||||
from cognee.shared.logging_utils import setup_logging, INFO
|
||||
from cognee.temporal_poc.temporal_cognify import temporal_cognify
|
||||
from cognee.api.v1.search import SearchType
|
||||
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
async def reading_temporal_data():
|
||||
path = Path("cognee/temporal_poc/test_hard.json")
|
||||
contexts = []
|
||||
seen = set()
|
||||
with path.open(encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
entry = json.loads(line)
|
||||
ctx = entry.get("context", "")
|
||||
if ctx and ctx not in seen:
|
||||
seen.add(ctx)
|
||||
contexts.append(ctx)
|
||||
return contexts
|
||||
|
||||
|
||||
async def main():
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
texts = await reading_temporal_data()
|
||||
texts = texts[:5]
|
||||
|
||||
# texts = ["Buzz Aldrin (born January 20, 1930) is an American former astronaut."]
|
||||
|
||||
await cognee.add(texts)
|
||||
await temporal_cognify()
|
||||
|
||||
search_results = await cognee.search(
|
||||
query_type=SearchType.TEMPORAL, query_text="What happened in the 1930s?"
|
||||
)
|
||||
|
||||
print(search_results)
|
||||
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger = setup_logging(log_level=INFO)
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(main())
|
||||
finally:
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
253
cognee/temporal_poc/temporal_retriever.py
Normal file
253
cognee/temporal_poc/temporal_retriever.py
Normal file
|
|
@ -0,0 +1,253 @@
|
|||
from typing import Any, Optional, Type, List
|
||||
from collections import Counter
|
||||
import string
|
||||
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||
from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
|
||||
from cognee.modules.retrieval.base_retriever import BaseRetriever
|
||||
from cognee.modules.retrieval.utils.brute_force_triplet_search import brute_force_triplet_search
|
||||
from cognee.modules.retrieval.utils.completion import generate_completion
|
||||
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.temporal_poc.models.models import QueryInterval
|
||||
from cognee.temporal_poc.event_extraction import date_to_int
|
||||
|
||||
logger = get_logger("TemporalRetriever")
|
||||
|
||||
|
||||
class TemporalRetriever(BaseRetriever):
|
||||
def __init__(
|
||||
self,
|
||||
user_prompt_path: str = "graph_context_for_question.txt",
|
||||
system_prompt_path: str = "answer_simple_question.txt",
|
||||
top_k: Optional[int] = 5,
|
||||
node_type: Optional[Type] = None,
|
||||
node_name: Optional[List[str]] = None,
|
||||
):
|
||||
"""Initialize retriever with prompt paths and search parameters."""
|
||||
self.user_prompt_path = user_prompt_path
|
||||
self.system_prompt_path = system_prompt_path
|
||||
self.top_k = top_k if top_k is not None else 5
|
||||
self.node_type = node_type
|
||||
self.node_name = node_name
|
||||
|
||||
def _get_nodes(self, retrieved_edges: list) -> dict:
|
||||
"""Creates a dictionary of nodes with their names and content."""
|
||||
nodes = {}
|
||||
for edge in retrieved_edges:
|
||||
for node in (edge.node1, edge.node2):
|
||||
if node.id not in nodes:
|
||||
text = node.attributes.get("text")
|
||||
if text:
|
||||
name = self._get_title(text)
|
||||
content = text
|
||||
else:
|
||||
name = node.attributes.get("name", "Unnamed Node")
|
||||
content = node.attributes.get("description", name)
|
||||
nodes[node.id] = {"node": node, "name": name, "content": content}
|
||||
return nodes
|
||||
|
||||
async def resolve_edges_to_text(self, retrieved_edges: list) -> str:
|
||||
nodes = self._get_nodes(retrieved_edges)
|
||||
node_section = "\n".join(
|
||||
f"Node: {info['name']}\n__node_content_start__\n{info['content']}\n__node_content_end__\n"
|
||||
for info in nodes.values()
|
||||
)
|
||||
connection_section = "\n".join(
|
||||
f"{nodes[edge.node1.id]['name']} --[{edge.attributes['relationship_type']}]--> {nodes[edge.node2.id]['name']}"
|
||||
for edge in retrieved_edges
|
||||
)
|
||||
return f"Nodes:\n{node_section}\n\nConnections:\n{connection_section}"
|
||||
|
||||
async def get_triplets(self, query: str) -> list:
|
||||
"""
|
||||
Retrieves relevant graph triplets based on a query string.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- query (str): The query string used to search for relevant triplets in the graph.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- list: A list of found triplets that match the query.
|
||||
"""
|
||||
subclasses = get_all_subclasses(DataPoint)
|
||||
vector_index_collections = []
|
||||
|
||||
for subclass in subclasses:
|
||||
if "metadata" in subclass.model_fields:
|
||||
metadata_field = subclass.model_fields["metadata"]
|
||||
if hasattr(metadata_field, "default") and metadata_field.default is not None:
|
||||
if isinstance(metadata_field.default, dict):
|
||||
index_fields = metadata_field.default.get("index_fields", [])
|
||||
for field_name in index_fields:
|
||||
vector_index_collections.append(f"{subclass.__name__}_{field_name}")
|
||||
|
||||
found_triplets = await brute_force_triplet_search(
|
||||
query,
|
||||
top_k=self.top_k,
|
||||
collections=vector_index_collections or None,
|
||||
node_type=self.node_type,
|
||||
node_name=self.node_name,
|
||||
)
|
||||
|
||||
return found_triplets
|
||||
|
||||
async def extract_time_from_query(self, query: str):
|
||||
system_prompt = """
|
||||
For the purposes of identifying timestamps in a query, you are tasked with extracting relevant timestamps from the query.
|
||||
## Timestamp requirements
|
||||
- If the query contains interval extrack both starts_at and ends_at properties
|
||||
- If the query contains an instantaneous timestamp, starts_at and ends_at should be the same
|
||||
- If the query its open ended (before 2009 or after 2009), the corresponding non defined end of the time should be none
|
||||
-For example: "before 2009" -- starts_at: None, ends_at: 2009 or "after 2009" -- starts_at: 2009, ends_at: None
|
||||
- Put always the data that comes first in time as starts_at and the timestamps that comes second in time as ends_at
|
||||
## Output Format
|
||||
Your reply should be a JSON: list of dictionaries with the following structure:
|
||||
```python
|
||||
class QueryInterval(BaseModel):
|
||||
starts_at: Optional[Timestamp] = None
|
||||
ends_at: Optional[Timestamp] = None
|
||||
```
|
||||
"""
|
||||
|
||||
interval = await LLMGateway.acreate_structured_output(query, system_prompt, QueryInterval)
|
||||
|
||||
return interval
|
||||
|
||||
def descriptions_to_string(self, results):
|
||||
descs = []
|
||||
for entry in results:
|
||||
events = entry.get("events", [])
|
||||
for ev in events:
|
||||
d = ev.get("description")
|
||||
if d:
|
||||
descs.append(d.strip())
|
||||
return "\n-".join(descs)
|
||||
|
||||
async def get_context(self, query: str) -> str:
|
||||
# :TODO: This is a POC and yes this method is far far far far from nice :D
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
interval = await self.extract_time_from_query(query=query)
|
||||
|
||||
time_from = interval.starts_at
|
||||
time_to = interval.ends_at
|
||||
|
||||
event_collection_cypher = """UNWIND [{quoted}] AS uid
|
||||
MATCH (start {{id: uid}})
|
||||
MATCH (start)-[*1..2]-(event)
|
||||
WHERE event.type = 'Event'
|
||||
WITH DISTINCT event
|
||||
RETURN collect(event) AS events;
|
||||
"""
|
||||
|
||||
if time_from and time_to:
|
||||
time_from = date_to_int(time_from)
|
||||
time_to = date_to_int(time_to)
|
||||
|
||||
cypher = """
|
||||
MATCH (n)
|
||||
WHERE n.type = 'Timestamp'
|
||||
AND n.time_at >= $time_from
|
||||
AND n.time_at <= $time_to
|
||||
RETURN n.id AS id
|
||||
"""
|
||||
params = {"time_from": time_from, "time_to": time_to}
|
||||
time_nodes = await graph_engine.query(cypher, params)
|
||||
|
||||
time_ids_list = [item["id"] for item in time_nodes if "id" in item]
|
||||
|
||||
ids = ", ".join("'{0}'".format(uid) for uid in time_ids_list)
|
||||
|
||||
event_collection_cypher = event_collection_cypher.format(quoted=ids)
|
||||
relevant_events = await graph_engine.query(event_collection_cypher)
|
||||
|
||||
context = self.descriptions_to_string(relevant_events)
|
||||
|
||||
return context
|
||||
elif time_from:
|
||||
time_from = date_to_int(time_from)
|
||||
|
||||
cypher = """
|
||||
MATCH (n)
|
||||
WHERE n.type = 'Timestamp'
|
||||
AND n.time_at >= $time_from
|
||||
RETURN n.id AS id
|
||||
"""
|
||||
params = {"time_from": time_from}
|
||||
time_nodes = await graph_engine.query(cypher, params)
|
||||
|
||||
time_ids_list = [item["id"] for item in time_nodes if "id" in item]
|
||||
|
||||
ids = ", ".join("'{0}'".format(uid) for uid in time_ids_list)
|
||||
|
||||
event_collection_cypher = event_collection_cypher.format(quoted=ids)
|
||||
relevant_events = await graph_engine.query(event_collection_cypher)
|
||||
|
||||
context = self.descriptions_to_string(relevant_events)
|
||||
|
||||
return context
|
||||
|
||||
elif time_to:
|
||||
time_to = date_to_int(time_to)
|
||||
|
||||
cypher = """
|
||||
MATCH (n)
|
||||
WHERE n.type = 'Timestamp'
|
||||
AND n.time_at <= $time_to
|
||||
RETURN n.id AS id
|
||||
"""
|
||||
params = {"time_to": time_to}
|
||||
|
||||
time_nodes = await graph_engine.query(cypher, params)
|
||||
|
||||
time_ids_list = [item["id"] for item in time_nodes if "id" in item]
|
||||
|
||||
ids = ", ".join("'{0}'".format(uid) for uid in time_ids_list)
|
||||
|
||||
event_collection_cypher = event_collection_cypher.format(quoted=ids)
|
||||
relevant_events = await graph_engine.query(event_collection_cypher)
|
||||
|
||||
context = self.descriptions_to_string(relevant_events)
|
||||
|
||||
return context
|
||||
else:
|
||||
logger.info(
|
||||
"We couldn't find any timestamps in this query therefore we return empty context"
|
||||
)
|
||||
return ""
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
if context is None:
|
||||
context = await self.get_context(query)
|
||||
|
||||
completion = await generate_completion(
|
||||
query=query,
|
||||
context=context,
|
||||
user_prompt_path=self.user_prompt_path,
|
||||
system_prompt_path=self.system_prompt_path,
|
||||
)
|
||||
return [completion]
|
||||
|
||||
def _top_n_words(self, text, stop_words=None, top_n=3, separator=", "):
|
||||
if stop_words is None:
|
||||
stop_words = DEFAULT_STOP_WORDS
|
||||
|
||||
words = [word.lower().strip(string.punctuation) for word in text.split()]
|
||||
|
||||
if stop_words:
|
||||
words = [word for word in words if word and word not in stop_words]
|
||||
|
||||
top_words = [word for word, freq in Counter(words).most_common(top_n)]
|
||||
|
||||
return separator.join(top_words)
|
||||
|
||||
def _get_title(self, text: str, first_n_words: int = 7, top_n_words: int = 3) -> str:
|
||||
first_n_words = text.split()[:first_n_words]
|
||||
top_n_words = self._top_n_words(text, top_n=top_n_words)
|
||||
return f"{' '.join(first_n_words)}... [{top_n_words}]"
|
||||
Loading…
Add table
Reference in a new issue