rewrote chunking config

This commit is contained in:
Vasilije 2024-06-09 22:46:02 +02:00
parent 00b60a9aef
commit 11231b7ada
11 changed files with 712 additions and 672 deletions

View file

@ -5,7 +5,11 @@ import logging
import nltk import nltk
from asyncio import Lock from asyncio import Lock
from nltk.corpus import stopwords from nltk.corpus import stopwords
from cognee.infrastructure.data.chunking.LangchainChunkingEngine import LangchainChunkEngine
from cognee.infrastructure.data.chunking.get_chunking_engine import get_chunk_engine
from cognee.infrastructure.databases.graph.config import get_graph_config from cognee.infrastructure.databases.graph.config import get_graph_config
from cognee.infrastructure.databases.vector.embeddings.LiteLLMEmbeddingEngine import LiteLLMEmbeddingEngine
from cognee.modules.cognify.graph.add_node_connections import group_nodes_by_layer, \ from cognee.modules.cognify.graph.add_node_connections import group_nodes_by_layer, \
graph_ready_output, connect_nodes_in_graph graph_ready_output, connect_nodes_in_graph
from cognee.modules.cognify.graph.add_data_chunks import add_data_chunks, add_data_chunks_basic_rag from cognee.modules.cognify.graph.add_data_chunks import add_data_chunks, add_data_chunks_basic_rag
@ -23,7 +27,7 @@ from cognee.modules.data.get_content_categories import get_content_categories
from cognee.modules.data.get_content_summary import get_content_summary from cognee.modules.data.get_content_summary import get_content_summary
from cognee.modules.data.get_cognitive_layers import get_cognitive_layers from cognee.modules.data.get_cognitive_layers import get_cognitive_layers
from cognee.modules.data.get_layer_graphs import get_layer_graphs from cognee.modules.data.get_layer_graphs import get_layer_graphs
from cognee.shared.data_models import KnowledgeGraph from cognee.shared.data_models import KnowledgeGraph, ChunkStrategy
from cognee.shared.utils import send_telemetry from cognee.shared.utils import send_telemetry
from cognee.modules.tasks import create_task_status_table, update_task_status from cognee.modules.tasks import create_task_status_table, update_task_status
from cognee.shared.SourceCodeGraph import SourceCodeGraph from cognee.shared.SourceCodeGraph import SourceCodeGraph
@ -94,7 +98,7 @@ async def cognify(datasets: Union[str, List[str]] = None):
dataset_files.append((added_dataset, db_engine.get_files_metadata(added_dataset))) dataset_files.append((added_dataset, db_engine.get_files_metadata(added_dataset)))
chunk_config = get_chunk_config() chunk_config = get_chunk_config()
chunk_engine = chunk_config.chunk_engine chunk_engine = get_chunk_engine()
chunk_strategy = chunk_config.chunk_strategy chunk_strategy = chunk_config.chunk_strategy
async def process_batch(files_batch): async def process_batch(files_batch):
@ -245,52 +249,52 @@ async def process_text(chunk_collection: str, chunk_id: str, input_text: str, fi
# if __name__ == "__main__": if __name__ == "__main__":
# async def test(): async def test():
# # await prune.prune_system() # await prune.prune_system()
# # # # #
# # from cognee.api.v1.add import add # from cognee.api.v1.add import add
# # data_directory_path = os.path.abspath("../../../.data") # data_directory_path = os.path.abspath("../../../.data")
# # # print(data_directory_path) # # print(data_directory_path)
# # # config.data_root_directory(data_directory_path) # # config.data_root_directory(data_directory_path)
# # # cognee_directory_path = os.path.abspath("../.cognee_system") # # cognee_directory_path = os.path.abspath("../.cognee_system")
# # # config.system_root_directory(cognee_directory_path) # # config.system_root_directory(cognee_directory_path)
# # #
# # await add("data://" +data_directory_path, "example") # await add("data://" +data_directory_path, "example")
# text = """import subprocess text = """Conservative PP in the lead in Spain, according to estimate
# def show_all_processes(): An estimate has been published for Spain:
# process = subprocess.Popen(['ps', 'aux'], stdout=subprocess.PIPE)
# output, error = process.communicate()
# if error: Opposition leader Alberto Núñez Feijóos conservative Peoples party (PP): 32.4%
# print(f"Error: {error}")
# else:
# print(output.decode())
# show_all_processes()""" Spanish prime minister Pedro Sánchezs Socialist party (PSOE): 30.2%
The far-right Vox party: 10.4%
In Spain, the right has sought to turn the European election into a referendum on Sánchez.
Ahead of the vote, public attention has focused on a saga embroiling the prime ministers wife, Begoña Gómez, who is being investigated over allegations of corruption and influence-peddling, which Sanchez has dismissed as politically-motivated and totally baseless."""
# from cognee.api.v1.add import add from cognee.api.v1.add import add
# await add([text], "example_dataset") await add([text], "example_dataset")
# infrastructure_config.set_config( {"chunk_engine": LangchainChunkEngine() , "chunk_strategy": ChunkStrategy.CODE,'embedding_engine': LiteLLMEmbeddingEngine() }) from cognee.api.v1.config.config import config
# from cognee.shared.SourceCodeGraph import SourceCodeGraph config.set_chunk_engine(LangchainChunkEngine() )
# from cognee.api.v1.config import config config.set_chunk_strategy(ChunkStrategy.LANGCHAIN_CHARACTER)
config.embedding_engine = LiteLLMEmbeddingEngine()
# # config.set_graph_model(SourceCodeGraph) graph = await cognify()
# # config.set_classification_model(CodeContentPrediction) # vector_client = infrastructure_config.get_config("vector_engine")
# # graph = await cognify() #
# vector_client = infrastructure_config.get_config("vector_engine") # out = await vector_client.search(collection_name ="basic_rag", query_text="show_all_processes", limit=10)
#
# print("results", out)
#
# from cognee.shared.utils import render_graph
#
# await render_graph(graph, include_color=True, include_nodes=False, include_size=False)
# out = await vector_client.search(collection_name ="basic_rag", query_text="show_all_processes", limit=10) import asyncio
asyncio.run(test())
# print("results", out)
# #
# # from cognee.shared.utils import render_graph
# #
# # await render_graph(graph, include_color=True, include_nodes=False, include_size=False)
# import asyncio
# asyncio.run(test())

View file

@ -89,3 +89,19 @@ class config():
def set_chunk_strategy(chunk_strategy: object): def set_chunk_strategy(chunk_strategy: object):
chunk_config = get_chunk_config() chunk_config = get_chunk_config()
chunk_config.chunk_strategy = chunk_strategy chunk_config.chunk_strategy = chunk_strategy
@staticmethod
def set_chunk_engine(chunk_engine: object):
chunk_config = get_chunk_config()
chunk_config.chunk_engine = chunk_engine
@staticmethod
def set_chunk_overlap(chunk_overlap: object):
chunk_config = get_chunk_config()
chunk_config.chunk_overlap = chunk_overlap
@staticmethod
def set_chunk_size(chunk_size: object):
chunk_config = get_chunk_config()
chunk_config.chunk_size = chunk_size

View file

@ -5,6 +5,12 @@ from cognee.shared.data_models import ChunkStrategy
class DefaultChunkEngine(): class DefaultChunkEngine():
def __init__(self, chunk_strategy=None, chunk_size=None, chunk_overlap=None):
self.chunk_strategy = chunk_strategy
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
@staticmethod @staticmethod
def _split_text_with_regex( def _split_text_with_regex(
text: str, separator: str, keep_separator: bool text: str, separator: str, keep_separator: bool
@ -25,8 +31,8 @@ class DefaultChunkEngine():
return [s for s in splits if s != ""] return [s for s in splits if s != ""]
@staticmethod
def chunk_data( def chunk_data(self,
chunk_strategy = None, chunk_strategy = None,
source_data = None, source_data = None,
chunk_size = None, chunk_size = None,
@ -45,19 +51,19 @@ class DefaultChunkEngine():
- The chunked data. - The chunked data.
""" """
if chunk_strategy == ChunkStrategy.PARAGRAPH: if self.chunk_strategy == ChunkStrategy.PARAGRAPH:
chunked_data = DefaultChunkEngine.chunk_data_by_paragraph(source_data,chunk_size, chunk_overlap) chunked_data = self.chunk_data_by_paragraph(source_data,chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap)
elif chunk_strategy == ChunkStrategy.SENTENCE: elif self.chunk_strategy == ChunkStrategy.SENTENCE:
chunked_data = DefaultChunkEngine.chunk_by_sentence(source_data, chunk_size, chunk_overlap) chunked_data = self.chunk_by_sentence(source_data, chunk_size = self.chunk_size, chunk_overlap=self.chunk_overlap)
elif chunk_strategy == ChunkStrategy.EXACT: elif self.chunk_strategy == ChunkStrategy.EXACT:
chunked_data = DefaultChunkEngine.chunk_data_exact(source_data, chunk_size, chunk_overlap) chunked_data = self.chunk_data_exact(source_data, chunk_size = self.chunk_size, chunk_overlap=self.chunk_overlap)
return chunked_data return chunked_data
@staticmethod
def chunk_data_exact(data_chunks, chunk_size, chunk_overlap): def chunk_data_exact(self, data_chunks, chunk_size, chunk_overlap):
data = "".join(data_chunks) data = "".join(data_chunks)
chunks = [] chunks = []
for i in range(0, len(data), chunk_size - chunk_overlap): for i in range(0, len(data), chunk_size - chunk_overlap):
@ -65,8 +71,8 @@ class DefaultChunkEngine():
return chunks return chunks
@staticmethod
def chunk_by_sentence(data_chunks, chunk_size, overlap): def chunk_by_sentence(self, data_chunks, chunk_size, chunk_overlap):
# Split by periods, question marks, exclamation marks, and ellipses # Split by periods, question marks, exclamation marks, and ellipses
data = "".join(data_chunks) data = "".join(data_chunks)
@ -77,15 +83,15 @@ class DefaultChunkEngine():
sentence_chunks = [] sentence_chunks = []
for sentence in sentences: for sentence in sentences:
if len(sentence) > chunk_size: if len(sentence) > chunk_size:
chunks = DefaultChunkEngine.chunk_data_exact([sentence], chunk_size, overlap) chunks = self.chunk_data_exact(data_chunks=[sentence], chunk_size=chunk_size, chunk_overlap=chunk_overlap)
sentence_chunks.extend(chunks) sentence_chunks.extend(chunks)
else: else:
sentence_chunks.append(sentence) sentence_chunks.append(sentence)
return sentence_chunks return sentence_chunks
@staticmethod
def chunk_data_by_paragraph(data_chunks, chunk_size, overlap, bound = 0.75): def chunk_data_by_paragraph(self, data_chunks, chunk_size, chunk_overlap, bound = 0.75):
data = "".join(data_chunks) data = "".join(data_chunks)
total_length = len(data) total_length = len(data)
chunks = [] chunks = []
@ -108,7 +114,7 @@ class DefaultChunkEngine():
# Update end_idx to include the paragraph delimiter # Update end_idx to include the paragraph delimiter
end_idx = next_paragraph_index + 2 end_idx = next_paragraph_index + 2
end_index = end_idx + overlap end_index = end_idx + chunk_overlap
chunk_text = data[start_idx:end_index] chunk_text = data[start_idx:end_index]
@ -116,7 +122,7 @@ class DefaultChunkEngine():
chunk_text += data[end_index] chunk_text += data[end_index]
end_index += 1 end_index += 1
end_idx = end_index - overlap end_idx = end_index - chunk_overlap
chunks.append(chunk_text.replace("\n", "").strip()) chunks.append(chunk_text.replace("\n", "").strip())

View file

@ -0,0 +1,6 @@
class HaystackChunkEngine:
def __init__(self, chunk_strategy=None, source_data=None, chunk_size=None, chunk_overlap=None):
self.chunk_strategy = chunk_strategy
self.source_data = source_data
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap

View file

@ -6,9 +6,16 @@ from cognee.shared.data_models import ChunkStrategy
class LangchainChunkEngine(): class LangchainChunkEngine:
@staticmethod def __init__(self, chunk_strategy=None, source_data=None, chunk_size=None, chunk_overlap=None):
def chunk_data( self.chunk_strategy = chunk_strategy
self.source_data = source_data
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
def chunk_data(self,
chunk_strategy = None, chunk_strategy = None,
source_data = None, source_data = None,
chunk_size = None, chunk_size = None,
@ -28,16 +35,16 @@ class LangchainChunkEngine():
""" """
if chunk_strategy == ChunkStrategy.CODE: if chunk_strategy == ChunkStrategy.CODE:
chunked_data = LangchainChunkEngine.chunk_data_by_code(source_data,chunk_size, chunk_overlap) chunked_data = self.chunk_data_by_code(source_data,self.chunk_size, self.chunk_overlap)
elif chunk_strategy == ChunkStrategy.LANGCHAIN_CHARACTER: elif chunk_strategy == ChunkStrategy.LANGCHAIN_CHARACTER:
chunked_data = LangchainChunkEngine.chunk_data_by_character(source_data,chunk_size, chunk_overlap) chunked_data = self.chunk_data_by_character(source_data,self.chunk_size, self.chunk_overlap)
else: else:
chunked_data = DefaultChunkEngine.chunk_data_by_paragraph(source_data,chunk_size, chunk_overlap) chunked_data = "Invalid chunk strategy."
return chunked_data return chunked_data
@staticmethod
def chunk_data_by_code(data_chunks, chunk_size, chunk_overlap, language=None): def chunk_data_by_code(self, data_chunks, chunk_size, chunk_overlap= 10, language=None):
from langchain_text_splitters import ( from langchain_text_splitters import (
Language, Language,
RecursiveCharacterTextSplitter, RecursiveCharacterTextSplitter,
@ -53,10 +60,10 @@ class LangchainChunkEngine():
return only_content return only_content
def chunk_data_by_character(self, data_chunks, chunk_size, chunk_overlap): def chunk_data_by_character(self, data_chunks, chunk_size=1500, chunk_overlap=10):
from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_text_splitters import RecursiveCharacterTextSplitter
splitter = RecursiveCharacterTextSplitter(chunk_size, chunk_overlap) splitter = RecursiveCharacterTextSplitter(chunk_size =chunk_size, chunk_overlap=chunk_overlap)
data = splitter.split(data_chunks) data = splitter.create_documents([data_chunks])
only_content = [chunk.page_content for chunk in data] only_content = [chunk.page_content for chunk in data]

View file

@ -7,9 +7,9 @@ from cognee.shared.data_models import ChunkStrategy
class ChunkConfig(BaseSettings): class ChunkConfig(BaseSettings):
chunk_size: int = 1500 chunk_size: int = 1500
chunk_overlap: int = 0 chunk_overlap: int = 10
chunk_strategy: object = ChunkStrategy.PARAGRAPH chunk_strategy: object = ChunkStrategy.PARAGRAPH
chunk_engine: object = DefaultChunkEngine() chunk_engine = DefaultChunkEngine
model_config = SettingsConfigDict(env_file = ".env", extra = "allow") model_config = SettingsConfigDict(env_file = ".env", extra = "allow")

View file

@ -0,0 +1,36 @@
from typing import Dict
from cognee.infrastructure.data.chunking.LangchainChunkingEngine import LangchainChunkEngine
class ChunkingConfig(Dict):
vector_db_url: str
vector_db_key: str
vector_db_provider: str
def create_chunking_engine(config: ChunkingConfig):
if config["chunk_engine"] == "langchainchunkengine":
return LangchainChunkEngine(
chunk_size=config["chunk_size"],
chunk_overlap=config["chunk_overlap"],
chunk_strategy=config["chunk_strategy"],
)
elif config["chunk_engine"] == "defaultchunkengine":
from .DefaultChunkEngine import DefaultChunkEngine
return DefaultChunkEngine(
chunk_size=config["chunk_size"],
chunk_overlap=config["chunk_overlap"],
chunk_strategy=config["chunk_strategy"],
)
elif config["chunk_engine"] == "haystackchunkengine":
from .HaystackChunkEngine import HaystackChunkEngine
return HaystackChunkEngine(
chunk_size=config["chunk_size"],
chunk_overlap=config["chunk_overlap"],
chunk_strategy=config["chunk_strategy"],
)

View file

@ -0,0 +1,6 @@
from .config import get_chunk_config
from .create_chunking_engine import create_chunking_engine
def get_chunk_engine():
return create_chunking_engine(get_chunk_config().to_dict())

View file

@ -1,3 +1,5 @@
import string
import random
from typing import BinaryIO, Union from typing import BinaryIO, Union
from cognee.base_config import get_base_config from cognee.base_config import get_base_config
from cognee.infrastructure.files.storage import LocalStorage from cognee.infrastructure.files.storage import LocalStorage
@ -13,6 +15,10 @@ def save_data_to_file(data: Union[str, BinaryIO], dataset_name: str, filename: s
LocalStorage.ensure_directory_exists(storage_path) LocalStorage.ensure_directory_exists(storage_path)
file_metadata = classified_data.get_metadata() file_metadata = classified_data.get_metadata()
if "name" not in file_metadata or file_metadata["name"] is None:
letters = string.ascii_lowercase
random_string = ''.join(random.choice(letters) for _ in range(32))
file_metadata["name"] = "file" + random_string
file_name = file_metadata["name"] file_name = file_metadata["name"]
LocalStorage(storage_path).store(file_name, classified_data.get_data()) LocalStorage(storage_path).store(file_name, classified_data.get_data())

1148
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -38,7 +38,7 @@ greenlet = "^3.0.3"
ruff = "^0.2.2" ruff = "^0.2.2"
filetype = "^1.2.0" filetype = "^1.2.0"
nltk = "^3.8.1" nltk = "^3.8.1"
dlt = "0.4.11" dlt = "0.4.12"
duckdb = {version = "^0.10.0", extras = ["dlt"]} duckdb = {version = "^0.10.0", extras = ["dlt"]}
overrides = "^7.7.0" overrides = "^7.7.0"
aiofiles = "^23.2.1" aiofiles = "^23.2.1"
@ -72,7 +72,8 @@ spacy = "^3.7.4"
protobuf = "<5.0.0" protobuf = "<5.0.0"
pydantic-settings = "^2.2.1" pydantic-settings = "^2.2.1"
anthropic = "^0.26.1" anthropic = "^0.26.1"
langchain-community = "0.0.38" langchain-text-splitters = "^0.2.1"
[tool.poetry.extras] [tool.poetry.extras]