save_dev
This commit is contained in:
parent
f2e66bc276
commit
e90cbc43dd
5 changed files with 241 additions and 1 deletions
|
|
@ -3,6 +3,8 @@ from typing import List
|
|||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.data.processing.document_types import Document
|
||||
from cognee.modules.engine.models import Entity
|
||||
from typing import Union
|
||||
from cognee.temporal_poc.models.models import Event
|
||||
|
||||
|
||||
class DocumentChunk(DataPoint):
|
||||
|
|
@ -30,6 +32,6 @@ class DocumentChunk(DataPoint):
|
|||
chunk_index: int
|
||||
cut_type: str
|
||||
is_part_of: Document
|
||||
contains: List[Entity] = None
|
||||
contains: List[Union[Entity, Event]] = None
|
||||
|
||||
metadata: dict = {"index_fields": ["text"]}
|
||||
|
|
|
|||
23
cognee/temporal_poc/datapoints/datapoints.py
Normal file
23
cognee/temporal_poc/datapoints/datapoints.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.modules.engine.models.EntityType import EntityType
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
|
||||
|
||||
class Interval(DataPoint):
|
||||
time_from: int = Field(..., ge=0)
|
||||
time_to: int = Field(..., ge=0)
|
||||
|
||||
|
||||
class Timestamp(DataPoint):
|
||||
time_at: int = Field(..., ge=0)
|
||||
|
||||
|
||||
class Event(DataPoint):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
at: Optional[Timestamp] = None
|
||||
during: Optional[Interval] = None
|
||||
location: Optional[str] = None
|
||||
|
||||
metadata: dict = {"index_fields": ["name"]}
|
||||
29
cognee/temporal_poc/models/models.py
Normal file
29
cognee/temporal_poc/models/models.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
from typing import Optional, Literal, List
|
||||
from pydantic import BaseModel, Field, root_validator, ValidationError
|
||||
from cognee.modules.engine.models.Entity import Entity
|
||||
|
||||
|
||||
class Timestamp(BaseModel):
|
||||
year: int = Field(..., ge=1, le=9999)
|
||||
month: int = Field(..., ge=1, le=12)
|
||||
day: int = Field(..., ge=1, le=31)
|
||||
hour: int = Field(..., ge=0, le=23)
|
||||
minute: int = Field(..., ge=0, le=59)
|
||||
second: int = Field(..., ge=0, le=59)
|
||||
|
||||
|
||||
class Interval(BaseModel):
|
||||
starts_at: Timestamp
|
||||
ends_at: Timestamp
|
||||
|
||||
|
||||
class Event(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
time_from: Optional[Timestamp] = None
|
||||
time_to: Optional[Timestamp] = None
|
||||
location: Optional[str] = None
|
||||
|
||||
|
||||
class EventList(BaseModel):
|
||||
events: List[Event]
|
||||
138
cognee/temporal_poc/temporal_cognify.py
Normal file
138
cognee/temporal_poc/temporal_cognify.py
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
import asyncio
|
||||
from pydantic import BaseModel
|
||||
from typing import Union, Optional, List, Type
|
||||
from uuid import UUID
|
||||
|
||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.shared.data_models import KnowledgeGraph
|
||||
from cognee.infrastructure.llm import get_max_chunk_tokens, get_llm_config
|
||||
|
||||
from cognee.api.v1.cognify.cognify import run_cognify_blocking
|
||||
from cognee.modules.pipelines.tasks.task import Task
|
||||
from cognee.modules.chunking.TextChunker import TextChunker
|
||||
from cognee.modules.ontology.rdf_xml.OntologyResolver import OntologyResolver
|
||||
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
|
||||
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.tasks.documents import (
|
||||
check_permissions_on_dataset,
|
||||
classify_documents,
|
||||
extract_chunks_from_documents,
|
||||
)
|
||||
from cognee.tasks.graph import extract_graph_from_data
|
||||
from cognee.tasks.storage import add_data_points
|
||||
from cognee.tasks.summarization import summarize_text
|
||||
from cognee.temporal_poc.models.models import EventList
|
||||
from cognee.temporal_poc.datapoints.datapoints import Interval, Timestamp, Event
|
||||
|
||||
logger = get_logger("temporal_cognify")
|
||||
|
||||
|
||||
async def extract_event_graph(content: str, response_model: Type[BaseModel]):
|
||||
llm_client = get_llm_client()
|
||||
|
||||
system_prompt = """
|
||||
You are an extractor. From input text, pull out:
|
||||
|
||||
Timestamps: concrete points (year, month, day, hour, minute, second).
|
||||
|
||||
Intervals: spans with explicit start and end times; resolve relative durations if anchored.
|
||||
|
||||
Entities: people, organizations, topics, etc., with name, short description, and with their type (person/org/location/topic/other). Always attach the type.
|
||||
|
||||
Events: include name, brief description, subject (actor), object (target), time as either a point (at) or span (during), and location. Prefer during if it’s a multi-hour span; use at for a point. Omit ambiguous times rather than guessing.
|
||||
|
||||
Output JSON. Reuse entity names when repeated. Use null for missing optional fields.
|
||||
”
|
||||
"""
|
||||
|
||||
content_graph = await llm_client.acreate_structured_output(
|
||||
content, system_prompt, response_model
|
||||
)
|
||||
|
||||
return content_graph
|
||||
|
||||
|
||||
async def extract_events_and_entities(data_chunks: List[DocumentChunk]) -> List[DocumentChunk]:
|
||||
"""Extracts events and entities from a chunk of documents."""
|
||||
# data_chunks = data_chunks + data_chunks
|
||||
|
||||
events = await asyncio.gather(
|
||||
*[extract_event_graph(chunk.text, EventList) for chunk in data_chunks]
|
||||
)
|
||||
for data_chunk, event_list in zip(data_chunks, events):
|
||||
for event in event_list.events:
|
||||
if event.time_from and event.time_to:
|
||||
event_interval = Interval(
|
||||
time_from=int(event.time_from), time_to=int(event.time_to)
|
||||
)
|
||||
event_datapoint = Event(
|
||||
name=event.name,
|
||||
description=event.description,
|
||||
during=event_interval,
|
||||
location=event.location,
|
||||
)
|
||||
elif event.time_from:
|
||||
event_time_at = Timestamp(time_at=int(event.time_from))
|
||||
event_datapoint = Event(
|
||||
name=event.name,
|
||||
description=event.description,
|
||||
at=event_time_at,
|
||||
location=event.location,
|
||||
)
|
||||
elif event.time_to:
|
||||
event_time_at = Timestamp(time_at=int(event.time_to))
|
||||
event_datapoint = Event(
|
||||
name=event.name,
|
||||
description=event.description,
|
||||
at=event_time_at,
|
||||
location=event.location,
|
||||
)
|
||||
else:
|
||||
event_datapoint = Event(
|
||||
name=event.name, description=event.description, location=event.location
|
||||
)
|
||||
|
||||
data_chunk.contains.append(event_datapoint)
|
||||
|
||||
return data_chunks
|
||||
|
||||
|
||||
async def get_temporal_tasks(
|
||||
user: User = None, chunker=TextChunker, chunk_size: int = None
|
||||
) -> list[Task]:
|
||||
temporal_tasks = [
|
||||
Task(classify_documents),
|
||||
Task(check_permissions_on_dataset, user=user, permissions=["write"]),
|
||||
Task(
|
||||
extract_chunks_from_documents,
|
||||
max_chunk_size=chunk_size or get_max_chunk_tokens(),
|
||||
chunker=chunker,
|
||||
),
|
||||
Task(extract_events_and_entities, task_config={"chunk_size": 10}),
|
||||
Task(add_data_points, task_config={"batch_size": 10}),
|
||||
]
|
||||
|
||||
return temporal_tasks
|
||||
|
||||
|
||||
async def temporal_cognify(
|
||||
datasets: Union[str, list[str], list[UUID]] = None,
|
||||
user: User = None,
|
||||
chunker=TextChunker,
|
||||
chunk_size: int = None,
|
||||
vector_db_config: dict = None,
|
||||
graph_db_config: dict = None,
|
||||
incremental_loading: bool = True,
|
||||
):
|
||||
tasks = await get_temporal_tasks(user, chunker, chunk_size)
|
||||
|
||||
return await run_cognify_blocking(
|
||||
tasks=tasks,
|
||||
user=user,
|
||||
datasets=datasets,
|
||||
vector_db_config=vector_db_config,
|
||||
graph_db_config=graph_db_config,
|
||||
incremental_loading=incremental_loading,
|
||||
)
|
||||
48
cognee/temporal_poc/temporal_example.py
Normal file
48
cognee/temporal_poc/temporal_example.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
import asyncio
|
||||
import cognee
|
||||
from cognee.shared.logging_utils import setup_logging, INFO
|
||||
from cognee.temporal_poc.temporal_cognify import temporal_cognify
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
async def reading_temporal_data():
|
||||
path = Path("cognee/temporal_poc/test_hard.json")
|
||||
contexts = []
|
||||
seen = set()
|
||||
with path.open(encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
entry = json.loads(line)
|
||||
ctx = entry.get("context", "")
|
||||
if ctx and ctx not in seen:
|
||||
seen.add(ctx)
|
||||
contexts.append(ctx)
|
||||
return contexts
|
||||
|
||||
|
||||
async def main():
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
texts = await reading_temporal_data()
|
||||
texts = texts[:10]
|
||||
|
||||
await cognee.add(texts)
|
||||
await temporal_cognify()
|
||||
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger = setup_logging(log_level=INFO)
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(main())
|
||||
finally:
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
Loading…
Add table
Reference in a new issue