diff --git a/cognee/modules/chunking/TextChunker.py b/cognee/modules/chunking/TextChunker.py index f9a664d8b..57663f95e 100644 --- a/cognee/modules/chunking/TextChunker.py +++ b/cognee/modules/chunking/TextChunker.py @@ -46,6 +46,7 @@ class TextChunker: id=chunk_data["chunk_id"], text=chunk_data["text"], word_count=chunk_data["word_count"], + token_count=chunk_data["token_count"], is_part_of=self.document, chunk_index=self.chunk_index, cut_type=chunk_data["cut_type"], @@ -65,6 +66,7 @@ class TextChunker: ), text=chunk_text, word_count=self.chunk_size, + token_count=self.token_count, is_part_of=self.document, chunk_index=self.chunk_index, cut_type=paragraph_chunks[len(paragraph_chunks) - 1]["cut_type"], @@ -87,6 +89,7 @@ class TextChunker: id=uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}"), text=" ".join(chunk["text"] for chunk in paragraph_chunks), word_count=self.chunk_size, + token_count=self.token_count, is_part_of=self.document, chunk_index=self.chunk_index, cut_type=paragraph_chunks[len(paragraph_chunks) - 1]["cut_type"], diff --git a/cognee/modules/chunking/models/DocumentChunk.py b/cognee/modules/chunking/models/DocumentChunk.py index a232d50a1..894a810d2 100644 --- a/cognee/modules/chunking/models/DocumentChunk.py +++ b/cognee/modules/chunking/models/DocumentChunk.py @@ -9,6 +9,7 @@ class DocumentChunk(DataPoint): __tablename__ = "document_chunk" text: str word_count: int + token_count: int chunk_index: int cut_type: str is_part_of: Document diff --git a/cognee/modules/data/models/Data.py b/cognee/modules/data/models/Data.py index 0c0d60d0d..54d8bce1b 100644 --- a/cognee/modules/data/models/Data.py +++ b/cognee/modules/data/models/Data.py @@ -1,6 +1,6 @@ from datetime import datetime, timezone from uuid import uuid4 -from sqlalchemy import UUID, Column, DateTime, String, JSON +from sqlalchemy import UUID, Column, DateTime, String, JSON, Integer from sqlalchemy.orm import relationship from cognee.infrastructure.databases.relational import Base @@ -20,6 +20,7 @@ class Data(Base): owner_id = Column(UUID, index=True) content_hash = Column(String) external_metadata = Column(JSON) + token_count = Column(Integer) created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc)) diff --git a/cognee/tasks/documents/extract_chunks_from_documents.py b/cognee/tasks/documents/extract_chunks_from_documents.py index 4a089c7bc..a65f32fc9 100644 --- a/cognee/tasks/documents/extract_chunks_from_documents.py +++ b/cognee/tasks/documents/extract_chunks_from_documents.py @@ -1,6 +1,25 @@ -from typing import Optional, AsyncGenerator +from typing import AsyncGenerator from cognee.modules.data.processing.document_types.Document import Document +from sqlalchemy import select +from cognee.modules.data.models import Data +from cognee.infrastructure.databases.relational import get_relational_engine +from uuid import UUID + + +async def update_document_token_count(document_id: UUID, token_count: int) -> None: + db_engine = get_relational_engine() + async with db_engine.get_async_session() as session: + document_data_point = ( + await session.execute(select(Data).filter(Data.id == document_id)) + ).scalar_one_or_none() + + if document_data_point: + document_data_point.token_count = token_count + await session.merge(document_data_point) + await session.commit() + else: + raise ValueError(f"Document with id {document_id} not found.") async def extract_chunks_from_documents( @@ -17,7 +36,11 @@ async def extract_chunks_from_documents( - The `chunker` parameter determines the chunking logic and should align with the document type. """ for document in documents: + document_token_count = 0 for document_chunk in document.read( chunk_size=chunk_size, chunker=chunker, max_chunk_tokens=max_chunk_tokens ): + document_token_count += document_chunk.token_count yield document_chunk + + await update_document_token_count(document.id, document_token_count) diff --git a/cognee/tasks/ingestion/ingest_data.py b/cognee/tasks/ingestion/ingest_data.py index 924ef10b0..b19786d4e 100644 --- a/cognee/tasks/ingestion/ingest_data.py +++ b/cognee/tasks/ingestion/ingest_data.py @@ -107,6 +107,7 @@ async def ingest_data(data: Any, dataset_name: str, user: User): owner_id=user.id, content_hash=file_metadata["content_hash"], external_metadata=get_external_metadata_dict(data_item), + token_count=-1, ) # Check if data is already in dataset