cognee/cognee/infrastructure/data/chunking/DefaultChunkEngine.py
2024-08-08 19:59:26 +02:00

150 lines
5.4 KiB
Python

""" Chunking strategies for splitting text into smaller parts."""
from __future__ import annotations
import re
from cognee.shared.data_models import ChunkStrategy
# /Users/vasa/Projects/cognee/cognee/infrastructure/data/chunking/DefaultChunkEngine.py
class DefaultChunkEngine():
def __init__(self, chunk_strategy=None, chunk_size=None, chunk_overlap=None):
self.chunk_strategy = chunk_strategy
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
@staticmethod
def _split_text_with_regex(
text: str, separator: str, keep_separator: bool
) -> list[str]:
# Now that we have the separator, split the text
if separator:
if keep_separator:
# The parentheses in the pattern keep the delimiters in the result.
_splits = re.split(f"({separator})", text)
splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
if len(_splits) % 2 == 0:
splits += _splits[-1:]
splits = [_splits[0]] + splits
else:
splits = re.split(separator, text)
else:
splits = list(text)
return [s for s in splits if s != ""]
def chunk_data(self,
chunk_strategy = None,
source_data = None,
chunk_size = None,
chunk_overlap = None,
):
"""
Chunk data based on the specified strategy.
Parameters:
- chunk_strategy: The strategy to use for chunking.
- source_data: The data to be chunked.
- chunk_size: The size of each chunk.
- chunk_overlap: The overlap between chunks.
Returns:
- The chunked data.
"""
if self.chunk_strategy == ChunkStrategy.PARAGRAPH:
chunked_data, chunk_number = self.chunk_data_by_paragraph(source_data,chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap)
elif self.chunk_strategy == ChunkStrategy.SENTENCE:
chunked_data, chunk_number = self.chunk_by_sentence(source_data, chunk_size = self.chunk_size, chunk_overlap=self.chunk_overlap)
elif self.chunk_strategy == ChunkStrategy.EXACT:
chunked_data, chunk_number = self.chunk_data_exact(source_data, chunk_size = self.chunk_size, chunk_overlap=self.chunk_overlap)
else:
chunked_data, chunk_number = "Invalid chunk strategy.", [0, "Invalid chunk strategy."]
return chunked_data, chunk_number
def chunk_data_exact(self, data_chunks, chunk_size, chunk_overlap):
data = "".join(data_chunks)
chunks = []
for i in range(0, len(data), chunk_size - chunk_overlap):
chunks.append(data[i:i + chunk_size])
numbered_chunks = []
for i, chunk in enumerate(chunks):
numbered_chunk = [i + 1, chunk]
numbered_chunks.append(numbered_chunk)
return chunks, numbered_chunks
def chunk_by_sentence(self, data_chunks, chunk_size, chunk_overlap):
# Split by periods, question marks, exclamation marks, and ellipses
data = "".join(data_chunks)
# The regular expression is used to find series of charaters that end with one the following chaacters (. ! ? ...)
sentence_endings = r'(?<=[.!?…]) +'
sentences = re.split(sentence_endings, data)
sentence_chunks = []
for sentence in sentences:
if len(sentence) > chunk_size:
chunks = self.chunk_data_exact(data_chunks=[sentence], chunk_size=chunk_size, chunk_overlap=chunk_overlap)
sentence_chunks.extend(chunks)
else:
sentence_chunks.append(sentence)
numbered_chunks = []
for i, chunk in enumerate(sentence_chunks):
numbered_chunk = [i + 1, chunk]
numbered_chunks.append(numbered_chunk)
return sentence_chunks, numbered_chunks
def chunk_data_by_paragraph(self, data_chunks, chunk_size, chunk_overlap, bound = 0.75):
data = "".join(data_chunks)
total_length = len(data)
chunks = []
check_bound = int(bound * chunk_size)
start_idx = 0
chunk_splitter = "\n\n"
if data.find("\n\n") == -1:
chunk_splitter = "\n"
while start_idx < total_length:
# Set the end index to the minimum of start_idx + default_chunk_size or total_length
end_idx = min(start_idx + chunk_size, total_length)
# Find the next paragraph index within the current chunk and bound
next_paragraph_index = data.find(chunk_splitter, start_idx + check_bound, end_idx)
# If a next paragraph index is found within the current chunk
if next_paragraph_index != -1:
# Update end_idx to include the paragraph delimiter
end_idx = next_paragraph_index + 2
end_index = end_idx + chunk_overlap
chunk_text = data[start_idx:end_index]
while chunk_text[-1] != "." and end_index < total_length:
chunk_text += data[end_index]
end_index += 1
end_idx = end_index - chunk_overlap
chunks.append(chunk_text.replace("\n", "").strip())
# Update start_idx to be the current end_idx
start_idx = end_idx
numbered_chunks = []
for i, chunk in enumerate(chunks):
numbered_chunk = [i + 1, chunk]
numbered_chunks.append(numbered_chunk)
return chunks, numbered_chunks