Merge pull request #468 from topoteretes/COG-970-refactor-tokenizing

Cog 970 refactor tokenizing
This commit is contained in:
Igor Ilic 2025-01-29 09:02:23 +01:00 committed by GitHub
commit d900060e2b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
44 changed files with 633 additions and 134 deletions

View file

@ -1,12 +1,28 @@
ENV="local" ENV="local"
TOKENIZERS_PARALLELISM="false" TOKENIZERS_PARALLELISM="false"
LLM_API_KEY=
# LLM settings
LLM_API_KEY=""
LLM_MODEL="openai/gpt-4o-mini"
LLM_PROVIDER="openai"
LLM_ENDPOINT=""
LLM_API_VERSION=""
LLM_MAX_TOKENS="16384"
GRAPHISTRY_USERNAME= GRAPHISTRY_USERNAME=
GRAPHISTRY_PASSWORD= GRAPHISTRY_PASSWORD=
SENTRY_REPORTING_URL= SENTRY_REPORTING_URL=
# Embedding settings
EMBEDDING_PROVIDER="openai"
EMBEDDING_API_KEY=""
EMBEDDING_MODEL="openai/text-embedding-3-large"
EMBEDDING_ENDPOINT=""
EMBEDDING_API_VERSION=""
EMBEDDING_DIMENSIONS=3072
EMBEDDING_MAX_TOKENS=8191
# "neo4j" or "networkx" # "neo4j" or "networkx"
GRAPH_DATABASE_PROVIDER="networkx" GRAPH_DATABASE_PROVIDER="networkx"
# Not needed if using networkx # Not needed if using networkx

View file

@ -20,6 +20,7 @@ from cognee.tasks.repo_processor import (
from cognee.tasks.repo_processor.get_source_code_chunks import get_source_code_chunks from cognee.tasks.repo_processor.get_source_code_chunks import get_source_code_chunks
from cognee.tasks.storage import add_data_points from cognee.tasks.storage import add_data_points
from cognee.tasks.summarization import summarize_code, summarize_text from cognee.tasks.summarization import summarize_code, summarize_text
from cognee.infrastructure.llm import get_max_chunk_tokens
monitoring = get_base_config().monitoring_tool monitoring = get_base_config().monitoring_tool
if monitoring == MonitoringTool.LANGFUSE: if monitoring == MonitoringTool.LANGFUSE:
@ -57,7 +58,7 @@ async def run_code_graph_pipeline(repo_path, include_docs=True):
Task(ingest_data, dataset_name="repo_docs", user=user), Task(ingest_data, dataset_name="repo_docs", user=user),
Task(get_data_list_for_user, dataset_name="repo_docs", user=user), Task(get_data_list_for_user, dataset_name="repo_docs", user=user),
Task(classify_documents), Task(classify_documents),
Task(extract_chunks_from_documents, max_tokens=cognee_config.max_tokens), Task(extract_chunks_from_documents, max_chunk_tokens=get_max_chunk_tokens()),
Task( Task(
extract_graph_from_data, graph_model=KnowledgeGraph, task_config={"batch_size": 50} extract_graph_from_data, graph_model=KnowledgeGraph, task_config={"batch_size": 50}
), ),

View file

@ -4,6 +4,7 @@ from typing import Union
from pydantic import BaseModel from pydantic import BaseModel
from cognee.infrastructure.llm import get_max_chunk_tokens
from cognee.modules.cognify.config import get_cognify_config from cognee.modules.cognify.config import get_cognify_config
from cognee.modules.data.methods import get_datasets, get_datasets_by_name from cognee.modules.data.methods import get_datasets, get_datasets_by_name
from cognee.modules.data.methods.get_dataset_data import get_dataset_data from cognee.modules.data.methods.get_dataset_data import get_dataset_data
@ -151,7 +152,9 @@ async def get_default_tasks(
default_tasks = [ default_tasks = [
Task(classify_documents), Task(classify_documents),
Task(check_permissions_on_documents, user=user, permissions=["write"]), Task(check_permissions_on_documents, user=user, permissions=["write"]),
Task(extract_chunks_from_documents), # Extract text chunks based on the document type. Task(
extract_chunks_from_documents, max_chunk_tokens=get_max_chunk_tokens()
), # Extract text chunks based on the document type.
Task( Task(
extract_graph_from_data, graph_model=graph_model, task_config={"batch_size": 10} extract_graph_from_data, graph_model=graph_model, task_config={"batch_size": 10}
), # Generate knowledge graphs from the document chunks. ), # Generate knowledge graphs from the document chunks.

View file

@ -6,6 +6,9 @@ import litellm
import os import os
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException
from cognee.infrastructure.llm.tokenizer.Gemini import GeminiTokenizer
from cognee.infrastructure.llm.tokenizer.HuggingFace import HuggingFaceTokenizer
from cognee.infrastructure.llm.tokenizer.TikToken import TikTokenTokenizer
litellm.set_verbose = False litellm.set_verbose = False
logger = logging.getLogger("LiteLLMEmbeddingEngine") logger = logging.getLogger("LiteLLMEmbeddingEngine")
@ -15,23 +18,29 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
api_key: str api_key: str
endpoint: str endpoint: str
api_version: str api_version: str
provider: str
model: str model: str
dimensions: int dimensions: int
mock: bool mock: bool
def __init__( def __init__(
self, self,
provider: str = "openai",
model: Optional[str] = "text-embedding-3-large", model: Optional[str] = "text-embedding-3-large",
dimensions: Optional[int] = 3072, dimensions: Optional[int] = 3072,
api_key: str = None, api_key: str = None,
endpoint: str = None, endpoint: str = None,
api_version: str = None, api_version: str = None,
max_tokens: int = 512,
): ):
self.api_key = api_key self.api_key = api_key
self.endpoint = endpoint self.endpoint = endpoint
self.api_version = api_version self.api_version = api_version
self.provider = provider
self.model = model self.model = model
self.dimensions = dimensions self.dimensions = dimensions
self.max_tokens = max_tokens
self.tokenizer = self.get_tokenizer()
enable_mocking = os.getenv("MOCK_EMBEDDING", "false") enable_mocking = os.getenv("MOCK_EMBEDDING", "false")
if isinstance(enable_mocking, bool): if isinstance(enable_mocking, bool):
@ -104,3 +113,18 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
def get_vector_size(self) -> int: def get_vector_size(self) -> int:
return self.dimensions return self.dimensions
def get_tokenizer(self):
logger.debug(f"Loading tokenizer for model {self.model}...")
# If model also contains provider information, extract only model information
model = self.model.split("/")[-1]
if "openai" in self.provider.lower():
tokenizer = TikTokenTokenizer(model=model, max_tokens=self.max_tokens)
elif "gemini" in self.provider.lower():
tokenizer = GeminiTokenizer(model=model, max_tokens=self.max_tokens)
else:
tokenizer = HuggingFaceTokenizer(model=self.model, max_tokens=self.max_tokens)
logger.debug(f"Tokenizer loaded for model: {self.model}")
return tokenizer

View file

@ -4,12 +4,13 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
class EmbeddingConfig(BaseSettings): class EmbeddingConfig(BaseSettings):
embedding_model: Optional[str] = "text-embedding-3-large" embedding_provider: Optional[str] = "openai"
embedding_model: Optional[str] = "openai/text-embedding-3-large"
embedding_dimensions: Optional[int] = 3072 embedding_dimensions: Optional[int] = 3072
embedding_endpoint: Optional[str] = None embedding_endpoint: Optional[str] = None
embedding_api_key: Optional[str] = None embedding_api_key: Optional[str] = None
embedding_api_version: Optional[str] = None embedding_api_version: Optional[str] = None
embedding_max_tokens: Optional[int] = 8191
model_config = SettingsConfigDict(env_file=".env", extra="allow") model_config = SettingsConfigDict(env_file=".env", extra="allow")

View file

@ -10,9 +10,11 @@ def get_embedding_engine() -> EmbeddingEngine:
return LiteLLMEmbeddingEngine( return LiteLLMEmbeddingEngine(
# If OpenAI API is used for embeddings, litellm needs only the api_key. # If OpenAI API is used for embeddings, litellm needs only the api_key.
provider=config.embedding_provider,
api_key=config.embedding_api_key or llm_config.llm_api_key, api_key=config.embedding_api_key or llm_config.llm_api_key,
endpoint=config.embedding_endpoint, endpoint=config.embedding_endpoint,
api_version=config.embedding_api_version, api_version=config.embedding_api_version,
model=config.embedding_model, model=config.embedding_model,
dimensions=config.embedding_dimensions, dimensions=config.embedding_dimensions,
max_tokens=config.embedding_max_tokens,
) )

View file

@ -1 +1,2 @@
from .config import get_llm_config from .config import get_llm_config
from .utils import get_max_chunk_tokens

View file

@ -14,11 +14,12 @@ class AnthropicAdapter(LLMInterface):
name = "Anthropic" name = "Anthropic"
model: str model: str
def __init__(self, model: str = None): def __init__(self, max_tokens: int, model: str = None):
self.aclient = instructor.patch( self.aclient = instructor.patch(
create=anthropic.Anthropic().messages.create, mode=instructor.Mode.ANTHROPIC_TOOLS create=anthropic.Anthropic().messages.create, mode=instructor.Mode.ANTHROPIC_TOOLS
) )
self.model = model self.model = model
self.max_tokens = max_tokens
async def acreate_structured_output( async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel] self, text_input: str, system_prompt: str, response_model: Type[BaseModel]

View file

@ -11,6 +11,7 @@ class LLMConfig(BaseSettings):
llm_api_version: Optional[str] = None llm_api_version: Optional[str] = None
llm_temperature: float = 0.0 llm_temperature: float = 0.0
llm_streaming: bool = False llm_streaming: bool = False
llm_max_tokens: int = 16384
transcription_model: str = "whisper-1" transcription_model: str = "whisper-1"
model_config = SettingsConfigDict(env_file=".env", extra="allow") model_config = SettingsConfigDict(env_file=".env", extra="allow")
@ -24,6 +25,7 @@ class LLMConfig(BaseSettings):
"api_version": self.llm_api_version, "api_version": self.llm_api_version,
"temperature": self.llm_temperature, "temperature": self.llm_temperature,
"streaming": self.llm_streaming, "streaming": self.llm_streaming,
"max_tokens": self.llm_max_tokens,
"transcription_model": self.transcription_model, "transcription_model": self.transcription_model,
} }

View file

@ -2,6 +2,7 @@
import asyncio import asyncio
from typing import List, Type from typing import List, Type
from pydantic import BaseModel from pydantic import BaseModel
import instructor import instructor
from cognee.infrastructure.llm.llm_interface import LLMInterface from cognee.infrastructure.llm.llm_interface import LLMInterface
@ -16,11 +17,12 @@ class GenericAPIAdapter(LLMInterface):
model: str model: str
api_key: str api_key: str
def __init__(self, endpoint, api_key: str, model: str, name: str): def __init__(self, endpoint, api_key: str, model: str, name: str, max_tokens: int):
self.name = name self.name = name
self.model = model self.model = model
self.api_key = api_key self.api_key = api_key
self.endpoint = endpoint self.endpoint = endpoint
self.max_tokens = max_tokens
llm_config = get_llm_config() llm_config = get_llm_config()

View file

@ -20,6 +20,15 @@ def get_llm_client():
provider = LLMProvider(llm_config.llm_provider) provider = LLMProvider(llm_config.llm_provider)
# Check if max_token value is defined in liteLLM for given model
# if not use value from cognee configuration
from cognee.infrastructure.llm.utils import (
get_model_max_tokens,
) # imported here to avoid circular imports
model_max_tokens = get_model_max_tokens(llm_config.llm_model)
max_tokens = model_max_tokens if model_max_tokens else llm_config.llm_max_tokens
if provider == LLMProvider.OPENAI: if provider == LLMProvider.OPENAI:
if llm_config.llm_api_key is None: if llm_config.llm_api_key is None:
raise InvalidValueError(message="LLM API key is not set.") raise InvalidValueError(message="LLM API key is not set.")
@ -32,6 +41,7 @@ def get_llm_client():
api_version=llm_config.llm_api_version, api_version=llm_config.llm_api_version,
model=llm_config.llm_model, model=llm_config.llm_model,
transcription_model=llm_config.transcription_model, transcription_model=llm_config.transcription_model,
max_tokens=max_tokens,
streaming=llm_config.llm_streaming, streaming=llm_config.llm_streaming,
) )
@ -42,13 +52,17 @@ def get_llm_client():
from .generic_llm_api.adapter import GenericAPIAdapter from .generic_llm_api.adapter import GenericAPIAdapter
return GenericAPIAdapter( return GenericAPIAdapter(
llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Ollama" llm_config.llm_endpoint,
llm_config.llm_api_key,
llm_config.llm_model,
"Ollama",
max_tokens=max_tokens,
) )
elif provider == LLMProvider.ANTHROPIC: elif provider == LLMProvider.ANTHROPIC:
from .anthropic.adapter import AnthropicAdapter from .anthropic.adapter import AnthropicAdapter
return AnthropicAdapter(llm_config.llm_model) return AnthropicAdapter(max_tokens=max_tokens, model=llm_config.llm_model)
elif provider == LLMProvider.CUSTOM: elif provider == LLMProvider.CUSTOM:
if llm_config.llm_api_key is None: if llm_config.llm_api_key is None:
@ -57,7 +71,11 @@ def get_llm_client():
from .generic_llm_api.adapter import GenericAPIAdapter from .generic_llm_api.adapter import GenericAPIAdapter
return GenericAPIAdapter( return GenericAPIAdapter(
llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Custom" llm_config.llm_endpoint,
llm_config.llm_api_key,
llm_config.llm_model,
"Custom",
max_tokens=max_tokens,
) )
else: else:

View file

@ -32,6 +32,7 @@ class OpenAIAdapter(LLMInterface):
api_version: str, api_version: str,
model: str, model: str,
transcription_model: str, transcription_model: str,
max_tokens: int,
streaming: bool = False, streaming: bool = False,
): ):
self.aclient = instructor.from_litellm(litellm.acompletion) self.aclient = instructor.from_litellm(litellm.acompletion)
@ -41,6 +42,7 @@ class OpenAIAdapter(LLMInterface):
self.api_key = api_key self.api_key = api_key
self.endpoint = endpoint self.endpoint = endpoint
self.api_version = api_version self.api_version = api_version
self.max_tokens = max_tokens
self.streaming = streaming self.streaming = streaming
@observe(as_type="generation") @observe(as_type="generation")

View file

@ -0,0 +1 @@
from .adapter import GeminiTokenizer

View file

@ -0,0 +1,44 @@
from typing import List, Any
from ..tokenizer_interface import TokenizerInterface
class GeminiTokenizer(TokenizerInterface):
def __init__(
self,
model: str,
max_tokens: int = 3072,
):
self.model = model
self.max_tokens = max_tokens
# Get LLM API key from config
from cognee.infrastructure.databases.vector.embeddings.config import get_embedding_config
from cognee.infrastructure.llm.config import get_llm_config
config = get_embedding_config()
llm_config = get_llm_config()
import google.generativeai as genai
genai.configure(api_key=config.embedding_api_key or llm_config.llm_api_key)
def extract_tokens(self, text: str) -> List[Any]:
raise NotImplementedError
def count_tokens(self, text: str) -> int:
"""
Returns the number of tokens in the given text.
Args:
text: str
Returns:
number of tokens in the given text
"""
import google.generativeai as genai
return len(genai.embed_content(model=f"models/{self.model}", content=text))
def trim_text_to_max_tokens(self, text: str) -> str:
raise NotImplementedError

View file

@ -0,0 +1 @@
from .adapter import HuggingFaceTokenizer

View file

@ -0,0 +1,36 @@
from typing import List, Any
from transformers import AutoTokenizer
from ..tokenizer_interface import TokenizerInterface
class HuggingFaceTokenizer(TokenizerInterface):
def __init__(
self,
model: str,
max_tokens: int = 512,
):
self.model = model
self.max_tokens = max_tokens
self.tokenizer = AutoTokenizer.from_pretrained(model)
def extract_tokens(self, text: str) -> List[Any]:
tokens = self.tokenizer.tokenize(text)
return tokens
def count_tokens(self, text: str) -> int:
"""
Returns the number of tokens in the given text.
Args:
text: str
Returns:
number of tokens in the given text
"""
return len(self.tokenizer.tokenize(text))
def trim_text_to_max_tokens(self, text: str) -> str:
raise NotImplementedError

View file

@ -0,0 +1 @@
from .adapter import TikTokenTokenizer

View file

@ -0,0 +1,69 @@
from typing import List, Any
import tiktoken
from ..tokenizer_interface import TokenizerInterface
class TikTokenTokenizer(TokenizerInterface):
"""
Tokenizer adapter for OpenAI.
Inteded to be used as part of LLM Embedding and LLM Adapters classes
"""
def __init__(
self,
model: str,
max_tokens: int = 8191,
):
self.model = model
self.max_tokens = max_tokens
# Initialize TikToken for GPT based on model
self.tokenizer = tiktoken.encoding_for_model(self.model)
def extract_tokens(self, text: str) -> List[Any]:
tokens = []
# Using TikToken's method to tokenize text
token_ids = self.tokenizer.encode(text)
# Go through tokens and decode them to text value
for token_id in token_ids:
token = self.tokenizer.decode([token_id])
tokens.append(token)
return tokens
def count_tokens(self, text: str) -> int:
"""
Returns the number of tokens in the given text.
Args:
text: str
Returns:
number of tokens in the given text
"""
num_tokens = len(self.tokenizer.encode(text))
return num_tokens
def trim_text_to_max_tokens(self, text: str) -> str:
"""
Trims the text so that the number of tokens does not exceed max_tokens.
Args:
text (str): Original text string to be trimmed.
Returns:
str: Trimmed version of text or original text if under the limit.
"""
# First check the number of tokens
num_tokens = self.count_tokens(text)
# If the number of tokens is within the limit, return the text as is
if num_tokens <= self.max_tokens:
return text
# If the number exceeds the limit, trim the text
# This is a simple trim, it may cut words in half; consider using word boundaries for a cleaner cut
encoded_text = self.tokenizer.encode(text)
trimmed_encoded_text = encoded_text[: self.max_tokens]
# Decoding the trimmed text
trimmed_text = self.tokenizer.decode(trimmed_encoded_text)
return trimmed_text

View file

@ -0,0 +1 @@
from .tokenizer_interface import TokenizerInterface

View file

@ -0,0 +1,18 @@
from typing import List, Protocol, Any
from abc import abstractmethod
class TokenizerInterface(Protocol):
"""Tokenizer interface"""
@abstractmethod
def extract_tokens(self, text: str) -> List[Any]:
raise NotImplementedError
@abstractmethod
def count_tokens(self, text: str) -> int:
raise NotImplementedError
@abstractmethod
def trim_text_to_max_tokens(self, text: str) -> str:
raise NotImplementedError

View file

@ -0,0 +1,38 @@
import logging
import litellm
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.llm.get_llm_client import get_llm_client
logger = logging.getLogger(__name__)
def get_max_chunk_tokens():
# Calculate max chunk size based on the following formula
embedding_engine = get_vector_engine().embedding_engine
llm_client = get_llm_client()
# We need to make sure chunk size won't take more than half of LLM max context token size
# but it also can't be bigger than the embedding engine max token size
llm_cutoff_point = llm_client.max_tokens // 2 # Round down the division
max_chunk_tokens = min(embedding_engine.max_tokens, llm_cutoff_point)
return max_chunk_tokens
def get_model_max_tokens(model_name: str):
"""
Args:
model_name: name of LLM or embedding model
Returns: Number of max tokens of model, or None if model is unknown
"""
max_tokens = None
if model_name in litellm.model_cost:
max_tokens = litellm.model_cost[model_name]["max_tokens"]
logger.debug(f"Max input tokens for {model_name}: {max_tokens}")
else:
logger.info("Model not found in LiteLLM's model_cost.")
return max_tokens

View file

@ -14,17 +14,15 @@ class TextChunker:
chunk_size = 0 chunk_size = 0
token_count = 0 token_count = 0
def __init__( def __init__(self, document, get_text: callable, max_chunk_tokens: int, chunk_size: int = 1024):
self, document, get_text: callable, max_tokens: Optional[int] = None, chunk_size: int = 1024
):
self.document = document self.document = document
self.max_chunk_size = chunk_size self.max_chunk_size = chunk_size
self.get_text = get_text self.get_text = get_text
self.max_tokens = max_tokens if max_tokens else float("inf") self.max_chunk_tokens = max_chunk_tokens
def check_word_count_and_token_count(self, word_count_before, token_count_before, chunk_data): def check_word_count_and_token_count(self, word_count_before, token_count_before, chunk_data):
word_count_fits = word_count_before + chunk_data["word_count"] <= self.max_chunk_size word_count_fits = word_count_before + chunk_data["word_count"] <= self.max_chunk_size
token_count_fits = token_count_before + chunk_data["token_count"] <= self.max_tokens token_count_fits = token_count_before + chunk_data["token_count"] <= self.max_chunk_tokens
return word_count_fits and token_count_fits return word_count_fits and token_count_fits
def read(self): def read(self):
@ -32,7 +30,7 @@ class TextChunker:
for content_text in self.get_text(): for content_text in self.get_text():
for chunk_data in chunk_by_paragraph( for chunk_data in chunk_by_paragraph(
content_text, content_text,
self.max_tokens, self.max_chunk_tokens,
self.max_chunk_size, self.max_chunk_size,
batch_paragraphs=True, batch_paragraphs=True,
): ):

View file

@ -8,7 +8,6 @@ import os
class CognifyConfig(BaseSettings): class CognifyConfig(BaseSettings):
classification_model: object = DefaultContentPrediction classification_model: object = DefaultContentPrediction
summarization_model: object = SummarizedContent summarization_model: object = SummarizedContent
max_tokens: Optional[int] = os.getenv("MAX_TOKENS")
model_config = SettingsConfigDict(env_file=".env", extra="allow") model_config = SettingsConfigDict(env_file=".env", extra="allow")
def to_dict(self) -> dict: def to_dict(self) -> dict:

View file

@ -13,14 +13,14 @@ class AudioDocument(Document):
result = get_llm_client().create_transcript(self.raw_data_location) result = get_llm_client().create_transcript(self.raw_data_location)
return result.text return result.text
def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None): def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int):
# Transcribe the audio file # Transcribe the audio file
text = self.create_transcript() text = self.create_transcript()
chunker_func = ChunkerConfig.get_chunker(chunker) chunker_func = ChunkerConfig.get_chunker(chunker)
chunker = chunker_func( chunker = chunker_func(
self, chunk_size=chunk_size, get_text=lambda: [text], max_tokens=max_tokens self, chunk_size=chunk_size, get_text=lambda: [text], max_chunk_tokens=max_chunk_tokens
) )
yield from chunker.read() yield from chunker.read()

View file

@ -11,5 +11,5 @@ class Document(DataPoint):
mime_type: str mime_type: str
_metadata: dict = {"index_fields": ["name"], "type": "Document"} _metadata: dict = {"index_fields": ["name"], "type": "Document"}
def read(self, chunk_size: int, chunker=str, max_tokens: Optional[int] = None) -> str: def read(self, chunk_size: int, chunker=str, max_chunk_tokens: Optional[int] = None) -> str:
pass pass

View file

@ -13,13 +13,13 @@ class ImageDocument(Document):
result = get_llm_client().transcribe_image(self.raw_data_location) result = get_llm_client().transcribe_image(self.raw_data_location)
return result.choices[0].message.content return result.choices[0].message.content
def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None): def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int):
# Transcribe the image file # Transcribe the image file
text = self.transcribe_image() text = self.transcribe_image()
chunker_func = ChunkerConfig.get_chunker(chunker) chunker_func = ChunkerConfig.get_chunker(chunker)
chunker = chunker_func( chunker = chunker_func(
self, chunk_size=chunk_size, get_text=lambda: [text], max_tokens=max_tokens self, chunk_size=chunk_size, get_text=lambda: [text], max_chunk_tokens=max_chunk_tokens
) )
yield from chunker.read() yield from chunker.read()

View file

@ -9,7 +9,7 @@ from .Document import Document
class PdfDocument(Document): class PdfDocument(Document):
type: str = "pdf" type: str = "pdf"
def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None): def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int):
file = PdfReader(self.raw_data_location) file = PdfReader(self.raw_data_location)
def get_text(): def get_text():
@ -19,7 +19,7 @@ class PdfDocument(Document):
chunker_func = ChunkerConfig.get_chunker(chunker) chunker_func = ChunkerConfig.get_chunker(chunker)
chunker = chunker_func( chunker = chunker_func(
self, chunk_size=chunk_size, get_text=get_text, max_tokens=max_tokens self, chunk_size=chunk_size, get_text=get_text, max_chunk_tokens=max_chunk_tokens
) )
yield from chunker.read() yield from chunker.read()

View file

@ -7,7 +7,7 @@ from .Document import Document
class TextDocument(Document): class TextDocument(Document):
type: str = "text" type: str = "text"
def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None): def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int):
def get_text(): def get_text():
with open(self.raw_data_location, mode="r", encoding="utf-8") as file: with open(self.raw_data_location, mode="r", encoding="utf-8") as file:
while True: while True:
@ -21,7 +21,7 @@ class TextDocument(Document):
chunker_func = ChunkerConfig.get_chunker(chunker) chunker_func = ChunkerConfig.get_chunker(chunker)
chunker = chunker_func( chunker = chunker_func(
self, chunk_size=chunk_size, get_text=get_text, max_tokens=max_tokens self, chunk_size=chunk_size, get_text=get_text, max_chunk_tokens=max_chunk_tokens
) )
yield from chunker.read() yield from chunker.read()

View file

@ -10,7 +10,7 @@ from .Document import Document
class UnstructuredDocument(Document): class UnstructuredDocument(Document):
type: str = "unstructured" type: str = "unstructured"
def read(self, chunk_size: int, chunker: str, max_tokens: Optional[int] = None) -> str: def read(self, chunk_size: int, chunker: str, max_chunk_tokens: int) -> str:
def get_text(): def get_text():
try: try:
from unstructured.partition.auto import partition from unstructured.partition.auto import partition
@ -29,6 +29,8 @@ class UnstructuredDocument(Document):
yield text yield text
chunker = TextChunker(self, chunk_size=chunk_size, get_text=get_text, max_tokens=max_tokens) chunker = TextChunker(
self, chunk_size=chunk_size, get_text=get_text, max_chunk_tokens=max_chunk_tokens
)
yield from chunker.read() yield from chunker.read()

View file

@ -10,8 +10,6 @@ import graphistry
import networkx as nx import networkx as nx
import pandas as pd import pandas as pd
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import tiktoken
import time
import logging import logging
import sys import sys
@ -100,15 +98,6 @@ def send_telemetry(event_name: str, user_id, additional_properties: dict = {}):
print(f"Error sending telemetry through proxy: {response.status_code}") print(f"Error sending telemetry through proxy: {response.status_code}")
def num_tokens_from_string(string: str, encoding_name: str) -> int:
"""Returns the number of tokens in a text string."""
# tiktoken.get_encoding("cl100k_base")
encoding = tiktoken.encoding_for_model(encoding_name)
num_tokens = len(encoding.encode(string))
return num_tokens
def get_file_content_hash(file_obj: Union[str, BinaryIO]) -> str: def get_file_content_hash(file_obj: Union[str, BinaryIO]) -> str:
h = hashlib.md5() h = hashlib.md5()
@ -134,34 +123,6 @@ def get_file_content_hash(file_obj: Union[str, BinaryIO]) -> str:
raise IngestionError(message=f"Failed to load data from {file}: {e}") raise IngestionError(message=f"Failed to load data from {file}: {e}")
def trim_text_to_max_tokens(text: str, max_tokens: int, encoding_name: str) -> str:
"""
Trims the text so that the number of tokens does not exceed max_tokens.
Args:
text (str): Original text string to be trimmed.
max_tokens (int): Maximum number of tokens allowed.
encoding_name (str): The name of the token encoding to use.
Returns:
str: Trimmed version of text or original text if under the limit.
"""
# First check the number of tokens
num_tokens = num_tokens_from_string(text, encoding_name)
# If the number of tokens is within the limit, return the text as is
if num_tokens <= max_tokens:
return text
# If the number exceeds the limit, trim the text
# This is a simple trim, it may cut words in half; consider using word boundaries for a cleaner cut
encoded_text = tiktoken.get_encoding(encoding_name).encode(text)
trimmed_encoded_text = encoded_text[:max_tokens]
# Decoding the trimmed text
trimmed_text = tiktoken.get_encoding(encoding_name).decode(trimmed_encoded_text)
return trimmed_text
def generate_color_palette(unique_layers): def generate_color_palette(unique_layers):
colormap = plt.cm.get_cmap("viridis", len(unique_layers)) colormap = plt.cm.get_cmap("viridis", len(unique_layers))
colors = [colormap(i) for i in range(len(unique_layers))] colors = [colormap(i) for i in range(len(unique_layers))]

View file

@ -10,7 +10,7 @@ from .chunk_by_sentence import chunk_by_sentence
def chunk_by_paragraph( def chunk_by_paragraph(
data: str, data: str,
max_tokens: Optional[Union[int, float]] = None, max_chunk_tokens,
paragraph_length: int = 1024, paragraph_length: int = 1024,
batch_paragraphs: bool = True, batch_paragraphs: bool = True,
) -> Iterator[Dict[str, Any]]: ) -> Iterator[Dict[str, Any]]:
@ -30,8 +30,6 @@ def chunk_by_paragraph(
paragraph_ids = [] paragraph_ids = []
last_cut_type = None last_cut_type = None
current_token_count = 0 current_token_count = 0
if not max_tokens:
max_tokens = float("inf")
vector_engine = get_vector_engine() vector_engine = get_vector_engine()
embedding_model = vector_engine.embedding_engine.model embedding_model = vector_engine.embedding_engine.model
@ -47,7 +45,7 @@ def chunk_by_paragraph(
if current_word_count > 0 and ( if current_word_count > 0 and (
current_word_count + word_count > paragraph_length current_word_count + word_count > paragraph_length
or current_token_count + token_count > max_tokens or current_token_count + token_count > max_chunk_tokens
): ):
# Yield current chunk # Yield current chunk
chunk_dict = { chunk_dict = {

View file

@ -5,9 +5,9 @@ from cognee.modules.data.processing.document_types.Document import Document
async def extract_chunks_from_documents( async def extract_chunks_from_documents(
documents: list[Document], documents: list[Document],
max_chunk_tokens: int,
chunk_size: int = 1024, chunk_size: int = 1024,
chunker="text_chunker", chunker="text_chunker",
max_tokens: Optional[int] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
""" """
Extracts chunks of data from a list of documents based on the specified chunking parameters. Extracts chunks of data from a list of documents based on the specified chunking parameters.
@ -18,6 +18,6 @@ async def extract_chunks_from_documents(
""" """
for document in documents: for document in documents:
for document_chunk in document.read( for document_chunk in document.read(
chunk_size=chunk_size, chunker=chunker, max_tokens=max_tokens chunk_size=chunk_size, chunker=chunker, max_chunk_tokens=max_chunk_tokens
): ):
yield document_chunk yield document_chunk

View file

@ -89,26 +89,29 @@ def _get_subchunk_token_counts(
def _get_chunk_source_code( def _get_chunk_source_code(
code_token_counts: list[tuple[str, int]], overlap: float, max_tokens: int code_token_counts: list[tuple[str, int]], overlap: float
) -> tuple[list[tuple[str, int]], str]: ) -> tuple[list[tuple[str, int]], str]:
"""Generates a chunk of source code from tokenized subchunks with overlap handling.""" """Generates a chunk of source code from tokenized subchunks with overlap handling."""
current_count = 0 current_count = 0
cumulative_counts = [] cumulative_counts = []
current_source_code = "" current_source_code = ""
# Get embedding engine used in vector database
embedding_engine = get_vector_engine().embedding_engine
for i, (child_code, token_count) in enumerate(code_token_counts): for i, (child_code, token_count) in enumerate(code_token_counts):
current_count += token_count current_count += token_count
cumulative_counts.append(current_count) cumulative_counts.append(current_count)
if current_count > max_tokens: if current_count > embedding_engine.max_tokens:
break break
current_source_code += f"\n{child_code}" current_source_code += f"\n{child_code}"
if current_count <= max_tokens: if current_count <= embedding_engine.max_tokens:
return [], current_source_code.strip() return [], current_source_code.strip()
cutoff = 1 cutoff = 1
for i, cum_count in enumerate(cumulative_counts): for i, cum_count in enumerate(cumulative_counts):
if cum_count > (1 - overlap) * max_tokens: if cum_count > (1 - overlap) * embedding_engine.max_tokens:
break break
cutoff = i cutoff = i
@ -117,21 +120,18 @@ def _get_chunk_source_code(
def get_source_code_chunks_from_code_part( def get_source_code_chunks_from_code_part(
code_file_part: CodePart, code_file_part: CodePart,
max_tokens: int = 8192,
overlap: float = 0.25, overlap: float = 0.25,
granularity: float = 0.1, granularity: float = 0.1,
model_name: str = "text-embedding-3-large",
) -> Generator[SourceCodeChunk, None, None]: ) -> Generator[SourceCodeChunk, None, None]:
"""Yields source code chunks from a CodePart object, with configurable token limits and overlap.""" """Yields source code chunks from a CodePart object, with configurable token limits and overlap."""
if not code_file_part.source_code: if not code_file_part.source_code:
logger.error(f"No source code in CodeFile {code_file_part.id}") logger.error(f"No source code in CodeFile {code_file_part.id}")
return return
vector_engine = get_vector_engine() embedding_engine = get_vector_engine().embedding_engine
embedding_model = vector_engine.embedding_engine.model tokenizer = embedding_engine.tokenizer
model_name = embedding_model.split("/")[-1]
tokenizer = tiktoken.encoding_for_model(model_name) max_subchunk_tokens = max(1, int(granularity * embedding_engine.max_tokens))
max_subchunk_tokens = max(1, int(granularity * max_tokens))
subchunk_token_counts = _get_subchunk_token_counts( subchunk_token_counts = _get_subchunk_token_counts(
tokenizer, code_file_part.source_code, max_subchunk_tokens tokenizer, code_file_part.source_code, max_subchunk_tokens
) )
@ -139,7 +139,7 @@ def get_source_code_chunks_from_code_part(
previous_chunk = None previous_chunk = None
while subchunk_token_counts: while subchunk_token_counts:
subchunk_token_counts, chunk_source_code = _get_chunk_source_code( subchunk_token_counts, chunk_source_code = _get_chunk_source_code(
subchunk_token_counts, overlap, max_tokens subchunk_token_counts, overlap
) )
if not chunk_source_code: if not chunk_source_code:
continue continue

View file

@ -34,7 +34,7 @@ def test_AudioDocument():
) )
with patch.object(AudioDocument, "create_transcript", return_value=TEST_TEXT): with patch.object(AudioDocument, "create_transcript", return_value=TEST_TEXT):
for ground_truth, paragraph_data in zip( for ground_truth, paragraph_data in zip(
GROUND_TRUTH, document.read(chunk_size=64, chunker="text_chunker") GROUND_TRUTH, document.read(chunk_size=64, chunker="text_chunker", max_chunk_tokens=512)
): ):
assert ground_truth["word_count"] == paragraph_data.word_count, ( assert ground_truth["word_count"] == paragraph_data.word_count, (
f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }' f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'

View file

@ -23,7 +23,7 @@ def test_ImageDocument():
) )
with patch.object(ImageDocument, "transcribe_image", return_value=TEST_TEXT): with patch.object(ImageDocument, "transcribe_image", return_value=TEST_TEXT):
for ground_truth, paragraph_data in zip( for ground_truth, paragraph_data in zip(
GROUND_TRUTH, document.read(chunk_size=64, chunker="text_chunker") GROUND_TRUTH, document.read(chunk_size=64, chunker="text_chunker", max_chunk_tokens=512)
): ):
assert ground_truth["word_count"] == paragraph_data.word_count, ( assert ground_truth["word_count"] == paragraph_data.word_count, (
f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }' f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'

View file

@ -25,7 +25,7 @@ def test_PdfDocument():
) )
for ground_truth, paragraph_data in zip( for ground_truth, paragraph_data in zip(
GROUND_TRUTH, document.read(chunk_size=1024, chunker="text_chunker") GROUND_TRUTH, document.read(chunk_size=1024, chunker="text_chunker", max_chunk_tokens=2048)
): ):
assert ground_truth["word_count"] == paragraph_data.word_count, ( assert ground_truth["word_count"] == paragraph_data.word_count, (
f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }' f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'

View file

@ -37,7 +37,8 @@ def test_TextDocument(input_file, chunk_size):
) )
for ground_truth, paragraph_data in zip( for ground_truth, paragraph_data in zip(
GROUND_TRUTH[input_file], document.read(chunk_size=chunk_size, chunker="text_chunker") GROUND_TRUTH[input_file],
document.read(chunk_size=chunk_size, chunker="text_chunker", max_chunk_tokens=1024),
): ):
assert ground_truth["word_count"] == paragraph_data.word_count, ( assert ground_truth["word_count"] == paragraph_data.word_count, (
f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }' f'{ground_truth["word_count"] = } != {paragraph_data.word_count = }'

View file

@ -68,7 +68,9 @@ def test_UnstructuredDocument():
) )
# Test PPTX # Test PPTX
for paragraph_data in pptx_document.read(chunk_size=1024, chunker="text_chunker"): for paragraph_data in pptx_document.read(
chunk_size=1024, chunker="text_chunker", max_chunk_tokens=1024
):
assert 19 == paragraph_data.word_count, f" 19 != {paragraph_data.word_count = }" assert 19 == paragraph_data.word_count, f" 19 != {paragraph_data.word_count = }"
assert 104 == len(paragraph_data.text), f" 104 != {len(paragraph_data.text) = }" assert 104 == len(paragraph_data.text), f" 104 != {len(paragraph_data.text) = }"
assert "sentence_cut" == paragraph_data.cut_type, ( assert "sentence_cut" == paragraph_data.cut_type, (
@ -76,7 +78,9 @@ def test_UnstructuredDocument():
) )
# Test DOCX # Test DOCX
for paragraph_data in docx_document.read(chunk_size=1024, chunker="text_chunker"): for paragraph_data in docx_document.read(
chunk_size=1024, chunker="text_chunker", max_chunk_tokens=1024
):
assert 16 == paragraph_data.word_count, f" 16 != {paragraph_data.word_count = }" assert 16 == paragraph_data.word_count, f" 16 != {paragraph_data.word_count = }"
assert 145 == len(paragraph_data.text), f" 145 != {len(paragraph_data.text) = }" assert 145 == len(paragraph_data.text), f" 145 != {len(paragraph_data.text) = }"
assert "sentence_end" == paragraph_data.cut_type, ( assert "sentence_end" == paragraph_data.cut_type, (
@ -84,7 +88,9 @@ def test_UnstructuredDocument():
) )
# TEST CSV # TEST CSV
for paragraph_data in csv_document.read(chunk_size=1024, chunker="text_chunker"): for paragraph_data in csv_document.read(
chunk_size=1024, chunker="text_chunker", max_chunk_tokens=1024
):
assert 15 == paragraph_data.word_count, f" 15 != {paragraph_data.word_count = }" assert 15 == paragraph_data.word_count, f" 15 != {paragraph_data.word_count = }"
assert "A A A A A A A A A,A A A A A A,A A" == paragraph_data.text, ( assert "A A A A A A A A A,A A A A A A,A A" == paragraph_data.text, (
f"Read text doesn't match expected text: {paragraph_data.text}" f"Read text doesn't match expected text: {paragraph_data.text}"
@ -94,7 +100,9 @@ def test_UnstructuredDocument():
) )
# Test XLSX # Test XLSX
for paragraph_data in xlsx_document.read(chunk_size=1024, chunker="text_chunker"): for paragraph_data in xlsx_document.read(
chunk_size=1024, chunker="text_chunker", max_chunk_tokens=1024
):
assert 36 == paragraph_data.word_count, f" 36 != {paragraph_data.word_count = }" assert 36 == paragraph_data.word_count, f" 36 != {paragraph_data.word_count = }"
assert 171 == len(paragraph_data.text), f" 171 != {len(paragraph_data.text) = }" assert 171 == len(paragraph_data.text), f" 171 != {len(paragraph_data.text) = }"
assert "sentence_cut" == paragraph_data.cut_type, ( assert "sentence_cut" == paragraph_data.cut_type, (

View file

@ -8,14 +8,24 @@ from cognee.tests.unit.processing.chunks.test_input import INPUT_TEXTS
paragraph_lengths = [64, 256, 1024] paragraph_lengths = [64, 256, 1024]
batch_paragraphs_vals = [True, False] batch_paragraphs_vals = [True, False]
max_chunk_tokens_vals = [512, 1024, 4096]
@pytest.mark.parametrize( @pytest.mark.parametrize(
"input_text,paragraph_length,batch_paragraphs", "input_text,max_chunk_tokens,paragraph_length,batch_paragraphs",
list(product(list(INPUT_TEXTS.values()), paragraph_lengths, batch_paragraphs_vals)), list(
product(
list(INPUT_TEXTS.values()),
max_chunk_tokens_vals,
paragraph_lengths,
batch_paragraphs_vals,
)
),
) )
def test_chunk_by_paragraph_isomorphism(input_text, paragraph_length, batch_paragraphs): def test_chunk_by_paragraph_isomorphism(
chunks = chunk_by_paragraph(input_text, paragraph_length, batch_paragraphs) input_text, max_chunk_tokens, paragraph_length, batch_paragraphs
):
chunks = chunk_by_paragraph(input_text, max_chunk_tokens, paragraph_length, batch_paragraphs)
reconstructed_text = "".join([chunk["text"] for chunk in chunks]) reconstructed_text = "".join([chunk["text"] for chunk in chunks])
assert reconstructed_text == input_text, ( assert reconstructed_text == input_text, (
f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }" f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
@ -23,13 +33,23 @@ def test_chunk_by_paragraph_isomorphism(input_text, paragraph_length, batch_para
@pytest.mark.parametrize( @pytest.mark.parametrize(
"input_text,paragraph_length,batch_paragraphs", "input_text,max_chunk_tokens,paragraph_length,batch_paragraphs",
list(product(list(INPUT_TEXTS.values()), paragraph_lengths, batch_paragraphs_vals)), list(
product(
list(INPUT_TEXTS.values()),
max_chunk_tokens_vals,
paragraph_lengths,
batch_paragraphs_vals,
)
),
) )
def test_paragraph_chunk_length(input_text, paragraph_length, batch_paragraphs): def test_paragraph_chunk_length(input_text, max_chunk_tokens, paragraph_length, batch_paragraphs):
chunks = list( chunks = list(
chunk_by_paragraph( chunk_by_paragraph(
data=input_text, paragraph_length=paragraph_length, batch_paragraphs=batch_paragraphs data=input_text,
max_chunk_tokens=max_chunk_tokens,
paragraph_length=paragraph_length,
batch_paragraphs=batch_paragraphs,
) )
) )
@ -42,12 +62,24 @@ def test_paragraph_chunk_length(input_text, paragraph_length, batch_paragraphs):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"input_text,paragraph_length,batch_paragraphs", "input_text,max_chunk_tokens,paragraph_length,batch_paragraphs",
list(product(list(INPUT_TEXTS.values()), paragraph_lengths, batch_paragraphs_vals)), list(
product(
list(INPUT_TEXTS.values()),
max_chunk_tokens_vals,
paragraph_lengths,
batch_paragraphs_vals,
)
),
) )
def test_chunk_by_paragraph_chunk_numbering(input_text, paragraph_length, batch_paragraphs): def test_chunk_by_paragraph_chunk_numbering(
input_text, max_chunk_tokens, paragraph_length, batch_paragraphs
):
chunks = chunk_by_paragraph( chunks = chunk_by_paragraph(
data=input_text, paragraph_length=paragraph_length, batch_paragraphs=batch_paragraphs data=input_text,
max_chunk_tokens=max_chunk_tokens,
paragraph_length=paragraph_length,
batch_paragraphs=batch_paragraphs,
) )
chunk_indices = np.array([chunk["chunk_index"] for chunk in chunks]) chunk_indices = np.array([chunk["chunk_index"] for chunk in chunks])
assert np.all(chunk_indices == np.arange(len(chunk_indices))), ( assert np.all(chunk_indices == np.arange(len(chunk_indices))), (

View file

@ -50,7 +50,7 @@ Third paragraph is cut and is missing the dot at the end""",
def run_chunking_test(test_text, expected_chunks): def run_chunking_test(test_text, expected_chunks):
chunks = [] chunks = []
for chunk_data in chunk_by_paragraph( for chunk_data in chunk_by_paragraph(
data=test_text, paragraph_length=12, batch_paragraphs=False data=test_text, paragraph_length=12, batch_paragraphs=False, max_chunk_tokens=512
): ):
chunks.append(chunk_data) chunks.append(chunk_data)

View file

@ -11,9 +11,7 @@ from cognee.shared.exceptions import IngestionError
from cognee.shared.utils import ( from cognee.shared.utils import (
get_anonymous_id, get_anonymous_id,
send_telemetry, send_telemetry,
num_tokens_from_string,
get_file_content_hash, get_file_content_hash,
trim_text_to_max_tokens,
prepare_edges, prepare_edges,
prepare_nodes, prepare_nodes,
create_cognee_style_network_with_logo, create_cognee_style_network_with_logo,
@ -45,15 +43,6 @@ def test_get_anonymous_id(mock_open_file, mock_makedirs, temp_dir):
# args, kwargs = mock_post.call_args # args, kwargs = mock_post.call_args
# assert kwargs["json"]["event_name"] == "test_event" # assert kwargs["json"]["event_name"] == "test_event"
#
# @patch("tiktoken.encoding_for_model")
# def test_num_tokens_from_string(mock_encoding):
# mock_encoding.return_value.encode = lambda x: list(x)
#
# assert num_tokens_from_string("hello", "test_encoding") == 5
# assert num_tokens_from_string("world", "test_encoding") == 5
#
@patch("builtins.open", new_callable=mock_open, read_data=b"test_data") @patch("builtins.open", new_callable=mock_open, read_data=b"test_data")
def test_get_file_content_hash_file(mock_open_file): def test_get_file_content_hash_file(mock_open_file):
@ -73,18 +62,6 @@ def test_get_file_content_hash_stream():
assert result == expected_hash assert result == expected_hash
# def test_trim_text_to_max_tokens():
# text = "This is a test string with multiple words."
# encoding_name = "test_encoding"
#
# with patch("tiktoken.get_encoding") as mock_get_encoding:
# mock_get_encoding.return_value.encode = lambda x: list(x)
# mock_get_encoding.return_value.decode = lambda x: "".join(x)
#
# result = trim_text_to_max_tokens(text, 5, encoding_name)
# assert result == text[:5]
def test_prepare_edges(): def test_prepare_edges():
graph = nx.MultiDiGraph() graph = nx.MultiDiGraph()
graph.add_edge("A", "B", key="AB", weight=1) graph.add_edge("A", "B", key="AB", weight=1)

View file

@ -650,6 +650,7 @@
"from cognee.modules.pipelines import run_tasks\n", "from cognee.modules.pipelines import run_tasks\n",
"from cognee.modules.users.models import User\n", "from cognee.modules.users.models import User\n",
"from cognee.tasks.documents import check_permissions_on_documents, classify_documents, extract_chunks_from_documents\n", "from cognee.tasks.documents import check_permissions_on_documents, classify_documents, extract_chunks_from_documents\n",
"from cognee.infrastructure.llm import get_max_chunk_tokens\n",
"from cognee.tasks.graph import extract_graph_from_data\n", "from cognee.tasks.graph import extract_graph_from_data\n",
"from cognee.tasks.storage import add_data_points\n", "from cognee.tasks.storage import add_data_points\n",
"from cognee.tasks.summarization import summarize_text\n", "from cognee.tasks.summarization import summarize_text\n",
@ -663,7 +664,7 @@
" tasks = [\n", " tasks = [\n",
" Task(classify_documents),\n", " Task(classify_documents),\n",
" Task(check_permissions_on_documents, user = user, permissions = [\"write\"]),\n", " Task(check_permissions_on_documents, user = user, permissions = [\"write\"]),\n",
" Task(extract_chunks_from_documents), # Extract text chunks based on the document type.\n", " Task(extract_chunks_from_documents, max_chunk_tokens=get_max_chunk_tokens()), # Extract text chunks based on the document type.\n",
" Task(extract_graph_from_data, graph_model = KnowledgeGraph, task_config = { \"batch_size\": 10 }), # Generate knowledge graphs from the document chunks.\n", " Task(extract_graph_from_data, graph_model = KnowledgeGraph, task_config = { \"batch_size\": 10 }), # Generate knowledge graphs from the document chunks.\n",
" Task(\n", " Task(\n",
" summarize_text,\n", " summarize_text,\n",

246
poetry.lock generated
View file

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. # This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand.
[[package]] [[package]]
name = "aiofiles" name = "aiofiles"
@ -645,6 +645,17 @@ urllib3 = {version = ">=1.25.4,<2.2.0 || >2.2.0,<3", markers = "python_version >
[package.extras] [package.extras]
crt = ["awscrt (==0.23.4)"] crt = ["awscrt (==0.23.4)"]
[[package]]
name = "cachetools"
version = "5.5.1"
description = "Extensible memoizing collections and decorators"
optional = true
python-versions = ">=3.7"
files = [
{file = "cachetools-5.5.1-py3-none-any.whl", hash = "sha256:b76651fdc3b24ead3c648bbdeeb940c1b04d365b38b4af66788f9ec4a81d42bb"},
{file = "cachetools-5.5.1.tar.gz", hash = "sha256:70f238fbba50383ef62e55c6aff6d9673175fe59f7c6782c7a0b9e38f4a9df95"},
]
[[package]] [[package]]
name = "certifi" name = "certifi"
version = "2024.12.14" version = "2024.12.14"
@ -1995,6 +2006,135 @@ files = [
{file = "giturlparse-0.12.0.tar.gz", hash = "sha256:c0fff7c21acc435491b1779566e038757a205c1ffdcb47e4f81ea52ad8c3859a"}, {file = "giturlparse-0.12.0.tar.gz", hash = "sha256:c0fff7c21acc435491b1779566e038757a205c1ffdcb47e4f81ea52ad8c3859a"},
] ]
[[package]]
name = "google-ai-generativelanguage"
version = "0.6.15"
description = "Google Ai Generativelanguage API client library"
optional = true
python-versions = ">=3.7"
files = [
{file = "google_ai_generativelanguage-0.6.15-py3-none-any.whl", hash = "sha256:5a03ef86377aa184ffef3662ca28f19eeee158733e45d7947982eb953c6ebb6c"},
{file = "google_ai_generativelanguage-0.6.15.tar.gz", hash = "sha256:8f6d9dc4c12b065fe2d0289026171acea5183ebf2d0b11cefe12f3821e159ec3"},
]
[package.dependencies]
google-api-core = {version = ">=1.34.1,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]}
google-auth = ">=2.14.1,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0dev"
proto-plus = ">=1.22.3,<2.0.0dev"
protobuf = ">=3.20.2,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0dev"
[[package]]
name = "google-api-core"
version = "2.24.0"
description = "Google API client core library"
optional = true
python-versions = ">=3.7"
files = [
{file = "google_api_core-2.24.0-py3-none-any.whl", hash = "sha256:10d82ac0fca69c82a25b3efdeefccf6f28e02ebb97925a8cce8edbfe379929d9"},
{file = "google_api_core-2.24.0.tar.gz", hash = "sha256:e255640547a597a4da010876d333208ddac417d60add22b6851a0c66a831fcaf"},
]
[package.dependencies]
google-auth = ">=2.14.1,<3.0.dev0"
googleapis-common-protos = ">=1.56.2,<2.0.dev0"
grpcio = [
{version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""},
{version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
]
grpcio-status = [
{version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""},
{version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
]
proto-plus = ">=1.22.3,<2.0.0dev"
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0"
requests = ">=2.18.0,<3.0.0.dev0"
[package.extras]
async-rest = ["google-auth[aiohttp] (>=2.35.0,<3.0.dev0)"]
grpc = ["grpcio (>=1.33.2,<2.0dev)", "grpcio (>=1.49.1,<2.0dev)", "grpcio-status (>=1.33.2,<2.0.dev0)", "grpcio-status (>=1.49.1,<2.0.dev0)"]
grpcgcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"]
grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"]
[[package]]
name = "google-api-python-client"
version = "2.159.0"
description = "Google API Client Library for Python"
optional = true
python-versions = ">=3.7"
files = [
{file = "google_api_python_client-2.159.0-py2.py3-none-any.whl", hash = "sha256:baef0bb631a60a0bd7c0bf12a5499e3a40cd4388484de7ee55c1950bf820a0cf"},
{file = "google_api_python_client-2.159.0.tar.gz", hash = "sha256:55197f430f25c907394b44fa078545ffef89d33fd4dca501b7db9f0d8e224bd6"},
]
[package.dependencies]
google-api-core = ">=1.31.5,<2.0.dev0 || >2.3.0,<3.0.0.dev0"
google-auth = ">=1.32.0,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0.dev0"
google-auth-httplib2 = ">=0.2.0,<1.0.0"
httplib2 = ">=0.19.0,<1.dev0"
uritemplate = ">=3.0.1,<5"
[[package]]
name = "google-auth"
version = "2.38.0"
description = "Google Authentication Library"
optional = true
python-versions = ">=3.7"
files = [
{file = "google_auth-2.38.0-py2.py3-none-any.whl", hash = "sha256:e7dae6694313f434a2727bf2906f27ad259bae090d7aa896590d86feec3d9d4a"},
{file = "google_auth-2.38.0.tar.gz", hash = "sha256:8285113607d3b80a3f1543b75962447ba8a09fe85783432a784fdeef6ac094c4"},
]
[package.dependencies]
cachetools = ">=2.0.0,<6.0"
pyasn1-modules = ">=0.2.1"
rsa = ">=3.1.4,<5"
[package.extras]
aiohttp = ["aiohttp (>=3.6.2,<4.0.0.dev0)", "requests (>=2.20.0,<3.0.0.dev0)"]
enterprise-cert = ["cryptography", "pyopenssl"]
pyjwt = ["cryptography (>=38.0.3)", "pyjwt (>=2.0)"]
pyopenssl = ["cryptography (>=38.0.3)", "pyopenssl (>=20.0.0)"]
reauth = ["pyu2f (>=0.1.5)"]
requests = ["requests (>=2.20.0,<3.0.0.dev0)"]
[[package]]
name = "google-auth-httplib2"
version = "0.2.0"
description = "Google Authentication Library: httplib2 transport"
optional = true
python-versions = "*"
files = [
{file = "google-auth-httplib2-0.2.0.tar.gz", hash = "sha256:38aa7badf48f974f1eb9861794e9c0cb2a0511a4ec0679b1f886d108f5640e05"},
{file = "google_auth_httplib2-0.2.0-py2.py3-none-any.whl", hash = "sha256:b65a0a2123300dd71281a7bf6e64d65a0759287df52729bdd1ae2e47dc311a3d"},
]
[package.dependencies]
google-auth = "*"
httplib2 = ">=0.19.0"
[[package]]
name = "google-generativeai"
version = "0.8.4"
description = "Google Generative AI High level API client library and tools."
optional = true
python-versions = ">=3.9"
files = [
{file = "google_generativeai-0.8.4-py3-none-any.whl", hash = "sha256:e987b33ea6decde1e69191ddcaec6ef974458864d243de7191db50c21a7c5b82"},
]
[package.dependencies]
google-ai-generativelanguage = "0.6.15"
google-api-core = "*"
google-api-python-client = "*"
google-auth = ">=2.15.0"
protobuf = "*"
pydantic = "*"
tqdm = "*"
typing-extensions = "*"
[package.extras]
dev = ["Pillow", "absl-py", "black", "ipython", "nose2", "pandas", "pytype", "pyyaml"]
[[package]] [[package]]
name = "googleapis-common-protos" name = "googleapis-common-protos"
version = "1.66.0" version = "1.66.0"
@ -2251,6 +2391,22 @@ files = [
grpcio = ">=1.67.1" grpcio = ">=1.67.1"
protobuf = ">=5.26.1,<6.0dev" protobuf = ">=5.26.1,<6.0dev"
[[package]]
name = "grpcio-status"
version = "1.67.1"
description = "Status proto mapping for gRPC"
optional = true
python-versions = ">=3.8"
files = [
{file = "grpcio_status-1.67.1-py3-none-any.whl", hash = "sha256:16e6c085950bdacac97c779e6a502ea671232385e6e37f258884d6883392c2bd"},
{file = "grpcio_status-1.67.1.tar.gz", hash = "sha256:2bf38395e028ceeecfd8866b081f61628114b384da7d51ae064ddc8d766a5d11"},
]
[package.dependencies]
googleapis-common-protos = ">=1.5.5"
grpcio = ">=1.67.1"
protobuf = ">=5.26.1,<6.0dev"
[[package]] [[package]]
name = "grpcio-tools" name = "grpcio-tools"
version = "1.67.1" version = "1.67.1"
@ -2445,6 +2601,20 @@ http2 = ["h2 (>=3,<5)"]
socks = ["socksio (==1.*)"] socks = ["socksio (==1.*)"]
trio = ["trio (>=0.22.0,<1.0)"] trio = ["trio (>=0.22.0,<1.0)"]
[[package]]
name = "httplib2"
version = "0.22.0"
description = "A comprehensive HTTP client library."
optional = true
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
files = [
{file = "httplib2-0.22.0-py3-none-any.whl", hash = "sha256:14ae0a53c1ba8f3d37e9e27cf37eabb0fb9980f435ba405d546948b009dd64dc"},
{file = "httplib2-0.22.0.tar.gz", hash = "sha256:d7a10bc5ef5ab08322488bde8c726eeee5c8618723fdb399597ec58f3d82df81"},
]
[package.dependencies]
pyparsing = {version = ">=2.4.2,<3.0.0 || >3.0.0,<3.0.1 || >3.0.1,<3.0.2 || >3.0.2,<3.0.3 || >3.0.3,<4", markers = "python_version > \"3.0\""}
[[package]] [[package]]
name = "httpx" name = "httpx"
version = "0.27.0" version = "0.27.0"
@ -4998,8 +5168,8 @@ files = [
[package.dependencies] [package.dependencies]
numpy = [ numpy = [
{version = ">=1.22.4", markers = "python_version < \"3.11\""}, {version = ">=1.22.4", markers = "python_version < \"3.11\""},
{version = ">=1.23.2", markers = "python_version == \"3.11\""},
{version = ">=1.26.0", markers = "python_version >= \"3.12\""}, {version = ">=1.26.0", markers = "python_version >= \"3.12\""},
{version = ">=1.23.2", markers = "python_version == \"3.11\""},
] ]
python-dateutil = ">=2.8.2" python-dateutil = ">=2.8.2"
pytz = ">=2020.1" pytz = ">=2020.1"
@ -5520,6 +5690,23 @@ files = [
{file = "propcache-0.2.1.tar.gz", hash = "sha256:3f77ce728b19cb537714499928fe800c3dda29e8d9428778fc7c186da4c09a64"}, {file = "propcache-0.2.1.tar.gz", hash = "sha256:3f77ce728b19cb537714499928fe800c3dda29e8d9428778fc7c186da4c09a64"},
] ]
[[package]]
name = "proto-plus"
version = "1.25.0"
description = "Beautiful, Pythonic protocol buffers."
optional = true
python-versions = ">=3.7"
files = [
{file = "proto_plus-1.25.0-py3-none-any.whl", hash = "sha256:c91fc4a65074ade8e458e95ef8bac34d4008daa7cce4a12d6707066fca648961"},
{file = "proto_plus-1.25.0.tar.gz", hash = "sha256:fbb17f57f7bd05a68b7707e745e26528b0b3c34e378db91eef93912c54982d91"},
]
[package.dependencies]
protobuf = ">=3.19.0,<6.0.0dev"
[package.extras]
testing = ["google-api-core (>=1.31.5)"]
[[package]] [[package]]
name = "protobuf" name = "protobuf"
version = "5.29.3" version = "5.29.3"
@ -5686,6 +5873,31 @@ files = [
[package.extras] [package.extras]
test = ["cffi", "hypothesis", "pandas", "pytest", "pytz"] test = ["cffi", "hypothesis", "pandas", "pytest", "pytz"]
[[package]]
name = "pyasn1"
version = "0.6.1"
description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)"
optional = true
python-versions = ">=3.8"
files = [
{file = "pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629"},
{file = "pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034"},
]
[[package]]
name = "pyasn1-modules"
version = "0.4.1"
description = "A collection of ASN.1-based protocols modules"
optional = true
python-versions = ">=3.8"
files = [
{file = "pyasn1_modules-0.4.1-py3-none-any.whl", hash = "sha256:49bfa96b45a292b711e986f222502c1c9a5e1f4e568fc30e2574a6c7d07838fd"},
{file = "pyasn1_modules-0.4.1.tar.gz", hash = "sha256:c28e2dbf9c06ad61c71a075c7e0f9fd0f1b0bb2d2ad4377f240d33ac2ab60a7c"},
]
[package.dependencies]
pyasn1 = ">=0.4.6,<0.7.0"
[[package]] [[package]]
name = "pycparser" name = "pycparser"
version = "2.22" version = "2.22"
@ -5927,8 +6139,8 @@ astroid = ">=3.3.8,<=3.4.0-dev0"
colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""} colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""}
dill = [ dill = [
{version = ">=0.2", markers = "python_version < \"3.11\""}, {version = ">=0.2", markers = "python_version < \"3.11\""},
{version = ">=0.3.6", markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
{version = ">=0.3.7", markers = "python_version >= \"3.12\""}, {version = ">=0.3.7", markers = "python_version >= \"3.12\""},
{version = ">=0.3.6", markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
] ]
isort = ">=4.2.5,<5.13.0 || >5.13.0,<6" isort = ">=4.2.5,<5.13.0 || >5.13.0,<6"
mccabe = ">=0.6,<0.8" mccabe = ">=0.6,<0.8"
@ -6967,6 +7179,20 @@ files = [
{file = "rpds_py-0.22.3.tar.gz", hash = "sha256:e32fee8ab45d3c2db6da19a5323bc3362237c8b653c70194414b892fd06a080d"}, {file = "rpds_py-0.22.3.tar.gz", hash = "sha256:e32fee8ab45d3c2db6da19a5323bc3362237c8b653c70194414b892fd06a080d"},
] ]
[[package]]
name = "rsa"
version = "4.9"
description = "Pure-Python RSA implementation"
optional = true
python-versions = ">=3.6,<4"
files = [
{file = "rsa-4.9-py3-none-any.whl", hash = "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7"},
{file = "rsa-4.9.tar.gz", hash = "sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21"},
]
[package.dependencies]
pyasn1 = ">=0.1.3"
[[package]] [[package]]
name = "ruff" name = "ruff"
version = "0.9.3" version = "0.9.3"
@ -8218,6 +8444,17 @@ files = [
[package.extras] [package.extras]
dev = ["flake8", "flake8-annotations", "flake8-bandit", "flake8-bugbear", "flake8-commas", "flake8-comprehensions", "flake8-continuation", "flake8-datetimez", "flake8-docstrings", "flake8-import-order", "flake8-literal", "flake8-modern-annotations", "flake8-noqa", "flake8-pyproject", "flake8-requirements", "flake8-typechecking-import", "flake8-use-fstring", "mypy", "pep8-naming", "types-PyYAML"] dev = ["flake8", "flake8-annotations", "flake8-bandit", "flake8-bugbear", "flake8-commas", "flake8-comprehensions", "flake8-continuation", "flake8-datetimez", "flake8-docstrings", "flake8-import-order", "flake8-literal", "flake8-modern-annotations", "flake8-noqa", "flake8-pyproject", "flake8-requirements", "flake8-typechecking-import", "flake8-use-fstring", "mypy", "pep8-naming", "types-PyYAML"]
[[package]]
name = "uritemplate"
version = "4.1.1"
description = "Implementation of RFC 6570 URI Templates"
optional = true
python-versions = ">=3.6"
files = [
{file = "uritemplate-4.1.1-py2.py3-none-any.whl", hash = "sha256:830c08b8d99bdd312ea4ead05994a38e8936266f84b9a7878232db50b044e02e"},
{file = "uritemplate-4.1.1.tar.gz", hash = "sha256:4346edfc5c3b79f694bccd6d6099a322bbeb628dbf2cd86eea55a456ce5124f0"},
]
[[package]] [[package]]
name = "urllib3" name = "urllib3"
version = "2.3.0" version = "2.3.0"
@ -8801,6 +9038,7 @@ deepeval = ["deepeval"]
docs = ["unstructured"] docs = ["unstructured"]
falkordb = ["falkordb"] falkordb = ["falkordb"]
filesystem = ["botocore"] filesystem = ["botocore"]
gemini = ["google-generativeai"]
groq = ["groq"] groq = ["groq"]
langchain = ["langchain_text_splitters", "langsmith"] langchain = ["langchain_text_splitters", "langsmith"]
llama-index = ["llama-index-core"] llama-index = ["llama-index-core"]
@ -8815,4 +9053,4 @@ weaviate = ["weaviate-client"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.10.0,<3.13" python-versions = ">=3.10.0,<3.13"
content-hash = "1cc352109264d0e3add524cdc15c9b2e6153e1bab20d968b40e42a4d5138967f" content-hash = "480675c274cd85a76a95bf03af865b1a0b462f25bbc21d7427b0a0b8e21c13db"

View file

@ -77,6 +77,7 @@ pre-commit = "^4.0.1"
httpx = "0.27.0" httpx = "0.27.0"
bokeh="^3.6.2" bokeh="^3.6.2"
nltk = "3.9.1" nltk = "3.9.1"
google-generativeai = {version = "^0.8.4", optional = true}
parso = {version = "^0.8.4", optional = true} parso = {version = "^0.8.4", optional = true}
jedi = {version = "^0.19.2", optional = true} jedi = {version = "^0.19.2", optional = true}
@ -90,6 +91,7 @@ postgres = ["psycopg2", "pgvector", "asyncpg"]
notebook = ["notebook", "ipykernel", "overrides", "ipywidgets", "jupyterlab", "jupyterlab_widgets", "jupyterlab-server", "jupyterlab-git"] notebook = ["notebook", "ipykernel", "overrides", "ipywidgets", "jupyterlab", "jupyterlab_widgets", "jupyterlab-server", "jupyterlab-git"]
langchain = ["langsmith", "langchain_text_splitters"] langchain = ["langsmith", "langchain_text_splitters"]
llama-index = ["llama-index-core"] llama-index = ["llama-index-core"]
gemini = ["google-generativeai"]
deepeval = ["deepeval"] deepeval = ["deepeval"]
posthog = ["posthog"] posthog = ["posthog"]
falkordb = ["falkordb"] falkordb = ["falkordb"]