fix: move chunker to adapter
This commit is contained in:
parent
1c4caa9ee8
commit
9cbf450849
8 changed files with 179 additions and 327 deletions
|
|
@ -55,7 +55,7 @@ async def add_files(file_paths: List[str], dataset_name: str):
|
|||
|
||||
if data_directory_path not in file_path:
|
||||
file_name = file_path.split("/")[-1]
|
||||
file_directory_path = data_directory_path + "/" + (dataset_name.replace('.', "/") + "/" if dataset_name != "root" else "")
|
||||
file_directory_path = data_directory_path + "/" + (dataset_name.replace(".", "/") + "/" if dataset_name != "root" else "")
|
||||
dataset_file_path = path.join(file_directory_path, file_name)
|
||||
|
||||
LocalStorage.ensure_directory_exists(file_directory_path)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
import asyncio
|
||||
import uuid
|
||||
from os import path
|
||||
from uuid import uuid4
|
||||
from typing import List, Union
|
||||
import logging
|
||||
import instructor
|
||||
from openai import OpenAI
|
||||
from nltk.corpus import stopwords
|
||||
from cognee.config import Config
|
||||
from cognee.modules.cognify.graph.add_data_chunks import add_data_chunks
|
||||
from cognee.modules.cognify.graph.add_document_node import add_document_node
|
||||
from cognee.modules.cognify.graph.add_classification_nodes import add_classification_nodes
|
||||
|
|
@ -13,9 +14,6 @@ from cognee.modules.cognify.graph.add_summary_nodes import add_summary_nodes
|
|||
from cognee.modules.cognify.graph.add_node_connections import group_nodes_by_layer, \
|
||||
graph_ready_output, connect_nodes_in_graph
|
||||
from cognee.modules.cognify.llm.resolve_cross_graph_references import resolve_cross_graph_references
|
||||
|
||||
from cognee.config import Config
|
||||
|
||||
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
|
||||
from cognee.modules.cognify.graph.add_label_nodes import add_label_nodes
|
||||
from cognee.modules.cognify.graph.add_cognitive_layers import add_cognitive_layers
|
||||
|
|
@ -27,8 +25,7 @@ from cognee.modules.data.get_content_categories import get_content_categories
|
|||
from cognee.modules.data.get_content_summary import get_content_summary
|
||||
from cognee.modules.data.get_cognitive_layers import get_cognitive_layers
|
||||
from cognee.modules.data.get_layer_graphs import get_layer_graphs
|
||||
from cognee.modules.ingestion.chunkers import chunk_data
|
||||
from cognee.shared.data_models import ChunkStrategy
|
||||
|
||||
|
||||
config = Config()
|
||||
config.load()
|
||||
|
|
@ -37,10 +34,12 @@ aclient = instructor.patch(OpenAI())
|
|||
|
||||
USER_ID = "default_user"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger("cognify")
|
||||
|
||||
async def cognify(datasets: Union[str, List[str]] = None):
|
||||
"""This function is responsible for the cognitive processing of the content."""
|
||||
# Has to be loaded in advance, multithreading doesn't work without it.
|
||||
stopwords.ensure_loaded()
|
||||
|
||||
db_engine = infrastructure_config.get_config()["database_engine"]
|
||||
|
||||
|
|
@ -57,10 +56,10 @@ async def cognify(datasets: Union[str, List[str]] = None):
|
|||
graphs = await asyncio.gather(*awaitables)
|
||||
return graphs[0]
|
||||
|
||||
# datasets is a dataset name string
|
||||
added_datasets = db_engine.get_datasets()
|
||||
|
||||
dataset_files = []
|
||||
# datasets is a dataset name string
|
||||
dataset_name = datasets.replace(".", "_").replace(" ", "_")
|
||||
|
||||
for added_dataset in added_datasets:
|
||||
|
|
@ -77,29 +76,27 @@ async def cognify(datasets: Union[str, List[str]] = None):
|
|||
|
||||
data_chunks = {}
|
||||
|
||||
chunk_engine = infrastructure_config.get_config()["chunk_engine"]
|
||||
chunk_strategy = infrastructure_config.get_config()["chunk_strategy"]
|
||||
|
||||
for (dataset_name, files) in dataset_files:
|
||||
for file_metadata in files[:3]:
|
||||
for file_metadata in files:
|
||||
with open(file_metadata["file_path"], "rb") as file:
|
||||
try:
|
||||
file_type = guess_file_type(file)
|
||||
text = extract_text_from_file(file, file_type)
|
||||
subchunks = chunk_data(chunk_strategy, text, config.chunk_size, config.chunk_overlap)
|
||||
subchunks = chunk_engine.chunk_data(chunk_strategy, text, config.chunk_size, config.chunk_overlap)
|
||||
|
||||
if dataset_name not in data_chunks:
|
||||
data_chunks[dataset_name] = []
|
||||
|
||||
for subchunk in subchunks:
|
||||
|
||||
data_chunks[dataset_name].append(dict(text = subchunk, chunk_id=str(uuid.uuid4()), file_metadata = file_metadata))
|
||||
data_chunks[dataset_name].append(dict(text = subchunk, chunk_id = str(uuid4()), file_metadata = file_metadata))
|
||||
except FileTypeException:
|
||||
logger.warning("File (%s) has an unknown file type. We are skipping it.", file_metadata["id"])
|
||||
print("Added chunks are: ", data_chunks)
|
||||
|
||||
added_chunks: list[tuple[str, str, dict]] = await add_data_chunks(data_chunks)
|
||||
|
||||
|
||||
|
||||
await asyncio.gather(
|
||||
*[process_text(chunk["collection"], chunk["chunk_id"], chunk["text"], chunk["file_metadata"]) for chunk in added_chunks]
|
||||
)
|
||||
|
|
@ -107,7 +104,7 @@ async def cognify(datasets: Union[str, List[str]] = None):
|
|||
return graph_client.graph
|
||||
|
||||
async def process_text(chunk_collection: str, chunk_id: str, input_text: str, file_metadata: dict):
|
||||
print(f"Processing document ({file_metadata['id']}).")
|
||||
print(f"Processing chunk ({chunk_id}) from document ({file_metadata['id']}).")
|
||||
|
||||
graph_client = await get_graph_client(infrastructure_config.get_config()["graph_engine"])
|
||||
|
||||
|
|
@ -126,12 +123,12 @@ async def process_text(chunk_collection: str, chunk_id: str, input_text: str, fi
|
|||
categories = classified_categories,
|
||||
)
|
||||
|
||||
print(f"Document ({document_id}) classified.")
|
||||
print(f"Chunk ({chunk_id}) classified.")
|
||||
|
||||
content_summary = await get_content_summary(input_text)
|
||||
await add_summary_nodes(graph_client, document_id, content_summary)
|
||||
|
||||
print(f"Document ({document_id}) summarized.")
|
||||
print(f"Chunk ({chunk_id}) summarized.")
|
||||
|
||||
cognitive_layers = await get_cognitive_layers(input_text, classified_categories)
|
||||
cognitive_layers = (await add_cognitive_layers(graph_client, document_id, cognitive_layers))[:2]
|
||||
|
|
@ -143,8 +140,6 @@ async def process_text(chunk_collection: str, chunk_id: str, input_text: str, fi
|
|||
db_engine = infrastructure_config.get_config()["database_engine"]
|
||||
relevant_documents_to_connect = db_engine.fetch_cognify_data(excluded_document_id = file_metadata["id"])
|
||||
|
||||
print("Relevant documents to connect are: ", relevant_documents_to_connect)
|
||||
|
||||
list_of_nodes = []
|
||||
|
||||
relevant_documents_to_connect.append({
|
||||
|
|
@ -155,13 +150,9 @@ async def process_text(chunk_collection: str, chunk_id: str, input_text: str, fi
|
|||
node_descriptions_to_match = await graph_client.extract_node_description(document["layer_id"])
|
||||
list_of_nodes.extend(node_descriptions_to_match)
|
||||
|
||||
print("List of nodes are: ", len(list_of_nodes))
|
||||
|
||||
nodes_by_layer = await group_nodes_by_layer(list_of_nodes)
|
||||
print("Nodes by layer are: ", str(nodes_by_layer)[:5000])
|
||||
|
||||
results = await resolve_cross_graph_references(nodes_by_layer)
|
||||
print("Results are: ", str(results)[:3000])
|
||||
|
||||
relationships = graph_ready_output(results)
|
||||
|
||||
|
|
@ -171,35 +162,4 @@ async def process_text(chunk_collection: str, chunk_id: str, input_text: str, fi
|
|||
score_threshold = infrastructure_config.get_config()["intra_layer_score_treshold"]
|
||||
)
|
||||
|
||||
print(f"Document ({document_id}) cognified.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
text = """Natural language processing (NLP) is an interdisciplinary
|
||||
subfield of computer science and information retrieval"""
|
||||
|
||||
from cognee.api.v1.add.add import add
|
||||
|
||||
data_path = path.abspath(".data")
|
||||
async def add_(text):
|
||||
await add("data://" + "/Users/vasa/Projects/cognee/cognee/.data", "explanations")
|
||||
|
||||
|
||||
asyncio.run(add_(text))
|
||||
asyncio.run(cognify("explanations"))
|
||||
|
||||
import cognee
|
||||
|
||||
# datasets = cognee.datasets.list_datasets()
|
||||
# print(datasets)
|
||||
# # print(vv)
|
||||
# for dataset in datasets:
|
||||
# print(dataset)
|
||||
# data_from_dataset = cognee.datasets.query_data(dataset)
|
||||
# for file_info in data_from_dataset:
|
||||
# print(file_info)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
print(f"Chunk ({chunk_id}) cognified.")
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from .databases.vector.embeddings.DefaultEmbeddingEngine import DefaultEmbedding
|
|||
from .llm.llm_interface import LLMInterface
|
||||
from .llm.openai.adapter import OpenAIAdapter
|
||||
from .files.storage import LocalStorage
|
||||
from .data.chunking.DefaultChunkEngine import DefaultChunkEngine
|
||||
from ..shared.data_models import GraphDBType, DefaultContentPrediction, KnowledgeGraph, SummarizedContent, \
|
||||
LabeledContent, DefaultCognitiveLayer
|
||||
|
||||
|
|
@ -31,6 +32,7 @@ class InfrastructureConfig():
|
|||
database_directory_path: str = None
|
||||
database_file_path: str = None
|
||||
chunk_strategy = config.chunk_strategy
|
||||
chunk_engine = None
|
||||
|
||||
def get_config(self, config_entity: str = None) -> dict:
|
||||
if (config_entity is None or config_entity == "database_engine") and self.database_engine is None:
|
||||
|
|
@ -73,6 +75,9 @@ class InfrastructureConfig():
|
|||
if self.chunk_strategy is None:
|
||||
self.chunk_strategy = config.chunk_strategy
|
||||
|
||||
if self.chunk_engine is None:
|
||||
self.chunk_engine = DefaultChunkEngine()
|
||||
|
||||
if (config_entity is None or config_entity == "llm_engine") and self.llm_engine is None:
|
||||
self.llm_engine = OpenAIAdapter(config.openai_key, config.openai_model)
|
||||
|
||||
|
|
@ -125,16 +130,17 @@ class InfrastructureConfig():
|
|||
"connect_documents": self.connect_documents,
|
||||
"database_directory_path": self.database_directory_path,
|
||||
"database_path": self.database_file_path,
|
||||
"chunk_strategy": self.chunk_strategy
|
||||
"chunk_strategy": self.chunk_strategy,
|
||||
"chunk_engine": self.chunk_engine,
|
||||
}
|
||||
|
||||
def set_config(self, new_config: dict):
|
||||
if "system_root_directory" in new_config:
|
||||
self.system_root_directory = new_config["system_root_directory"]
|
||||
|
||||
|
||||
if "data_root_directory" in new_config:
|
||||
self.data_root_directory = new_config["data_root_directory"]
|
||||
|
||||
|
||||
if "database_engine" in new_config:
|
||||
self.database_engine = new_config["database_engine"]
|
||||
|
||||
|
|
@ -177,4 +183,7 @@ class InfrastructureConfig():
|
|||
if "chunk_strategy" in new_config:
|
||||
self.chunk_strategy = new_config["chunk_strategy"]
|
||||
|
||||
if "chunk_engine" in new_config:
|
||||
self.chunk_engine = new_config["chunk_engine"]
|
||||
|
||||
infrastructure_config = InfrastructureConfig()
|
||||
|
|
|
|||
125
cognee/infrastructure/data/chunking/DefaultChunkEngine.py
Normal file
125
cognee/infrastructure/data/chunking/DefaultChunkEngine.py
Normal file
|
|
@ -0,0 +1,125 @@
|
|||
""" Chunking strategies for splitting text into smaller parts."""
|
||||
from __future__ import annotations
|
||||
import re
|
||||
from cognee.shared.data_models import ChunkStrategy
|
||||
|
||||
|
||||
class DefaultChunkEngine():
|
||||
@staticmethod
|
||||
def _split_text_with_regex(
|
||||
text: str, separator: str, keep_separator: bool
|
||||
) -> list[str]:
|
||||
# Now that we have the separator, split the text
|
||||
if separator:
|
||||
if keep_separator:
|
||||
# The parentheses in the pattern keep the delimiters in the result.
|
||||
_splits = re.split(f"({separator})", text)
|
||||
splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
|
||||
if len(_splits) % 2 == 0:
|
||||
splits += _splits[-1:]
|
||||
splits = [_splits[0]] + splits
|
||||
else:
|
||||
splits = re.split(separator, text)
|
||||
else:
|
||||
splits = list(text)
|
||||
return [s for s in splits if s != ""]
|
||||
|
||||
|
||||
@staticmethod
|
||||
def chunk_data(
|
||||
chunk_strategy = None,
|
||||
source_data = None,
|
||||
chunk_size = None,
|
||||
chunk_overlap = None,
|
||||
):
|
||||
"""
|
||||
Chunk data based on the specified strategy.
|
||||
|
||||
Parameters:
|
||||
- chunk_strategy: The strategy to use for chunking.
|
||||
- source_data: The data to be chunked.
|
||||
- chunk_size: The size of each chunk.
|
||||
- chunk_overlap: The overlap between chunks.
|
||||
|
||||
Returns:
|
||||
- The chunked data.
|
||||
"""
|
||||
|
||||
if chunk_strategy == ChunkStrategy.PARAGRAPH:
|
||||
chunked_data = DefaultChunkEngine.chunk_data_by_paragraph(source_data,chunk_size, chunk_overlap)
|
||||
elif chunk_strategy == ChunkStrategy.SENTENCE:
|
||||
chunked_data = DefaultChunkEngine.chunk_by_sentence(source_data, chunk_size, chunk_overlap)
|
||||
elif chunk_strategy == ChunkStrategy.EXACT:
|
||||
chunked_data = DefaultChunkEngine.chunk_data_exact(source_data, chunk_size, chunk_overlap)
|
||||
|
||||
return chunked_data
|
||||
|
||||
|
||||
@staticmethod
|
||||
def chunk_data_exact(data_chunks, chunk_size, chunk_overlap):
|
||||
data = "".join(data_chunks)
|
||||
chunks = []
|
||||
for i in range(0, len(data), chunk_size - chunk_overlap):
|
||||
chunks.append(data[i:i + chunk_size])
|
||||
return chunks
|
||||
|
||||
|
||||
@staticmethod
|
||||
def chunk_by_sentence(data_chunks, chunk_size, overlap):
|
||||
# Split by periods, question marks, exclamation marks, and ellipses
|
||||
data = "".join(data_chunks)
|
||||
|
||||
# The regular expression is used to find series of charaters that end with one the following chaacters (. ! ? ...)
|
||||
sentence_endings = r'(?<=[.!?…]) +'
|
||||
sentences = re.split(sentence_endings, data)
|
||||
|
||||
sentence_chunks = []
|
||||
for sentence in sentences:
|
||||
if len(sentence) > chunk_size:
|
||||
chunks = DefaultChunkEngine.chunk_data_exact([sentence], chunk_size, overlap)
|
||||
sentence_chunks.extend(chunks)
|
||||
else:
|
||||
sentence_chunks.append(sentence)
|
||||
return sentence_chunks
|
||||
|
||||
|
||||
@staticmethod
|
||||
def chunk_data_by_paragraph(data_chunks, chunk_size, overlap, bound = 0.75):
|
||||
data = "".join(data_chunks)
|
||||
total_length = len(data)
|
||||
chunks = []
|
||||
check_bound = int(bound * chunk_size)
|
||||
start_idx = 0
|
||||
chunk_splitter = "\n\n"
|
||||
|
||||
if data.find("\n\n") == -1:
|
||||
chunk_splitter = "\n"
|
||||
|
||||
while start_idx < total_length:
|
||||
# Set the end index to the minimum of start_idx + default_chunk_size or total_length
|
||||
end_idx = min(start_idx + chunk_size, total_length)
|
||||
|
||||
# Find the next paragraph index within the current chunk and bound
|
||||
next_paragraph_index = data.find(chunk_splitter, start_idx + check_bound, end_idx)
|
||||
|
||||
# If a next paragraph index is found within the current chunk
|
||||
if next_paragraph_index != -1:
|
||||
# Update end_idx to include the paragraph delimiter
|
||||
end_idx = next_paragraph_index + 2
|
||||
|
||||
end_index = end_idx + overlap
|
||||
|
||||
chunk_text = data[start_idx:end_index]
|
||||
|
||||
while chunk_text[-1] != "." and end_index < total_length:
|
||||
chunk_text += data[end_index]
|
||||
end_index += 1
|
||||
|
||||
end_idx = end_index - overlap
|
||||
|
||||
chunks.append(chunk_text.replace("\n", "").strip())
|
||||
|
||||
# Update start_idx to be the current end_idx
|
||||
start_idx = end_idx
|
||||
|
||||
return chunks
|
||||
|
|
@ -1,10 +1,10 @@
|
|||
from typing import TypedDict
|
||||
from uuid import uuid4
|
||||
from cognee.infrastructure import infrastructure_config
|
||||
from cognee.infrastructure.databases.vector import DataPoint
|
||||
|
||||
class TextChunk(TypedDict):
|
||||
text: str
|
||||
chunk_id: str
|
||||
file_metadata: dict
|
||||
|
||||
async def add_data_chunks(dataset_data_chunks: dict[str, list[TextChunk]]):
|
||||
|
|
@ -20,13 +20,13 @@ async def add_data_chunks(dataset_data_chunks: dict[str, list[TextChunk]]):
|
|||
|
||||
dataset_chunks = [
|
||||
dict(
|
||||
id = str(uuid4()),
|
||||
chunk_id = chunk["chunk_id"],
|
||||
collection = dataset_name,
|
||||
text = chunk["text"],
|
||||
file_metadata = chunk["file_metadata"],
|
||||
) for chunk in chunks
|
||||
]
|
||||
|
||||
|
||||
identified_chunks.extend(dataset_chunks)
|
||||
|
||||
await vector_client.create_data_points(
|
||||
|
|
|
|||
|
|
@ -1,243 +0,0 @@
|
|||
""" Chunking strategies for splitting text into smaller parts."""
|
||||
from __future__ import annotations
|
||||
from cognee.shared.data_models import ChunkStrategy
|
||||
import re
|
||||
|
||||
from typing import Any, List, Optional
|
||||
|
||||
|
||||
class CharacterTextSplitter():
|
||||
"""Splitting text that looks at characters."""
|
||||
|
||||
def __init__(
|
||||
self, separator: str = "\n\n", is_separator_regex: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
"""Create a new TextSplitter."""
|
||||
super().__init__(**kwargs)
|
||||
self._separator = separator
|
||||
self._is_separator_regex = is_separator_regex
|
||||
|
||||
def split_text(self, text: str) -> List[str]:
|
||||
"""Split incoming text and return chunks."""
|
||||
# First we naively split the large input into a bunch of smaller ones.
|
||||
separator = (
|
||||
self._separator if self._is_separator_regex else re.escape(self._separator)
|
||||
)
|
||||
splits = _split_text_with_regex(text, separator, self._keep_separator)
|
||||
_separator = "" if self._keep_separator else self._separator
|
||||
return self._merge_splits(splits, _separator)
|
||||
|
||||
|
||||
def _split_text_with_regex(
|
||||
text: str, separator: str, keep_separator: bool
|
||||
) -> List[str]:
|
||||
# Now that we have the separator, split the text
|
||||
if separator:
|
||||
if keep_separator:
|
||||
# The parentheses in the pattern keep the delimiters in the result.
|
||||
_splits = re.split(f"({separator})", text)
|
||||
splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
|
||||
if len(_splits) % 2 == 0:
|
||||
splits += _splits[-1:]
|
||||
splits = [_splits[0]] + splits
|
||||
else:
|
||||
splits = re.split(separator, text)
|
||||
else:
|
||||
splits = list(text)
|
||||
return [s for s in splits if s != ""]
|
||||
|
||||
|
||||
class RecursiveCharacterTextSplitter():
|
||||
"""Splitting text by recursively look at characters.
|
||||
|
||||
Recursively tries to split by different characters to find one
|
||||
that works.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
separators: Optional[List[str]] = None,
|
||||
keep_separator: bool = True,
|
||||
is_separator_regex: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create a new TextSplitter."""
|
||||
super().__init__(keep_separator=keep_separator, **kwargs)
|
||||
self._separators = separators or ["\n\n", "\n", " ", ""]
|
||||
self._is_separator_regex = is_separator_regex
|
||||
|
||||
def _split_text(self, text: str, separators: List[str]) -> List[str]:
|
||||
"""Split incoming text and return chunks."""
|
||||
final_chunks = []
|
||||
# Get appropriate separator to use
|
||||
separator = separators[-1]
|
||||
new_separators = []
|
||||
for i, _s in enumerate(separators):
|
||||
_separator = _s if self._is_separator_regex else re.escape(_s)
|
||||
if _s == "":
|
||||
separator = _s
|
||||
break
|
||||
if re.search(_separator, text):
|
||||
separator = _s
|
||||
new_separators = separators[i + 1 :]
|
||||
break
|
||||
|
||||
_separator = separator if self._is_separator_regex else re.escape(separator)
|
||||
splits = _split_text_with_regex(text, _separator, self._keep_separator)
|
||||
|
||||
# Now go merging things, recursively splitting longer texts.
|
||||
_good_splits = []
|
||||
_separator = "" if self._keep_separator else separator
|
||||
for s in splits:
|
||||
if self._length_function(s) < self._chunk_size:
|
||||
_good_splits.append(s)
|
||||
else:
|
||||
if _good_splits:
|
||||
merged_text = self._merge_splits(_good_splits, _separator)
|
||||
final_chunks.extend(merged_text)
|
||||
_good_splits = []
|
||||
if not new_separators:
|
||||
final_chunks.append(s)
|
||||
else:
|
||||
other_info = self._split_text(s, new_separators)
|
||||
final_chunks.extend(other_info)
|
||||
if _good_splits:
|
||||
merged_text = self._merge_splits(_good_splits, _separator)
|
||||
final_chunks.extend(merged_text)
|
||||
return final_chunks
|
||||
|
||||
def split_text(self, text: str) -> List[str]:
|
||||
return self._split_text(text, self._separators)
|
||||
|
||||
def chunk_data(chunk_strategy=None, source_data=None, chunk_size=None, chunk_overlap=None):
|
||||
"""
|
||||
Chunk data based on the specified strategy.
|
||||
|
||||
Parameters:
|
||||
- chunk_strategy: The strategy to use for chunking.
|
||||
- source_data: The data to be chunked.
|
||||
- chunk_size: The size of each chunk.
|
||||
- chunk_overlap: The overlap between chunks.
|
||||
|
||||
Returns:
|
||||
- The chunked data.
|
||||
"""
|
||||
|
||||
if chunk_strategy == ChunkStrategy.VANILLA:
|
||||
chunked_data = vanilla_chunker(source_data, chunk_size, chunk_overlap)
|
||||
|
||||
elif chunk_strategy == ChunkStrategy.PARAGRAPH:
|
||||
chunked_data = chunk_data_by_paragraph(source_data,chunk_size, chunk_overlap)
|
||||
|
||||
elif chunk_strategy == ChunkStrategy.SENTENCE:
|
||||
chunked_data = chunk_by_sentence(source_data, chunk_size, chunk_overlap)
|
||||
elif chunk_strategy == ChunkStrategy.EXACT:
|
||||
chunked_data = chunk_data_exact(source_data, chunk_size, chunk_overlap)
|
||||
elif chunk_strategy == ChunkStrategy.SUMMARY:
|
||||
chunked_data = summary_chunker(source_data, chunk_size, chunk_overlap)
|
||||
else:
|
||||
chunked_data = vanilla_chunker(source_data, chunk_size, chunk_overlap)
|
||||
|
||||
return chunked_data
|
||||
|
||||
|
||||
def vanilla_chunker(source_data, chunk_size=100, chunk_overlap=20):
|
||||
# adapt this for different chunking strategies
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
# Set a really small chunk size, just to show.
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
length_function=len
|
||||
)
|
||||
# try:
|
||||
# pages = text_splitter.create_documents([source_data])
|
||||
# except:
|
||||
# try:
|
||||
pages = text_splitter.create_documents([source_data])
|
||||
# except:
|
||||
# pages = text_splitter.create_documents(source_data.content)
|
||||
# pages = source_data.load_and_split()
|
||||
return pages
|
||||
|
||||
def summary_chunker(source_data, chunk_size=400, chunk_overlap=20):
|
||||
"""
|
||||
Chunk the given source data into smaller parts, returning the first five and last five chunks.
|
||||
|
||||
Parameters:
|
||||
- source_data (str): The source data to be chunked.
|
||||
- chunk_size (int): The size of each chunk.
|
||||
- chunk_overlap (int): The overlap between consecutive chunks.
|
||||
|
||||
Returns:
|
||||
- List: A list containing the first five and last five chunks of the chunked source data.
|
||||
"""
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
length_function=len
|
||||
)
|
||||
|
||||
try:
|
||||
pages = text_splitter.create_documents([source_data])
|
||||
except:
|
||||
pages = text_splitter.create_documents(source_data.content)
|
||||
|
||||
# Return the first 5 and last 5 chunks
|
||||
if len(pages) > 10:
|
||||
return pages[:5] + pages[-5:]
|
||||
else:
|
||||
return pages # Return all chunks if there are 10 or fewer
|
||||
|
||||
def chunk_data_exact(data_chunks, chunk_size, chunk_overlap):
|
||||
data = "".join(data_chunks)
|
||||
chunks = []
|
||||
for i in range(0, len(data), chunk_size - chunk_overlap):
|
||||
chunks.append(data[i:i + chunk_size])
|
||||
return chunks
|
||||
|
||||
|
||||
def chunk_by_sentence(data_chunks, chunk_size, overlap):
|
||||
# Split by periods, question marks, exclamation marks, and ellipses
|
||||
data = "".join(data_chunks)
|
||||
|
||||
# The regular expression is used to find series of charaters that end with one the following chaacters (. ! ? ...)
|
||||
sentence_endings = r'(?<=[.!?…]) +'
|
||||
sentences = re.split(sentence_endings, data)
|
||||
|
||||
sentence_chunks = []
|
||||
for sentence in sentences:
|
||||
if len(sentence) > chunk_size:
|
||||
chunks = chunk_data_exact([sentence], chunk_size, overlap)
|
||||
sentence_chunks.extend(chunks)
|
||||
else:
|
||||
sentence_chunks.append(sentence)
|
||||
return sentence_chunks
|
||||
|
||||
|
||||
def chunk_data_by_paragraph(data_chunks, chunk_size, overlap, bound=0.75):
|
||||
data = "".join(data_chunks)
|
||||
total_length = len(data)
|
||||
chunks = []
|
||||
check_bound = int(bound * chunk_size)
|
||||
start_idx = 0
|
||||
|
||||
while start_idx < total_length:
|
||||
# Set the end index to the minimum of start_idx + default_chunk_size or total_length
|
||||
end_idx = min(start_idx + chunk_size, total_length)
|
||||
|
||||
# Find the next paragraph index within the current chunk and bound
|
||||
next_paragraph_index = data.find('\n\n', start_idx + check_bound, end_idx)
|
||||
|
||||
# If a next paragraph index is found within the current chunk
|
||||
if next_paragraph_index != -1:
|
||||
# Update end_idx to include the paragraph delimiter
|
||||
end_idx = next_paragraph_index + 2
|
||||
|
||||
chunks.append(data[start_idx:end_idx + overlap])
|
||||
|
||||
# Update start_idx to be the current end_idx
|
||||
start_idx = end_idx
|
||||
|
||||
return chunks
|
||||
|
|
@ -36,8 +36,6 @@ class ChunkStrategy(Enum):
|
|||
EXACT = "exact"
|
||||
PARAGRAPH = "paragraph"
|
||||
SENTENCE = "sentence"
|
||||
VANILLA = "vanilla"
|
||||
SUMMARY = "summary"
|
||||
|
||||
class MemorySummary(BaseModel):
|
||||
""" Memory summary. """
|
||||
|
|
|
|||
|
|
@ -23,24 +23,24 @@
|
|||
"colbertv2_wiki17_abstracts = dspy.ColBERTv2(url = \"http://20.102.90.50:2017/wiki17_abstracts\")\n",
|
||||
"dspy.configure(rm = colbertv2_wiki17_abstracts)\n",
|
||||
"\n",
|
||||
"dataset = HotPotQA(\n",
|
||||
" train_seed = 1,\n",
|
||||
" train_size = 10,\n",
|
||||
" eval_seed = 2023,\n",
|
||||
" dev_size = 0,\n",
|
||||
" test_size = 0,\n",
|
||||
" keep_details = True,\n",
|
||||
")\n",
|
||||
"# dataset = HotPotQA(\n",
|
||||
"# train_seed = 1,\n",
|
||||
"# train_size = 10,\n",
|
||||
"# eval_seed = 2023,\n",
|
||||
"# dev_size = 0,\n",
|
||||
"# test_size = 0,\n",
|
||||
"# keep_details = True,\n",
|
||||
"# )\n",
|
||||
"\n",
|
||||
"texts_to_add = []\n",
|
||||
"# texts_to_add = []\n",
|
||||
"\n",
|
||||
"for train_case in dataset.train:\n",
|
||||
" train_case_text = \"\\r\\n\".join(\" \".join(context_sentences) for context_sentences in train_case.get(\"context\")[\"sentences\"])\n",
|
||||
"# for train_case in dataset.train:\n",
|
||||
"# train_case_text = \"\\r\\n\".join(\" \".join(context_sentences) for context_sentences in train_case.get(\"context\")[\"sentences\"])\n",
|
||||
"\n",
|
||||
" texts_to_add.append(train_case_text)\n",
|
||||
"# texts_to_add.append(train_case_text)\n",
|
||||
"\n",
|
||||
"dataset_name = \"train_dataset\"\n",
|
||||
"await cognee.add(texts_to_add, dataset_name)\n"
|
||||
"dataset_name = \"short_stories\"\n",
|
||||
"await cognee.add(\"data://\" + data_directory_path, dataset_name)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -61,7 +61,7 @@
|
|||
"\n",
|
||||
"print(cognee.datasets.list_datasets())\n",
|
||||
"\n",
|
||||
"train_dataset = cognee.datasets.query_data('train_dataset')\n",
|
||||
"train_dataset = cognee.datasets.query_data(\"short_stories\")\n",
|
||||
"print(len(train_dataset))"
|
||||
]
|
||||
},
|
||||
|
|
@ -73,8 +73,11 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"from os import path\n",
|
||||
"import logging\n",
|
||||
"import cognee\n",
|
||||
"\n",
|
||||
"logging.basicConfig(level = logging.INFO)\n",
|
||||
"\n",
|
||||
"await cognee.prune.prune_system()\n",
|
||||
"\n",
|
||||
"data_directory_path = path.abspath(\"../.data\")\n",
|
||||
|
|
@ -83,7 +86,7 @@
|
|||
"cognee_directory_path = path.abspath(\"../.cognee_system\")\n",
|
||||
"cognee.config.system_root_directory(cognee_directory_path)\n",
|
||||
"\n",
|
||||
"await cognee.cognify('train_dataset')"
|
||||
"await cognee.cognify('short_stories')"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -151,10 +154,10 @@
|
|||
"graph_client = await get_graph_client(GraphDBType.NETWORKX)\n",
|
||||
"graph = graph_client.graph\n",
|
||||
"\n",
|
||||
"results = await search_similarity(\"Who is Ernie Grunwald?\", graph)\n",
|
||||
"results = await search_similarity(\"Who are French girls?\", graph)\n",
|
||||
"\n",
|
||||
"for result in results:\n",
|
||||
" print(\"Ernie Grunwald\" in result)\n",
|
||||
" print(\"French girls\" in result)\n",
|
||||
" print(result)"
|
||||
]
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue