Count the number of tokens in documents [COG-1071] (#476)

* Count the number of tokens in documents

* save token count to relational db

---------

Co-authored-by: Igor Ilic <30923996+dexters1@users.noreply.github.com>
This commit is contained in:
alekszievr 2025-01-29 11:29:09 +01:00 committed by GitHub
parent d900060e2b
commit edae2771a5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 31 additions and 2 deletions

View file

@ -46,6 +46,7 @@ class TextChunker:
id=chunk_data["chunk_id"], id=chunk_data["chunk_id"],
text=chunk_data["text"], text=chunk_data["text"],
word_count=chunk_data["word_count"], word_count=chunk_data["word_count"],
token_count=chunk_data["token_count"],
is_part_of=self.document, is_part_of=self.document,
chunk_index=self.chunk_index, chunk_index=self.chunk_index,
cut_type=chunk_data["cut_type"], cut_type=chunk_data["cut_type"],
@ -65,6 +66,7 @@ class TextChunker:
), ),
text=chunk_text, text=chunk_text,
word_count=self.chunk_size, word_count=self.chunk_size,
token_count=self.token_count,
is_part_of=self.document, is_part_of=self.document,
chunk_index=self.chunk_index, chunk_index=self.chunk_index,
cut_type=paragraph_chunks[len(paragraph_chunks) - 1]["cut_type"], cut_type=paragraph_chunks[len(paragraph_chunks) - 1]["cut_type"],
@ -87,6 +89,7 @@ class TextChunker:
id=uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}"), id=uuid5(NAMESPACE_OID, f"{str(self.document.id)}-{self.chunk_index}"),
text=" ".join(chunk["text"] for chunk in paragraph_chunks), text=" ".join(chunk["text"] for chunk in paragraph_chunks),
word_count=self.chunk_size, word_count=self.chunk_size,
token_count=self.token_count,
is_part_of=self.document, is_part_of=self.document,
chunk_index=self.chunk_index, chunk_index=self.chunk_index,
cut_type=paragraph_chunks[len(paragraph_chunks) - 1]["cut_type"], cut_type=paragraph_chunks[len(paragraph_chunks) - 1]["cut_type"],

View file

@ -9,6 +9,7 @@ class DocumentChunk(DataPoint):
__tablename__ = "document_chunk" __tablename__ = "document_chunk"
text: str text: str
word_count: int word_count: int
token_count: int
chunk_index: int chunk_index: int
cut_type: str cut_type: str
is_part_of: Document is_part_of: Document

View file

@ -1,6 +1,6 @@
from datetime import datetime, timezone from datetime import datetime, timezone
from uuid import uuid4 from uuid import uuid4
from sqlalchemy import UUID, Column, DateTime, String, JSON from sqlalchemy import UUID, Column, DateTime, String, JSON, Integer
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
from cognee.infrastructure.databases.relational import Base from cognee.infrastructure.databases.relational import Base
@ -20,6 +20,7 @@ class Data(Base):
owner_id = Column(UUID, index=True) owner_id = Column(UUID, index=True)
content_hash = Column(String) content_hash = Column(String)
external_metadata = Column(JSON) external_metadata = Column(JSON)
token_count = Column(Integer)
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc)) updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc))

View file

@ -1,6 +1,25 @@
from typing import Optional, AsyncGenerator from typing import AsyncGenerator
from cognee.modules.data.processing.document_types.Document import Document from cognee.modules.data.processing.document_types.Document import Document
from sqlalchemy import select
from cognee.modules.data.models import Data
from cognee.infrastructure.databases.relational import get_relational_engine
from uuid import UUID
async def update_document_token_count(document_id: UUID, token_count: int) -> None:
db_engine = get_relational_engine()
async with db_engine.get_async_session() as session:
document_data_point = (
await session.execute(select(Data).filter(Data.id == document_id))
).scalar_one_or_none()
if document_data_point:
document_data_point.token_count = token_count
await session.merge(document_data_point)
await session.commit()
else:
raise ValueError(f"Document with id {document_id} not found.")
async def extract_chunks_from_documents( async def extract_chunks_from_documents(
@ -17,7 +36,11 @@ async def extract_chunks_from_documents(
- The `chunker` parameter determines the chunking logic and should align with the document type. - The `chunker` parameter determines the chunking logic and should align with the document type.
""" """
for document in documents: for document in documents:
document_token_count = 0
for document_chunk in document.read( for document_chunk in document.read(
chunk_size=chunk_size, chunker=chunker, max_chunk_tokens=max_chunk_tokens chunk_size=chunk_size, chunker=chunker, max_chunk_tokens=max_chunk_tokens
): ):
document_token_count += document_chunk.token_count
yield document_chunk yield document_chunk
await update_document_token_count(document.id, document_token_count)

View file

@ -107,6 +107,7 @@ async def ingest_data(data: Any, dataset_name: str, user: User):
owner_id=user.id, owner_id=user.id,
content_hash=file_metadata["content_hash"], content_hash=file_metadata["content_hash"],
external_metadata=get_external_metadata_dict(data_item), external_metadata=get_external_metadata_dict(data_item),
token_count=-1,
) )
# Check if data is already in dataset # Check if data is already in dataset