diff --git a/cognee/modules/chunking/models/DocumentChunk.py b/cognee/modules/chunking/models/DocumentChunk.py index 475703265..d2c067dbe 100644 --- a/cognee/modules/chunking/models/DocumentChunk.py +++ b/cognee/modules/chunking/models/DocumentChunk.py @@ -3,6 +3,8 @@ from typing import List from cognee.infrastructure.engine import DataPoint from cognee.modules.data.processing.document_types import Document from cognee.modules.engine.models import Entity +from typing import Union +from cognee.temporal_poc.models.models import Event class DocumentChunk(DataPoint): @@ -30,6 +32,6 @@ class DocumentChunk(DataPoint): chunk_index: int cut_type: str is_part_of: Document - contains: List[Entity] = None + contains: List[Union[Entity, Event]] = None metadata: dict = {"index_fields": ["text"]} diff --git a/cognee/temporal_poc/datapoints/datapoints.py b/cognee/temporal_poc/datapoints/datapoints.py new file mode 100644 index 000000000..aea493cd0 --- /dev/null +++ b/cognee/temporal_poc/datapoints/datapoints.py @@ -0,0 +1,23 @@ +from cognee.infrastructure.engine import DataPoint +from cognee.modules.engine.models.EntityType import EntityType +from typing import Optional +from pydantic import BaseModel, Field, ConfigDict + + +class Interval(DataPoint): + time_from: int = Field(..., ge=0) + time_to: int = Field(..., ge=0) + + +class Timestamp(DataPoint): + time_at: int = Field(..., ge=0) + + +class Event(DataPoint): + name: str + description: Optional[str] = None + at: Optional[Timestamp] = None + during: Optional[Interval] = None + location: Optional[str] = None + + metadata: dict = {"index_fields": ["name"]} diff --git a/cognee/temporal_poc/models/models.py b/cognee/temporal_poc/models/models.py new file mode 100644 index 000000000..74d0c0903 --- /dev/null +++ b/cognee/temporal_poc/models/models.py @@ -0,0 +1,29 @@ +from typing import Optional, Literal, List +from pydantic import BaseModel, Field, root_validator, ValidationError +from cognee.modules.engine.models.Entity import Entity + + +class Timestamp(BaseModel): + year: int = Field(..., ge=1, le=9999) + month: int = Field(..., ge=1, le=12) + day: int = Field(..., ge=1, le=31) + hour: int = Field(..., ge=0, le=23) + minute: int = Field(..., ge=0, le=59) + second: int = Field(..., ge=0, le=59) + + +class Interval(BaseModel): + starts_at: Timestamp + ends_at: Timestamp + + +class Event(BaseModel): + name: str + description: Optional[str] = None + time_from: Optional[Timestamp] = None + time_to: Optional[Timestamp] = None + location: Optional[str] = None + + +class EventList(BaseModel): + events: List[Event] diff --git a/cognee/temporal_poc/temporal_cognify.py b/cognee/temporal_poc/temporal_cognify.py new file mode 100644 index 000000000..534bd6aa2 --- /dev/null +++ b/cognee/temporal_poc/temporal_cognify.py @@ -0,0 +1,138 @@ +import asyncio +from pydantic import BaseModel +from typing import Union, Optional, List, Type +from uuid import UUID + +from cognee.infrastructure.llm.get_llm_client import get_llm_client +from cognee.shared.logging_utils import get_logger +from cognee.shared.data_models import KnowledgeGraph +from cognee.infrastructure.llm import get_max_chunk_tokens, get_llm_config + +from cognee.api.v1.cognify.cognify import run_cognify_blocking +from cognee.modules.pipelines.tasks.task import Task +from cognee.modules.chunking.TextChunker import TextChunker +from cognee.modules.ontology.rdf_xml.OntologyResolver import OntologyResolver +from cognee.modules.chunking.models.DocumentChunk import DocumentChunk + +from cognee.modules.users.models import User +from cognee.tasks.documents import ( + check_permissions_on_dataset, + classify_documents, + extract_chunks_from_documents, +) +from cognee.tasks.graph import extract_graph_from_data +from cognee.tasks.storage import add_data_points +from cognee.tasks.summarization import summarize_text +from cognee.temporal_poc.models.models import EventList +from cognee.temporal_poc.datapoints.datapoints import Interval, Timestamp, Event + +logger = get_logger("temporal_cognify") + + +async def extract_event_graph(content: str, response_model: Type[BaseModel]): + llm_client = get_llm_client() + + system_prompt = """ + You are an extractor. From input text, pull out: + + Timestamps: concrete points (year, month, day, hour, minute, second). + + Intervals: spans with explicit start and end times; resolve relative durations if anchored. + + Entities: people, organizations, topics, etc., with name, short description, and with their type (person/org/location/topic/other). Always attach the type. + + Events: include name, brief description, subject (actor), object (target), time as either a point (at) or span (during), and location. Prefer during if it’s a multi-hour span; use at for a point. Omit ambiguous times rather than guessing. + + Output JSON. Reuse entity names when repeated. Use null for missing optional fields. + ” + """ + + content_graph = await llm_client.acreate_structured_output( + content, system_prompt, response_model + ) + + return content_graph + + +async def extract_events_and_entities(data_chunks: List[DocumentChunk]) -> List[DocumentChunk]: + """Extracts events and entities from a chunk of documents.""" + # data_chunks = data_chunks + data_chunks + + events = await asyncio.gather( + *[extract_event_graph(chunk.text, EventList) for chunk in data_chunks] + ) + for data_chunk, event_list in zip(data_chunks, events): + for event in event_list.events: + if event.time_from and event.time_to: + event_interval = Interval( + time_from=int(event.time_from), time_to=int(event.time_to) + ) + event_datapoint = Event( + name=event.name, + description=event.description, + during=event_interval, + location=event.location, + ) + elif event.time_from: + event_time_at = Timestamp(time_at=int(event.time_from)) + event_datapoint = Event( + name=event.name, + description=event.description, + at=event_time_at, + location=event.location, + ) + elif event.time_to: + event_time_at = Timestamp(time_at=int(event.time_to)) + event_datapoint = Event( + name=event.name, + description=event.description, + at=event_time_at, + location=event.location, + ) + else: + event_datapoint = Event( + name=event.name, description=event.description, location=event.location + ) + + data_chunk.contains.append(event_datapoint) + + return data_chunks + + +async def get_temporal_tasks( + user: User = None, chunker=TextChunker, chunk_size: int = None +) -> list[Task]: + temporal_tasks = [ + Task(classify_documents), + Task(check_permissions_on_dataset, user=user, permissions=["write"]), + Task( + extract_chunks_from_documents, + max_chunk_size=chunk_size or get_max_chunk_tokens(), + chunker=chunker, + ), + Task(extract_events_and_entities, task_config={"chunk_size": 10}), + Task(add_data_points, task_config={"batch_size": 10}), + ] + + return temporal_tasks + + +async def temporal_cognify( + datasets: Union[str, list[str], list[UUID]] = None, + user: User = None, + chunker=TextChunker, + chunk_size: int = None, + vector_db_config: dict = None, + graph_db_config: dict = None, + incremental_loading: bool = True, +): + tasks = await get_temporal_tasks(user, chunker, chunk_size) + + return await run_cognify_blocking( + tasks=tasks, + user=user, + datasets=datasets, + vector_db_config=vector_db_config, + graph_db_config=graph_db_config, + incremental_loading=incremental_loading, + ) diff --git a/cognee/temporal_poc/temporal_example.py b/cognee/temporal_poc/temporal_example.py new file mode 100644 index 000000000..90a16f680 --- /dev/null +++ b/cognee/temporal_poc/temporal_example.py @@ -0,0 +1,48 @@ +import asyncio +import cognee +from cognee.shared.logging_utils import setup_logging, INFO +from cognee.temporal_poc.temporal_cognify import temporal_cognify + +import json +from pathlib import Path + + +async def reading_temporal_data(): + path = Path("cognee/temporal_poc/test_hard.json") + contexts = [] + seen = set() + with path.open(encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + entry = json.loads(line) + ctx = entry.get("context", "") + if ctx and ctx not in seen: + seen.add(ctx) + contexts.append(ctx) + return contexts + + +async def main(): + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + + texts = await reading_temporal_data() + texts = texts[:10] + + await cognee.add(texts) + await temporal_cognify() + + print() + + +if __name__ == "__main__": + logger = setup_logging(log_level=INFO) + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(main()) + finally: + loop.run_until_complete(loop.shutdown_asyncgens())