Merge pull request #72 from topoteretes/COG-165

feat: Add chunking
This commit is contained in:
Vasilije 2024-04-23 15:56:34 +02:00 committed by GitHub
commit 0aed5f27be
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 202 additions and 291 deletions

View file

@ -55,7 +55,7 @@ async def add_files(file_paths: List[str], dataset_name: str):
if data_directory_path not in file_path: if data_directory_path not in file_path:
file_name = file_path.split("/")[-1] file_name = file_path.split("/")[-1]
file_directory_path = data_directory_path + "/" + (dataset_name.replace('.', "/") + "/" if dataset_name != "root" else "") file_directory_path = data_directory_path + "/" + (dataset_name.replace(".", "/") + "/" if dataset_name != "root" else "")
dataset_file_path = path.join(file_directory_path, file_name) dataset_file_path = path.join(file_directory_path, file_name)
LocalStorage.ensure_directory_exists(file_directory_path) LocalStorage.ensure_directory_exists(file_directory_path)

View file

@ -1,8 +1,11 @@
import asyncio import asyncio
from uuid import uuid4
from typing import List, Union from typing import List, Union
import logging import logging
import instructor import instructor
from openai import OpenAI from openai import OpenAI
from nltk.corpus import stopwords
from cognee.config import Config
from cognee.modules.cognify.graph.add_data_chunks import add_data_chunks from cognee.modules.cognify.graph.add_data_chunks import add_data_chunks
from cognee.modules.cognify.graph.add_document_node import add_document_node from cognee.modules.cognify.graph.add_document_node import add_document_node
from cognee.modules.cognify.graph.add_classification_nodes import add_classification_nodes from cognee.modules.cognify.graph.add_classification_nodes import add_classification_nodes
@ -11,9 +14,6 @@ from cognee.modules.cognify.graph.add_summary_nodes import add_summary_nodes
from cognee.modules.cognify.graph.add_node_connections import group_nodes_by_layer, \ from cognee.modules.cognify.graph.add_node_connections import group_nodes_by_layer, \
graph_ready_output, connect_nodes_in_graph graph_ready_output, connect_nodes_in_graph
from cognee.modules.cognify.llm.resolve_cross_graph_references import resolve_cross_graph_references from cognee.modules.cognify.llm.resolve_cross_graph_references import resolve_cross_graph_references
from cognee.config import Config
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
from cognee.modules.cognify.graph.add_label_nodes import add_label_nodes from cognee.modules.cognify.graph.add_label_nodes import add_label_nodes
from cognee.modules.cognify.graph.add_cognitive_layers import add_cognitive_layers from cognee.modules.cognify.graph.add_cognitive_layers import add_cognitive_layers
@ -26,6 +26,7 @@ from cognee.modules.data.get_content_summary import get_content_summary
from cognee.modules.data.get_cognitive_layers import get_cognitive_layers from cognee.modules.data.get_cognitive_layers import get_cognitive_layers
from cognee.modules.data.get_layer_graphs import get_layer_graphs from cognee.modules.data.get_layer_graphs import get_layer_graphs
config = Config() config = Config()
config.load() config.load()
@ -33,10 +34,12 @@ aclient = instructor.patch(OpenAI())
USER_ID = "default_user" USER_ID = "default_user"
logger = logging.getLogger(__name__) logger = logging.getLogger("cognify")
async def cognify(datasets: Union[str, List[str]] = None): async def cognify(datasets: Union[str, List[str]] = None):
"""This function is responsible for the cognitive processing of the content.""" """This function is responsible for the cognitive processing of the content."""
# Has to be loaded in advance, multithreading doesn't work without it.
stopwords.ensure_loaded()
db_engine = infrastructure_config.get_config()["database_engine"] db_engine = infrastructure_config.get_config()["database_engine"]
@ -53,10 +56,10 @@ async def cognify(datasets: Union[str, List[str]] = None):
graphs = await asyncio.gather(*awaitables) graphs = await asyncio.gather(*awaitables)
return graphs[0] return graphs[0]
# datasets is a dataset name string
added_datasets = db_engine.get_datasets() added_datasets = db_engine.get_datasets()
dataset_files = [] dataset_files = []
# datasets is a dataset name string
dataset_name = datasets.replace(".", "_").replace(" ", "_") dataset_name = datasets.replace(".", "_").replace(" ", "_")
for added_dataset in added_datasets: for added_dataset in added_datasets:
@ -73,30 +76,35 @@ async def cognify(datasets: Union[str, List[str]] = None):
data_chunks = {} data_chunks = {}
chunk_engine = infrastructure_config.get_config()["chunk_engine"]
chunk_strategy = infrastructure_config.get_config()["chunk_strategy"]
for (dataset_name, files) in dataset_files: for (dataset_name, files) in dataset_files:
for file_metadata in files[:3]: for file_metadata in files:
with open(file_metadata["file_path"], "rb") as file: with open(file_metadata["file_path"], "rb") as file:
try: try:
file_type = guess_file_type(file) file_type = guess_file_type(file)
text = extract_text_from_file(file, file_type) text = extract_text_from_file(file, file_type)
subchunks = chunk_engine.chunk_data(chunk_strategy, text, config.chunk_size, config.chunk_overlap)
if dataset_name not in data_chunks: if dataset_name not in data_chunks:
data_chunks[dataset_name] = [] data_chunks[dataset_name] = []
data_chunks[dataset_name].append(dict(text = text, file_metadata = file_metadata)) for subchunk in subchunks:
data_chunks[dataset_name].append(dict(text = subchunk, chunk_id = str(uuid4()), file_metadata = file_metadata))
except FileTypeException: except FileTypeException:
logger.warning("File (%s) has an unknown file type. We are skipping it.", file_metadata["id"]) logger.warning("File (%s) has an unknown file type. We are skipping it.", file_metadata["id"])
added_chunks: list[tuple[str, str, dict]] = await add_data_chunks(data_chunks) added_chunks: list[tuple[str, str, dict]] = await add_data_chunks(data_chunks)
await asyncio.gather( await asyncio.gather(
*[process_text(chunk["collection"], chunk["id"], chunk["text"], chunk["file_metadata"]) for chunk in added_chunks] *[process_text(chunk["collection"], chunk["chunk_id"], chunk["text"], chunk["file_metadata"]) for chunk in added_chunks]
) )
return graph_client.graph return graph_client.graph
async def process_text(chunk_collection: str, chunk_id: str, input_text: str, file_metadata: dict): async def process_text(chunk_collection: str, chunk_id: str, input_text: str, file_metadata: dict):
print(f"Processing document ({file_metadata['id']}).") print(f"Processing chunk ({chunk_id}) from document ({file_metadata['id']}).")
graph_client = await get_graph_client(infrastructure_config.get_config()["graph_engine"]) graph_client = await get_graph_client(infrastructure_config.get_config()["graph_engine"])
@ -115,12 +123,12 @@ async def process_text(chunk_collection: str, chunk_id: str, input_text: str, fi
categories = classified_categories, categories = classified_categories,
) )
print(f"Document ({document_id}) classified.") print(f"Chunk ({chunk_id}) classified.")
content_summary = await get_content_summary(input_text) content_summary = await get_content_summary(input_text)
await add_summary_nodes(graph_client, document_id, content_summary) await add_summary_nodes(graph_client, document_id, content_summary)
print(f"Document ({document_id}) summarized.") print(f"Chunk ({chunk_id}) summarized.")
cognitive_layers = await get_cognitive_layers(input_text, classified_categories) cognitive_layers = await get_cognitive_layers(input_text, classified_categories)
cognitive_layers = (await add_cognitive_layers(graph_client, document_id, cognitive_layers))[:2] cognitive_layers = (await add_cognitive_layers(graph_client, document_id, cognitive_layers))[:2]
@ -132,8 +140,6 @@ async def process_text(chunk_collection: str, chunk_id: str, input_text: str, fi
db_engine = infrastructure_config.get_config()["database_engine"] db_engine = infrastructure_config.get_config()["database_engine"]
relevant_documents_to_connect = db_engine.fetch_cognify_data(excluded_document_id = file_metadata["id"]) relevant_documents_to_connect = db_engine.fetch_cognify_data(excluded_document_id = file_metadata["id"])
print("Relevant documents to connect are: ", relevant_documents_to_connect)
list_of_nodes = [] list_of_nodes = []
relevant_documents_to_connect.append({ relevant_documents_to_connect.append({
@ -144,13 +150,9 @@ async def process_text(chunk_collection: str, chunk_id: str, input_text: str, fi
node_descriptions_to_match = await graph_client.extract_node_description(document["layer_id"]) node_descriptions_to_match = await graph_client.extract_node_description(document["layer_id"])
list_of_nodes.extend(node_descriptions_to_match) list_of_nodes.extend(node_descriptions_to_match)
print("List of nodes are: ", len(list_of_nodes))
nodes_by_layer = await group_nodes_by_layer(list_of_nodes) nodes_by_layer = await group_nodes_by_layer(list_of_nodes)
print("Nodes by layer are: ", str(nodes_by_layer)[:5000])
results = await resolve_cross_graph_references(nodes_by_layer) results = await resolve_cross_graph_references(nodes_by_layer)
print("Results are: ", str(results)[:3000])
relationships = graph_ready_output(results) relationships = graph_ready_output(results)
@ -160,4 +162,4 @@ async def process_text(chunk_collection: str, chunk_id: str, input_text: str, fi
score_threshold = infrastructure_config.get_config()["intra_layer_score_treshold"] score_threshold = infrastructure_config.get_config()["intra_layer_score_treshold"]
) )
print(f"Document ({document_id}) cognified.") print(f"Chunk ({chunk_id}) cognified.")

View file

@ -67,4 +67,8 @@ class config():
infrastructure_config.set_config({ infrastructure_config.set_config({
"connect_documents": connect_documents "connect_documents": connect_documents
}) })
@staticmethod
def set_chunk_strategy(chunk_strategy: object):
infrastructure_config.set_config({
"chunk_strategy": chunk_strategy
})

View file

@ -8,7 +8,7 @@ from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from dotenv import load_dotenv from dotenv import load_dotenv
from cognee.root_dir import get_absolute_path from cognee.root_dir import get_absolute_path
from cognee.shared.data_models import ChunkStrategy
base_dir = Path(__file__).resolve().parent.parent base_dir = Path(__file__).resolve().parent.parent
# Load the .env file from the base directory # Load the .env file from the base directory
@ -116,6 +116,11 @@ class Config:
# Client ID # Client ID
anon_clientid: Optional[str] = field(default_factory=lambda: uuid.uuid4().hex) anon_clientid: Optional[str] = field(default_factory=lambda: uuid.uuid4().hex)
#Chunking parameters
chunk_size: int = 1500
chunk_overlap: int = 0
chunk_strategy: str = ChunkStrategy.PARAGRAPH
def load(self): def load(self):
"""Loads the configuration from a file or environment variables.""" """Loads the configuration from a file or environment variables."""
config = configparser.ConfigParser() config = configparser.ConfigParser()

View file

@ -6,6 +6,7 @@ from .databases.vector.embeddings.DefaultEmbeddingEngine import DefaultEmbedding
from .llm.llm_interface import LLMInterface from .llm.llm_interface import LLMInterface
from .llm.openai.adapter import OpenAIAdapter from .llm.openai.adapter import OpenAIAdapter
from .files.storage import LocalStorage from .files.storage import LocalStorage
from .data.chunking.DefaultChunkEngine import DefaultChunkEngine
from ..shared.data_models import GraphDBType, DefaultContentPrediction, KnowledgeGraph, SummarizedContent, \ from ..shared.data_models import GraphDBType, DefaultContentPrediction, KnowledgeGraph, SummarizedContent, \
LabeledContent, DefaultCognitiveLayer LabeledContent, DefaultCognitiveLayer
@ -30,6 +31,8 @@ class InfrastructureConfig():
connect_documents = config.connect_documents connect_documents = config.connect_documents
database_directory_path: str = None database_directory_path: str = None
database_file_path: str = None database_file_path: str = None
chunk_strategy = config.chunk_strategy
chunk_engine = None
def get_config(self, config_entity: str = None) -> dict: def get_config(self, config_entity: str = None) -> dict:
if (config_entity is None or config_entity == "database_engine") and self.database_engine is None: if (config_entity is None or config_entity == "database_engine") and self.database_engine is None:
@ -69,6 +72,12 @@ class InfrastructureConfig():
if self.connect_documents is None: if self.connect_documents is None:
self.connect_documents = config.connect_documents self.connect_documents = config.connect_documents
if self.chunk_strategy is None:
self.chunk_strategy = config.chunk_strategy
if self.chunk_engine is None:
self.chunk_engine = DefaultChunkEngine()
if (config_entity is None or config_entity == "llm_engine") and self.llm_engine is None: if (config_entity is None or config_entity == "llm_engine") and self.llm_engine is None:
self.llm_engine = OpenAIAdapter(config.openai_key, config.openai_model) self.llm_engine = OpenAIAdapter(config.openai_key, config.openai_model)
@ -120,7 +129,9 @@ class InfrastructureConfig():
"embedding_engine": self.embedding_engine, "embedding_engine": self.embedding_engine,
"connect_documents": self.connect_documents, "connect_documents": self.connect_documents,
"database_directory_path": self.database_directory_path, "database_directory_path": self.database_directory_path,
"database_path": self.database_file_path "database_path": self.database_file_path,
"chunk_strategy": self.chunk_strategy,
"chunk_engine": self.chunk_engine,
} }
def set_config(self, new_config: dict): def set_config(self, new_config: dict):
@ -169,4 +180,10 @@ class InfrastructureConfig():
if "connect_documents" in new_config: if "connect_documents" in new_config:
self.connect_documents = new_config["connect_documents"] self.connect_documents = new_config["connect_documents"]
if "chunk_strategy" in new_config:
self.chunk_strategy = new_config["chunk_strategy"]
if "chunk_engine" in new_config:
self.chunk_engine = new_config["chunk_engine"]
infrastructure_config = InfrastructureConfig() infrastructure_config = InfrastructureConfig()

View file

@ -0,0 +1,125 @@
""" Chunking strategies for splitting text into smaller parts."""
from __future__ import annotations
import re
from cognee.shared.data_models import ChunkStrategy
class DefaultChunkEngine():
@staticmethod
def _split_text_with_regex(
text: str, separator: str, keep_separator: bool
) -> list[str]:
# Now that we have the separator, split the text
if separator:
if keep_separator:
# The parentheses in the pattern keep the delimiters in the result.
_splits = re.split(f"({separator})", text)
splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
if len(_splits) % 2 == 0:
splits += _splits[-1:]
splits = [_splits[0]] + splits
else:
splits = re.split(separator, text)
else:
splits = list(text)
return [s for s in splits if s != ""]
@staticmethod
def chunk_data(
chunk_strategy = None,
source_data = None,
chunk_size = None,
chunk_overlap = None,
):
"""
Chunk data based on the specified strategy.
Parameters:
- chunk_strategy: The strategy to use for chunking.
- source_data: The data to be chunked.
- chunk_size: The size of each chunk.
- chunk_overlap: The overlap between chunks.
Returns:
- The chunked data.
"""
if chunk_strategy == ChunkStrategy.PARAGRAPH:
chunked_data = DefaultChunkEngine.chunk_data_by_paragraph(source_data,chunk_size, chunk_overlap)
elif chunk_strategy == ChunkStrategy.SENTENCE:
chunked_data = DefaultChunkEngine.chunk_by_sentence(source_data, chunk_size, chunk_overlap)
elif chunk_strategy == ChunkStrategy.EXACT:
chunked_data = DefaultChunkEngine.chunk_data_exact(source_data, chunk_size, chunk_overlap)
return chunked_data
@staticmethod
def chunk_data_exact(data_chunks, chunk_size, chunk_overlap):
data = "".join(data_chunks)
chunks = []
for i in range(0, len(data), chunk_size - chunk_overlap):
chunks.append(data[i:i + chunk_size])
return chunks
@staticmethod
def chunk_by_sentence(data_chunks, chunk_size, overlap):
# Split by periods, question marks, exclamation marks, and ellipses
data = "".join(data_chunks)
# The regular expression is used to find series of charaters that end with one the following chaacters (. ! ? ...)
sentence_endings = r'(?<=[.!?…]) +'
sentences = re.split(sentence_endings, data)
sentence_chunks = []
for sentence in sentences:
if len(sentence) > chunk_size:
chunks = DefaultChunkEngine.chunk_data_exact([sentence], chunk_size, overlap)
sentence_chunks.extend(chunks)
else:
sentence_chunks.append(sentence)
return sentence_chunks
@staticmethod
def chunk_data_by_paragraph(data_chunks, chunk_size, overlap, bound = 0.75):
data = "".join(data_chunks)
total_length = len(data)
chunks = []
check_bound = int(bound * chunk_size)
start_idx = 0
chunk_splitter = "\n\n"
if data.find("\n\n") == -1:
chunk_splitter = "\n"
while start_idx < total_length:
# Set the end index to the minimum of start_idx + default_chunk_size or total_length
end_idx = min(start_idx + chunk_size, total_length)
# Find the next paragraph index within the current chunk and bound
next_paragraph_index = data.find(chunk_splitter, start_idx + check_bound, end_idx)
# If a next paragraph index is found within the current chunk
if next_paragraph_index != -1:
# Update end_idx to include the paragraph delimiter
end_idx = next_paragraph_index + 2
end_index = end_idx + overlap
chunk_text = data[start_idx:end_index]
while chunk_text[-1] != "." and end_index < total_length:
chunk_text += data[end_index]
end_index += 1
end_idx = end_index - overlap
chunks.append(chunk_text.replace("\n", "").strip())
# Update start_idx to be the current end_idx
start_idx = end_idx
return chunks

View file

@ -1,10 +1,10 @@
from typing import TypedDict from typing import TypedDict
from uuid import uuid4
from cognee.infrastructure import infrastructure_config from cognee.infrastructure import infrastructure_config
from cognee.infrastructure.databases.vector import DataPoint from cognee.infrastructure.databases.vector import DataPoint
class TextChunk(TypedDict): class TextChunk(TypedDict):
text: str text: str
chunk_id: str
file_metadata: dict file_metadata: dict
async def add_data_chunks(dataset_data_chunks: dict[str, list[TextChunk]]): async def add_data_chunks(dataset_data_chunks: dict[str, list[TextChunk]]):
@ -20,7 +20,7 @@ async def add_data_chunks(dataset_data_chunks: dict[str, list[TextChunk]]):
dataset_chunks = [ dataset_chunks = [
dict( dict(
id = str(uuid4()), chunk_id = chunk["chunk_id"],
collection = dataset_name, collection = dataset_name,
text = chunk["text"], text = chunk["text"],
file_metadata = chunk["file_metadata"], file_metadata = chunk["file_metadata"],
@ -33,7 +33,7 @@ async def add_data_chunks(dataset_data_chunks: dict[str, list[TextChunk]]):
dataset_name, dataset_name,
[ [
DataPoint( DataPoint(
id = chunk["id"], id = chunk["chunk_id"],
payload = dict(text = chunk["text"]), payload = dict(text = chunk["text"]),
embed_field = "text" embed_field = "text"
) for chunk in dataset_chunks ) for chunk in dataset_chunks

View file

@ -1,243 +0,0 @@
""" Chunking strategies for splitting text into smaller parts."""
from __future__ import annotations
from cognee.shared.data_models import ChunkStrategy
import re
from typing import Any, List, Optional
class CharacterTextSplitter():
"""Splitting text that looks at characters."""
def __init__(
self, separator: str = "\n\n", is_separator_regex: bool = False, **kwargs: Any
) -> None:
"""Create a new TextSplitter."""
super().__init__(**kwargs)
self._separator = separator
self._is_separator_regex = is_separator_regex
def split_text(self, text: str) -> List[str]:
"""Split incoming text and return chunks."""
# First we naively split the large input into a bunch of smaller ones.
separator = (
self._separator if self._is_separator_regex else re.escape(self._separator)
)
splits = _split_text_with_regex(text, separator, self._keep_separator)
_separator = "" if self._keep_separator else self._separator
return self._merge_splits(splits, _separator)
def _split_text_with_regex(
text: str, separator: str, keep_separator: bool
) -> List[str]:
# Now that we have the separator, split the text
if separator:
if keep_separator:
# The parentheses in the pattern keep the delimiters in the result.
_splits = re.split(f"({separator})", text)
splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
if len(_splits) % 2 == 0:
splits += _splits[-1:]
splits = [_splits[0]] + splits
else:
splits = re.split(separator, text)
else:
splits = list(text)
return [s for s in splits if s != ""]
class RecursiveCharacterTextSplitter():
"""Splitting text by recursively look at characters.
Recursively tries to split by different characters to find one
that works.
"""
def __init__(
self,
separators: Optional[List[str]] = None,
keep_separator: bool = True,
is_separator_regex: bool = False,
**kwargs: Any,
) -> None:
"""Create a new TextSplitter."""
super().__init__(keep_separator=keep_separator, **kwargs)
self._separators = separators or ["\n\n", "\n", " ", ""]
self._is_separator_regex = is_separator_regex
def _split_text(self, text: str, separators: List[str]) -> List[str]:
"""Split incoming text and return chunks."""
final_chunks = []
# Get appropriate separator to use
separator = separators[-1]
new_separators = []
for i, _s in enumerate(separators):
_separator = _s if self._is_separator_regex else re.escape(_s)
if _s == "":
separator = _s
break
if re.search(_separator, text):
separator = _s
new_separators = separators[i + 1 :]
break
_separator = separator if self._is_separator_regex else re.escape(separator)
splits = _split_text_with_regex(text, _separator, self._keep_separator)
# Now go merging things, recursively splitting longer texts.
_good_splits = []
_separator = "" if self._keep_separator else separator
for s in splits:
if self._length_function(s) < self._chunk_size:
_good_splits.append(s)
else:
if _good_splits:
merged_text = self._merge_splits(_good_splits, _separator)
final_chunks.extend(merged_text)
_good_splits = []
if not new_separators:
final_chunks.append(s)
else:
other_info = self._split_text(s, new_separators)
final_chunks.extend(other_info)
if _good_splits:
merged_text = self._merge_splits(_good_splits, _separator)
final_chunks.extend(merged_text)
return final_chunks
def split_text(self, text: str) -> List[str]:
return self._split_text(text, self._separators)
def chunk_data(chunk_strategy=None, source_data=None, chunk_size=None, chunk_overlap=None):
"""
Chunk data based on the specified strategy.
Parameters:
- chunk_strategy: The strategy to use for chunking.
- source_data: The data to be chunked.
- chunk_size: The size of each chunk.
- chunk_overlap: The overlap between chunks.
Returns:
- The chunked data.
"""
if chunk_strategy == ChunkStrategy.VANILLA:
chunked_data = vanilla_chunker(source_data, chunk_size, chunk_overlap)
elif chunk_strategy == ChunkStrategy.PARAGRAPH:
chunked_data = chunk_data_by_paragraph(source_data,chunk_size, chunk_overlap)
elif chunk_strategy == ChunkStrategy.SENTENCE:
chunked_data = chunk_by_sentence(source_data, chunk_size, chunk_overlap)
elif chunk_strategy == ChunkStrategy.EXACT:
chunked_data = chunk_data_exact(source_data, chunk_size, chunk_overlap)
elif chunk_strategy == ChunkStrategy.SUMMARY:
chunked_data = summary_chunker(source_data, chunk_size, chunk_overlap)
else:
chunked_data = vanilla_chunker(source_data, chunk_size, chunk_overlap)
return chunked_data
def vanilla_chunker(source_data, chunk_size=100, chunk_overlap=20):
# adapt this for different chunking strategies
text_splitter = RecursiveCharacterTextSplitter(
# Set a really small chunk size, just to show.
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=len
)
# try:
# pages = text_splitter.create_documents([source_data])
# except:
# try:
pages = text_splitter.create_documents([source_data])
# except:
# pages = text_splitter.create_documents(source_data.content)
# pages = source_data.load_and_split()
return pages
def summary_chunker(source_data, chunk_size=400, chunk_overlap=20):
"""
Chunk the given source data into smaller parts, returning the first five and last five chunks.
Parameters:
- source_data (str): The source data to be chunked.
- chunk_size (int): The size of each chunk.
- chunk_overlap (int): The overlap between consecutive chunks.
Returns:
- List: A list containing the first five and last five chunks of the chunked source data.
"""
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=len
)
try:
pages = text_splitter.create_documents([source_data])
except:
pages = text_splitter.create_documents(source_data.content)
# Return the first 5 and last 5 chunks
if len(pages) > 10:
return pages[:5] + pages[-5:]
else:
return pages # Return all chunks if there are 10 or fewer
def chunk_data_exact(data_chunks, chunk_size, chunk_overlap):
data = "".join(data_chunks)
chunks = []
for i in range(0, len(data), chunk_size - chunk_overlap):
chunks.append(data[i:i + chunk_size])
return chunks
def chunk_by_sentence(data_chunks, chunk_size, overlap):
# Split by periods, question marks, exclamation marks, and ellipses
data = "".join(data_chunks)
# The regular expression is used to find series of charaters that end with one the following chaacters (. ! ? ...)
sentence_endings = r'(?<=[.!?…]) +'
sentences = re.split(sentence_endings, data)
sentence_chunks = []
for sentence in sentences:
if len(sentence) > chunk_size:
chunks = chunk_data_exact([sentence], chunk_size, overlap)
sentence_chunks.extend(chunks)
else:
sentence_chunks.append(sentence)
return sentence_chunks
def chunk_data_by_paragraph(data_chunks, chunk_size, overlap, bound=0.75):
data = "".join(data_chunks)
total_length = len(data)
chunks = []
check_bound = int(bound * chunk_size)
start_idx = 0
while start_idx < total_length:
# Set the end index to the minimum of start_idx + default_chunk_size or total_length
end_idx = min(start_idx + chunk_size, total_length)
# Find the next paragraph index within the current chunk and bound
next_paragraph_index = data.find('\n\n', start_idx + check_bound, end_idx)
# If a next paragraph index is found within the current chunk
if next_paragraph_index != -1:
# Update end_idx to include the paragraph delimiter
end_idx = next_paragraph_index + 2
chunks.append(data[start_idx:end_idx + overlap])
# Update start_idx to be the current end_idx
start_idx = end_idx
return chunks

View file

@ -36,8 +36,6 @@ class ChunkStrategy(Enum):
EXACT = "exact" EXACT = "exact"
PARAGRAPH = "paragraph" PARAGRAPH = "paragraph"
SENTENCE = "sentence" SENTENCE = "sentence"
VANILLA = "vanilla"
SUMMARY = "summary"
class MemorySummary(BaseModel): class MemorySummary(BaseModel):
""" Memory summary. """ """ Memory summary. """

View file

@ -23,24 +23,24 @@
"colbertv2_wiki17_abstracts = dspy.ColBERTv2(url = \"http://20.102.90.50:2017/wiki17_abstracts\")\n", "colbertv2_wiki17_abstracts = dspy.ColBERTv2(url = \"http://20.102.90.50:2017/wiki17_abstracts\")\n",
"dspy.configure(rm = colbertv2_wiki17_abstracts)\n", "dspy.configure(rm = colbertv2_wiki17_abstracts)\n",
"\n", "\n",
"dataset = HotPotQA(\n", "# dataset = HotPotQA(\n",
" train_seed = 1,\n", "# train_seed = 1,\n",
" train_size = 10,\n", "# train_size = 10,\n",
" eval_seed = 2023,\n", "# eval_seed = 2023,\n",
" dev_size = 0,\n", "# dev_size = 0,\n",
" test_size = 0,\n", "# test_size = 0,\n",
" keep_details = True,\n", "# keep_details = True,\n",
")\n", "# )\n",
"\n", "\n",
"texts_to_add = []\n", "# texts_to_add = []\n",
"\n", "\n",
"for train_case in dataset.train:\n", "# for train_case in dataset.train:\n",
" train_case_text = \"\\r\\n\".join(\" \".join(context_sentences) for context_sentences in train_case.get(\"context\")[\"sentences\"])\n", "# train_case_text = \"\\r\\n\".join(\" \".join(context_sentences) for context_sentences in train_case.get(\"context\")[\"sentences\"])\n",
"\n", "\n",
" texts_to_add.append(train_case_text)\n", "# texts_to_add.append(train_case_text)\n",
"\n", "\n",
"dataset_name = \"train_dataset\"\n", "dataset_name = \"short_stories\"\n",
"await cognee.add(texts_to_add, dataset_name)\n" "await cognee.add(\"data://\" + data_directory_path, dataset_name)\n"
] ]
}, },
{ {
@ -61,7 +61,7 @@
"\n", "\n",
"print(cognee.datasets.list_datasets())\n", "print(cognee.datasets.list_datasets())\n",
"\n", "\n",
"train_dataset = cognee.datasets.query_data('train_dataset')\n", "train_dataset = cognee.datasets.query_data(\"short_stories\")\n",
"print(len(train_dataset))" "print(len(train_dataset))"
] ]
}, },
@ -73,8 +73,11 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"from os import path\n", "from os import path\n",
"import logging\n",
"import cognee\n", "import cognee\n",
"\n", "\n",
"logging.basicConfig(level = logging.INFO)\n",
"\n",
"await cognee.prune.prune_system()\n", "await cognee.prune.prune_system()\n",
"\n", "\n",
"data_directory_path = path.abspath(\"../.data\")\n", "data_directory_path = path.abspath(\"../.data\")\n",
@ -83,7 +86,7 @@
"cognee_directory_path = path.abspath(\"../.cognee_system\")\n", "cognee_directory_path = path.abspath(\"../.cognee_system\")\n",
"cognee.config.system_root_directory(cognee_directory_path)\n", "cognee.config.system_root_directory(cognee_directory_path)\n",
"\n", "\n",
"await cognee.cognify('train_dataset')" "await cognee.cognify('short_stories')"
] ]
}, },
{ {
@ -151,10 +154,10 @@
"graph_client = await get_graph_client(GraphDBType.NETWORKX)\n", "graph_client = await get_graph_client(GraphDBType.NETWORKX)\n",
"graph = graph_client.graph\n", "graph = graph_client.graph\n",
"\n", "\n",
"results = await search_similarity(\"Who is Ernie Grunwald?\", graph)\n", "results = await search_similarity(\"Who are French girls?\", graph)\n",
"\n", "\n",
"for result in results:\n", "for result in results:\n",
" print(\"Ernie Grunwald\" in result)\n", " print(\"French girls\" in result)\n",
" print(result)" " print(result)"
] ]
} }