commit
0aed5f27be
10 changed files with 202 additions and 291 deletions
|
|
@ -55,7 +55,7 @@ async def add_files(file_paths: List[str], dataset_name: str):
|
||||||
|
|
||||||
if data_directory_path not in file_path:
|
if data_directory_path not in file_path:
|
||||||
file_name = file_path.split("/")[-1]
|
file_name = file_path.split("/")[-1]
|
||||||
file_directory_path = data_directory_path + "/" + (dataset_name.replace('.', "/") + "/" if dataset_name != "root" else "")
|
file_directory_path = data_directory_path + "/" + (dataset_name.replace(".", "/") + "/" if dataset_name != "root" else "")
|
||||||
dataset_file_path = path.join(file_directory_path, file_name)
|
dataset_file_path = path.join(file_directory_path, file_name)
|
||||||
|
|
||||||
LocalStorage.ensure_directory_exists(file_directory_path)
|
LocalStorage.ensure_directory_exists(file_directory_path)
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,11 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from uuid import uuid4
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
import logging
|
import logging
|
||||||
import instructor
|
import instructor
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
from nltk.corpus import stopwords
|
||||||
|
from cognee.config import Config
|
||||||
from cognee.modules.cognify.graph.add_data_chunks import add_data_chunks
|
from cognee.modules.cognify.graph.add_data_chunks import add_data_chunks
|
||||||
from cognee.modules.cognify.graph.add_document_node import add_document_node
|
from cognee.modules.cognify.graph.add_document_node import add_document_node
|
||||||
from cognee.modules.cognify.graph.add_classification_nodes import add_classification_nodes
|
from cognee.modules.cognify.graph.add_classification_nodes import add_classification_nodes
|
||||||
|
|
@ -11,9 +14,6 @@ from cognee.modules.cognify.graph.add_summary_nodes import add_summary_nodes
|
||||||
from cognee.modules.cognify.graph.add_node_connections import group_nodes_by_layer, \
|
from cognee.modules.cognify.graph.add_node_connections import group_nodes_by_layer, \
|
||||||
graph_ready_output, connect_nodes_in_graph
|
graph_ready_output, connect_nodes_in_graph
|
||||||
from cognee.modules.cognify.llm.resolve_cross_graph_references import resolve_cross_graph_references
|
from cognee.modules.cognify.llm.resolve_cross_graph_references import resolve_cross_graph_references
|
||||||
|
|
||||||
from cognee.config import Config
|
|
||||||
|
|
||||||
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
|
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
|
||||||
from cognee.modules.cognify.graph.add_label_nodes import add_label_nodes
|
from cognee.modules.cognify.graph.add_label_nodes import add_label_nodes
|
||||||
from cognee.modules.cognify.graph.add_cognitive_layers import add_cognitive_layers
|
from cognee.modules.cognify.graph.add_cognitive_layers import add_cognitive_layers
|
||||||
|
|
@ -26,6 +26,7 @@ from cognee.modules.data.get_content_summary import get_content_summary
|
||||||
from cognee.modules.data.get_cognitive_layers import get_cognitive_layers
|
from cognee.modules.data.get_cognitive_layers import get_cognitive_layers
|
||||||
from cognee.modules.data.get_layer_graphs import get_layer_graphs
|
from cognee.modules.data.get_layer_graphs import get_layer_graphs
|
||||||
|
|
||||||
|
|
||||||
config = Config()
|
config = Config()
|
||||||
config.load()
|
config.load()
|
||||||
|
|
||||||
|
|
@ -33,10 +34,12 @@ aclient = instructor.patch(OpenAI())
|
||||||
|
|
||||||
USER_ID = "default_user"
|
USER_ID = "default_user"
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger("cognify")
|
||||||
|
|
||||||
async def cognify(datasets: Union[str, List[str]] = None):
|
async def cognify(datasets: Union[str, List[str]] = None):
|
||||||
"""This function is responsible for the cognitive processing of the content."""
|
"""This function is responsible for the cognitive processing of the content."""
|
||||||
|
# Has to be loaded in advance, multithreading doesn't work without it.
|
||||||
|
stopwords.ensure_loaded()
|
||||||
|
|
||||||
db_engine = infrastructure_config.get_config()["database_engine"]
|
db_engine = infrastructure_config.get_config()["database_engine"]
|
||||||
|
|
||||||
|
|
@ -53,10 +56,10 @@ async def cognify(datasets: Union[str, List[str]] = None):
|
||||||
graphs = await asyncio.gather(*awaitables)
|
graphs = await asyncio.gather(*awaitables)
|
||||||
return graphs[0]
|
return graphs[0]
|
||||||
|
|
||||||
# datasets is a dataset name string
|
|
||||||
added_datasets = db_engine.get_datasets()
|
added_datasets = db_engine.get_datasets()
|
||||||
|
|
||||||
dataset_files = []
|
dataset_files = []
|
||||||
|
# datasets is a dataset name string
|
||||||
dataset_name = datasets.replace(".", "_").replace(" ", "_")
|
dataset_name = datasets.replace(".", "_").replace(" ", "_")
|
||||||
|
|
||||||
for added_dataset in added_datasets:
|
for added_dataset in added_datasets:
|
||||||
|
|
@ -73,30 +76,35 @@ async def cognify(datasets: Union[str, List[str]] = None):
|
||||||
|
|
||||||
data_chunks = {}
|
data_chunks = {}
|
||||||
|
|
||||||
|
chunk_engine = infrastructure_config.get_config()["chunk_engine"]
|
||||||
|
chunk_strategy = infrastructure_config.get_config()["chunk_strategy"]
|
||||||
|
|
||||||
for (dataset_name, files) in dataset_files:
|
for (dataset_name, files) in dataset_files:
|
||||||
for file_metadata in files[:3]:
|
for file_metadata in files:
|
||||||
with open(file_metadata["file_path"], "rb") as file:
|
with open(file_metadata["file_path"], "rb") as file:
|
||||||
try:
|
try:
|
||||||
file_type = guess_file_type(file)
|
file_type = guess_file_type(file)
|
||||||
text = extract_text_from_file(file, file_type)
|
text = extract_text_from_file(file, file_type)
|
||||||
|
subchunks = chunk_engine.chunk_data(chunk_strategy, text, config.chunk_size, config.chunk_overlap)
|
||||||
|
|
||||||
if dataset_name not in data_chunks:
|
if dataset_name not in data_chunks:
|
||||||
data_chunks[dataset_name] = []
|
data_chunks[dataset_name] = []
|
||||||
|
|
||||||
data_chunks[dataset_name].append(dict(text = text, file_metadata = file_metadata))
|
for subchunk in subchunks:
|
||||||
|
data_chunks[dataset_name].append(dict(text = subchunk, chunk_id = str(uuid4()), file_metadata = file_metadata))
|
||||||
except FileTypeException:
|
except FileTypeException:
|
||||||
logger.warning("File (%s) has an unknown file type. We are skipping it.", file_metadata["id"])
|
logger.warning("File (%s) has an unknown file type. We are skipping it.", file_metadata["id"])
|
||||||
|
|
||||||
added_chunks: list[tuple[str, str, dict]] = await add_data_chunks(data_chunks)
|
added_chunks: list[tuple[str, str, dict]] = await add_data_chunks(data_chunks)
|
||||||
|
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
*[process_text(chunk["collection"], chunk["id"], chunk["text"], chunk["file_metadata"]) for chunk in added_chunks]
|
*[process_text(chunk["collection"], chunk["chunk_id"], chunk["text"], chunk["file_metadata"]) for chunk in added_chunks]
|
||||||
)
|
)
|
||||||
|
|
||||||
return graph_client.graph
|
return graph_client.graph
|
||||||
|
|
||||||
async def process_text(chunk_collection: str, chunk_id: str, input_text: str, file_metadata: dict):
|
async def process_text(chunk_collection: str, chunk_id: str, input_text: str, file_metadata: dict):
|
||||||
print(f"Processing document ({file_metadata['id']}).")
|
print(f"Processing chunk ({chunk_id}) from document ({file_metadata['id']}).")
|
||||||
|
|
||||||
graph_client = await get_graph_client(infrastructure_config.get_config()["graph_engine"])
|
graph_client = await get_graph_client(infrastructure_config.get_config()["graph_engine"])
|
||||||
|
|
||||||
|
|
@ -115,12 +123,12 @@ async def process_text(chunk_collection: str, chunk_id: str, input_text: str, fi
|
||||||
categories = classified_categories,
|
categories = classified_categories,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Document ({document_id}) classified.")
|
print(f"Chunk ({chunk_id}) classified.")
|
||||||
|
|
||||||
content_summary = await get_content_summary(input_text)
|
content_summary = await get_content_summary(input_text)
|
||||||
await add_summary_nodes(graph_client, document_id, content_summary)
|
await add_summary_nodes(graph_client, document_id, content_summary)
|
||||||
|
|
||||||
print(f"Document ({document_id}) summarized.")
|
print(f"Chunk ({chunk_id}) summarized.")
|
||||||
|
|
||||||
cognitive_layers = await get_cognitive_layers(input_text, classified_categories)
|
cognitive_layers = await get_cognitive_layers(input_text, classified_categories)
|
||||||
cognitive_layers = (await add_cognitive_layers(graph_client, document_id, cognitive_layers))[:2]
|
cognitive_layers = (await add_cognitive_layers(graph_client, document_id, cognitive_layers))[:2]
|
||||||
|
|
@ -132,8 +140,6 @@ async def process_text(chunk_collection: str, chunk_id: str, input_text: str, fi
|
||||||
db_engine = infrastructure_config.get_config()["database_engine"]
|
db_engine = infrastructure_config.get_config()["database_engine"]
|
||||||
relevant_documents_to_connect = db_engine.fetch_cognify_data(excluded_document_id = file_metadata["id"])
|
relevant_documents_to_connect = db_engine.fetch_cognify_data(excluded_document_id = file_metadata["id"])
|
||||||
|
|
||||||
print("Relevant documents to connect are: ", relevant_documents_to_connect)
|
|
||||||
|
|
||||||
list_of_nodes = []
|
list_of_nodes = []
|
||||||
|
|
||||||
relevant_documents_to_connect.append({
|
relevant_documents_to_connect.append({
|
||||||
|
|
@ -144,13 +150,9 @@ async def process_text(chunk_collection: str, chunk_id: str, input_text: str, fi
|
||||||
node_descriptions_to_match = await graph_client.extract_node_description(document["layer_id"])
|
node_descriptions_to_match = await graph_client.extract_node_description(document["layer_id"])
|
||||||
list_of_nodes.extend(node_descriptions_to_match)
|
list_of_nodes.extend(node_descriptions_to_match)
|
||||||
|
|
||||||
print("List of nodes are: ", len(list_of_nodes))
|
|
||||||
|
|
||||||
nodes_by_layer = await group_nodes_by_layer(list_of_nodes)
|
nodes_by_layer = await group_nodes_by_layer(list_of_nodes)
|
||||||
print("Nodes by layer are: ", str(nodes_by_layer)[:5000])
|
|
||||||
|
|
||||||
results = await resolve_cross_graph_references(nodes_by_layer)
|
results = await resolve_cross_graph_references(nodes_by_layer)
|
||||||
print("Results are: ", str(results)[:3000])
|
|
||||||
|
|
||||||
relationships = graph_ready_output(results)
|
relationships = graph_ready_output(results)
|
||||||
|
|
||||||
|
|
@ -160,4 +162,4 @@ async def process_text(chunk_collection: str, chunk_id: str, input_text: str, fi
|
||||||
score_threshold = infrastructure_config.get_config()["intra_layer_score_treshold"]
|
score_threshold = infrastructure_config.get_config()["intra_layer_score_treshold"]
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Document ({document_id}) cognified.")
|
print(f"Chunk ({chunk_id}) cognified.")
|
||||||
|
|
|
||||||
|
|
@ -67,4 +67,8 @@ class config():
|
||||||
infrastructure_config.set_config({
|
infrastructure_config.set_config({
|
||||||
"connect_documents": connect_documents
|
"connect_documents": connect_documents
|
||||||
})
|
})
|
||||||
|
@staticmethod
|
||||||
|
def set_chunk_strategy(chunk_strategy: object):
|
||||||
|
infrastructure_config.set_config({
|
||||||
|
"chunk_strategy": chunk_strategy
|
||||||
|
})
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from cognee.root_dir import get_absolute_path
|
from cognee.root_dir import get_absolute_path
|
||||||
|
from cognee.shared.data_models import ChunkStrategy
|
||||||
|
|
||||||
base_dir = Path(__file__).resolve().parent.parent
|
base_dir = Path(__file__).resolve().parent.parent
|
||||||
# Load the .env file from the base directory
|
# Load the .env file from the base directory
|
||||||
|
|
@ -116,6 +116,11 @@ class Config:
|
||||||
# Client ID
|
# Client ID
|
||||||
anon_clientid: Optional[str] = field(default_factory=lambda: uuid.uuid4().hex)
|
anon_clientid: Optional[str] = field(default_factory=lambda: uuid.uuid4().hex)
|
||||||
|
|
||||||
|
#Chunking parameters
|
||||||
|
chunk_size: int = 1500
|
||||||
|
chunk_overlap: int = 0
|
||||||
|
chunk_strategy: str = ChunkStrategy.PARAGRAPH
|
||||||
|
|
||||||
def load(self):
|
def load(self):
|
||||||
"""Loads the configuration from a file or environment variables."""
|
"""Loads the configuration from a file or environment variables."""
|
||||||
config = configparser.ConfigParser()
|
config = configparser.ConfigParser()
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from .databases.vector.embeddings.DefaultEmbeddingEngine import DefaultEmbedding
|
||||||
from .llm.llm_interface import LLMInterface
|
from .llm.llm_interface import LLMInterface
|
||||||
from .llm.openai.adapter import OpenAIAdapter
|
from .llm.openai.adapter import OpenAIAdapter
|
||||||
from .files.storage import LocalStorage
|
from .files.storage import LocalStorage
|
||||||
|
from .data.chunking.DefaultChunkEngine import DefaultChunkEngine
|
||||||
from ..shared.data_models import GraphDBType, DefaultContentPrediction, KnowledgeGraph, SummarizedContent, \
|
from ..shared.data_models import GraphDBType, DefaultContentPrediction, KnowledgeGraph, SummarizedContent, \
|
||||||
LabeledContent, DefaultCognitiveLayer
|
LabeledContent, DefaultCognitiveLayer
|
||||||
|
|
||||||
|
|
@ -30,6 +31,8 @@ class InfrastructureConfig():
|
||||||
connect_documents = config.connect_documents
|
connect_documents = config.connect_documents
|
||||||
database_directory_path: str = None
|
database_directory_path: str = None
|
||||||
database_file_path: str = None
|
database_file_path: str = None
|
||||||
|
chunk_strategy = config.chunk_strategy
|
||||||
|
chunk_engine = None
|
||||||
|
|
||||||
def get_config(self, config_entity: str = None) -> dict:
|
def get_config(self, config_entity: str = None) -> dict:
|
||||||
if (config_entity is None or config_entity == "database_engine") and self.database_engine is None:
|
if (config_entity is None or config_entity == "database_engine") and self.database_engine is None:
|
||||||
|
|
@ -69,6 +72,12 @@ class InfrastructureConfig():
|
||||||
if self.connect_documents is None:
|
if self.connect_documents is None:
|
||||||
self.connect_documents = config.connect_documents
|
self.connect_documents = config.connect_documents
|
||||||
|
|
||||||
|
if self.chunk_strategy is None:
|
||||||
|
self.chunk_strategy = config.chunk_strategy
|
||||||
|
|
||||||
|
if self.chunk_engine is None:
|
||||||
|
self.chunk_engine = DefaultChunkEngine()
|
||||||
|
|
||||||
if (config_entity is None or config_entity == "llm_engine") and self.llm_engine is None:
|
if (config_entity is None or config_entity == "llm_engine") and self.llm_engine is None:
|
||||||
self.llm_engine = OpenAIAdapter(config.openai_key, config.openai_model)
|
self.llm_engine = OpenAIAdapter(config.openai_key, config.openai_model)
|
||||||
|
|
||||||
|
|
@ -120,16 +129,18 @@ class InfrastructureConfig():
|
||||||
"embedding_engine": self.embedding_engine,
|
"embedding_engine": self.embedding_engine,
|
||||||
"connect_documents": self.connect_documents,
|
"connect_documents": self.connect_documents,
|
||||||
"database_directory_path": self.database_directory_path,
|
"database_directory_path": self.database_directory_path,
|
||||||
"database_path": self.database_file_path
|
"database_path": self.database_file_path,
|
||||||
|
"chunk_strategy": self.chunk_strategy,
|
||||||
|
"chunk_engine": self.chunk_engine,
|
||||||
}
|
}
|
||||||
|
|
||||||
def set_config(self, new_config: dict):
|
def set_config(self, new_config: dict):
|
||||||
if "system_root_directory" in new_config:
|
if "system_root_directory" in new_config:
|
||||||
self.system_root_directory = new_config["system_root_directory"]
|
self.system_root_directory = new_config["system_root_directory"]
|
||||||
|
|
||||||
if "data_root_directory" in new_config:
|
if "data_root_directory" in new_config:
|
||||||
self.data_root_directory = new_config["data_root_directory"]
|
self.data_root_directory = new_config["data_root_directory"]
|
||||||
|
|
||||||
if "database_engine" in new_config:
|
if "database_engine" in new_config:
|
||||||
self.database_engine = new_config["database_engine"]
|
self.database_engine = new_config["database_engine"]
|
||||||
|
|
||||||
|
|
@ -169,4 +180,10 @@ class InfrastructureConfig():
|
||||||
if "connect_documents" in new_config:
|
if "connect_documents" in new_config:
|
||||||
self.connect_documents = new_config["connect_documents"]
|
self.connect_documents = new_config["connect_documents"]
|
||||||
|
|
||||||
|
if "chunk_strategy" in new_config:
|
||||||
|
self.chunk_strategy = new_config["chunk_strategy"]
|
||||||
|
|
||||||
|
if "chunk_engine" in new_config:
|
||||||
|
self.chunk_engine = new_config["chunk_engine"]
|
||||||
|
|
||||||
infrastructure_config = InfrastructureConfig()
|
infrastructure_config = InfrastructureConfig()
|
||||||
|
|
|
||||||
125
cognee/infrastructure/data/chunking/DefaultChunkEngine.py
Normal file
125
cognee/infrastructure/data/chunking/DefaultChunkEngine.py
Normal file
|
|
@ -0,0 +1,125 @@
|
||||||
|
""" Chunking strategies for splitting text into smaller parts."""
|
||||||
|
from __future__ import annotations
|
||||||
|
import re
|
||||||
|
from cognee.shared.data_models import ChunkStrategy
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultChunkEngine():
|
||||||
|
@staticmethod
|
||||||
|
def _split_text_with_regex(
|
||||||
|
text: str, separator: str, keep_separator: bool
|
||||||
|
) -> list[str]:
|
||||||
|
# Now that we have the separator, split the text
|
||||||
|
if separator:
|
||||||
|
if keep_separator:
|
||||||
|
# The parentheses in the pattern keep the delimiters in the result.
|
||||||
|
_splits = re.split(f"({separator})", text)
|
||||||
|
splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
|
||||||
|
if len(_splits) % 2 == 0:
|
||||||
|
splits += _splits[-1:]
|
||||||
|
splits = [_splits[0]] + splits
|
||||||
|
else:
|
||||||
|
splits = re.split(separator, text)
|
||||||
|
else:
|
||||||
|
splits = list(text)
|
||||||
|
return [s for s in splits if s != ""]
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def chunk_data(
|
||||||
|
chunk_strategy = None,
|
||||||
|
source_data = None,
|
||||||
|
chunk_size = None,
|
||||||
|
chunk_overlap = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Chunk data based on the specified strategy.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- chunk_strategy: The strategy to use for chunking.
|
||||||
|
- source_data: The data to be chunked.
|
||||||
|
- chunk_size: The size of each chunk.
|
||||||
|
- chunk_overlap: The overlap between chunks.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- The chunked data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if chunk_strategy == ChunkStrategy.PARAGRAPH:
|
||||||
|
chunked_data = DefaultChunkEngine.chunk_data_by_paragraph(source_data,chunk_size, chunk_overlap)
|
||||||
|
elif chunk_strategy == ChunkStrategy.SENTENCE:
|
||||||
|
chunked_data = DefaultChunkEngine.chunk_by_sentence(source_data, chunk_size, chunk_overlap)
|
||||||
|
elif chunk_strategy == ChunkStrategy.EXACT:
|
||||||
|
chunked_data = DefaultChunkEngine.chunk_data_exact(source_data, chunk_size, chunk_overlap)
|
||||||
|
|
||||||
|
return chunked_data
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def chunk_data_exact(data_chunks, chunk_size, chunk_overlap):
|
||||||
|
data = "".join(data_chunks)
|
||||||
|
chunks = []
|
||||||
|
for i in range(0, len(data), chunk_size - chunk_overlap):
|
||||||
|
chunks.append(data[i:i + chunk_size])
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def chunk_by_sentence(data_chunks, chunk_size, overlap):
|
||||||
|
# Split by periods, question marks, exclamation marks, and ellipses
|
||||||
|
data = "".join(data_chunks)
|
||||||
|
|
||||||
|
# The regular expression is used to find series of charaters that end with one the following chaacters (. ! ? ...)
|
||||||
|
sentence_endings = r'(?<=[.!?…]) +'
|
||||||
|
sentences = re.split(sentence_endings, data)
|
||||||
|
|
||||||
|
sentence_chunks = []
|
||||||
|
for sentence in sentences:
|
||||||
|
if len(sentence) > chunk_size:
|
||||||
|
chunks = DefaultChunkEngine.chunk_data_exact([sentence], chunk_size, overlap)
|
||||||
|
sentence_chunks.extend(chunks)
|
||||||
|
else:
|
||||||
|
sentence_chunks.append(sentence)
|
||||||
|
return sentence_chunks
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def chunk_data_by_paragraph(data_chunks, chunk_size, overlap, bound = 0.75):
|
||||||
|
data = "".join(data_chunks)
|
||||||
|
total_length = len(data)
|
||||||
|
chunks = []
|
||||||
|
check_bound = int(bound * chunk_size)
|
||||||
|
start_idx = 0
|
||||||
|
chunk_splitter = "\n\n"
|
||||||
|
|
||||||
|
if data.find("\n\n") == -1:
|
||||||
|
chunk_splitter = "\n"
|
||||||
|
|
||||||
|
while start_idx < total_length:
|
||||||
|
# Set the end index to the minimum of start_idx + default_chunk_size or total_length
|
||||||
|
end_idx = min(start_idx + chunk_size, total_length)
|
||||||
|
|
||||||
|
# Find the next paragraph index within the current chunk and bound
|
||||||
|
next_paragraph_index = data.find(chunk_splitter, start_idx + check_bound, end_idx)
|
||||||
|
|
||||||
|
# If a next paragraph index is found within the current chunk
|
||||||
|
if next_paragraph_index != -1:
|
||||||
|
# Update end_idx to include the paragraph delimiter
|
||||||
|
end_idx = next_paragraph_index + 2
|
||||||
|
|
||||||
|
end_index = end_idx + overlap
|
||||||
|
|
||||||
|
chunk_text = data[start_idx:end_index]
|
||||||
|
|
||||||
|
while chunk_text[-1] != "." and end_index < total_length:
|
||||||
|
chunk_text += data[end_index]
|
||||||
|
end_index += 1
|
||||||
|
|
||||||
|
end_idx = end_index - overlap
|
||||||
|
|
||||||
|
chunks.append(chunk_text.replace("\n", "").strip())
|
||||||
|
|
||||||
|
# Update start_idx to be the current end_idx
|
||||||
|
start_idx = end_idx
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
@ -1,10 +1,10 @@
|
||||||
from typing import TypedDict
|
from typing import TypedDict
|
||||||
from uuid import uuid4
|
|
||||||
from cognee.infrastructure import infrastructure_config
|
from cognee.infrastructure import infrastructure_config
|
||||||
from cognee.infrastructure.databases.vector import DataPoint
|
from cognee.infrastructure.databases.vector import DataPoint
|
||||||
|
|
||||||
class TextChunk(TypedDict):
|
class TextChunk(TypedDict):
|
||||||
text: str
|
text: str
|
||||||
|
chunk_id: str
|
||||||
file_metadata: dict
|
file_metadata: dict
|
||||||
|
|
||||||
async def add_data_chunks(dataset_data_chunks: dict[str, list[TextChunk]]):
|
async def add_data_chunks(dataset_data_chunks: dict[str, list[TextChunk]]):
|
||||||
|
|
@ -20,20 +20,20 @@ async def add_data_chunks(dataset_data_chunks: dict[str, list[TextChunk]]):
|
||||||
|
|
||||||
dataset_chunks = [
|
dataset_chunks = [
|
||||||
dict(
|
dict(
|
||||||
id = str(uuid4()),
|
chunk_id = chunk["chunk_id"],
|
||||||
collection = dataset_name,
|
collection = dataset_name,
|
||||||
text = chunk["text"],
|
text = chunk["text"],
|
||||||
file_metadata = chunk["file_metadata"],
|
file_metadata = chunk["file_metadata"],
|
||||||
) for chunk in chunks
|
) for chunk in chunks
|
||||||
]
|
]
|
||||||
|
|
||||||
identified_chunks.extend(dataset_chunks)
|
identified_chunks.extend(dataset_chunks)
|
||||||
|
|
||||||
await vector_client.create_data_points(
|
await vector_client.create_data_points(
|
||||||
dataset_name,
|
dataset_name,
|
||||||
[
|
[
|
||||||
DataPoint(
|
DataPoint(
|
||||||
id = chunk["id"],
|
id = chunk["chunk_id"],
|
||||||
payload = dict(text = chunk["text"]),
|
payload = dict(text = chunk["text"]),
|
||||||
embed_field = "text"
|
embed_field = "text"
|
||||||
) for chunk in dataset_chunks
|
) for chunk in dataset_chunks
|
||||||
|
|
|
||||||
|
|
@ -1,243 +0,0 @@
|
||||||
""" Chunking strategies for splitting text into smaller parts."""
|
|
||||||
from __future__ import annotations
|
|
||||||
from cognee.shared.data_models import ChunkStrategy
|
|
||||||
import re
|
|
||||||
|
|
||||||
from typing import Any, List, Optional
|
|
||||||
|
|
||||||
|
|
||||||
class CharacterTextSplitter():
|
|
||||||
"""Splitting text that looks at characters."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, separator: str = "\n\n", is_separator_regex: bool = False, **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Create a new TextSplitter."""
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self._separator = separator
|
|
||||||
self._is_separator_regex = is_separator_regex
|
|
||||||
|
|
||||||
def split_text(self, text: str) -> List[str]:
|
|
||||||
"""Split incoming text and return chunks."""
|
|
||||||
# First we naively split the large input into a bunch of smaller ones.
|
|
||||||
separator = (
|
|
||||||
self._separator if self._is_separator_regex else re.escape(self._separator)
|
|
||||||
)
|
|
||||||
splits = _split_text_with_regex(text, separator, self._keep_separator)
|
|
||||||
_separator = "" if self._keep_separator else self._separator
|
|
||||||
return self._merge_splits(splits, _separator)
|
|
||||||
|
|
||||||
|
|
||||||
def _split_text_with_regex(
|
|
||||||
text: str, separator: str, keep_separator: bool
|
|
||||||
) -> List[str]:
|
|
||||||
# Now that we have the separator, split the text
|
|
||||||
if separator:
|
|
||||||
if keep_separator:
|
|
||||||
# The parentheses in the pattern keep the delimiters in the result.
|
|
||||||
_splits = re.split(f"({separator})", text)
|
|
||||||
splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
|
|
||||||
if len(_splits) % 2 == 0:
|
|
||||||
splits += _splits[-1:]
|
|
||||||
splits = [_splits[0]] + splits
|
|
||||||
else:
|
|
||||||
splits = re.split(separator, text)
|
|
||||||
else:
|
|
||||||
splits = list(text)
|
|
||||||
return [s for s in splits if s != ""]
|
|
||||||
|
|
||||||
|
|
||||||
class RecursiveCharacterTextSplitter():
|
|
||||||
"""Splitting text by recursively look at characters.
|
|
||||||
|
|
||||||
Recursively tries to split by different characters to find one
|
|
||||||
that works.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
separators: Optional[List[str]] = None,
|
|
||||||
keep_separator: bool = True,
|
|
||||||
is_separator_regex: bool = False,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
"""Create a new TextSplitter."""
|
|
||||||
super().__init__(keep_separator=keep_separator, **kwargs)
|
|
||||||
self._separators = separators or ["\n\n", "\n", " ", ""]
|
|
||||||
self._is_separator_regex = is_separator_regex
|
|
||||||
|
|
||||||
def _split_text(self, text: str, separators: List[str]) -> List[str]:
|
|
||||||
"""Split incoming text and return chunks."""
|
|
||||||
final_chunks = []
|
|
||||||
# Get appropriate separator to use
|
|
||||||
separator = separators[-1]
|
|
||||||
new_separators = []
|
|
||||||
for i, _s in enumerate(separators):
|
|
||||||
_separator = _s if self._is_separator_regex else re.escape(_s)
|
|
||||||
if _s == "":
|
|
||||||
separator = _s
|
|
||||||
break
|
|
||||||
if re.search(_separator, text):
|
|
||||||
separator = _s
|
|
||||||
new_separators = separators[i + 1 :]
|
|
||||||
break
|
|
||||||
|
|
||||||
_separator = separator if self._is_separator_regex else re.escape(separator)
|
|
||||||
splits = _split_text_with_regex(text, _separator, self._keep_separator)
|
|
||||||
|
|
||||||
# Now go merging things, recursively splitting longer texts.
|
|
||||||
_good_splits = []
|
|
||||||
_separator = "" if self._keep_separator else separator
|
|
||||||
for s in splits:
|
|
||||||
if self._length_function(s) < self._chunk_size:
|
|
||||||
_good_splits.append(s)
|
|
||||||
else:
|
|
||||||
if _good_splits:
|
|
||||||
merged_text = self._merge_splits(_good_splits, _separator)
|
|
||||||
final_chunks.extend(merged_text)
|
|
||||||
_good_splits = []
|
|
||||||
if not new_separators:
|
|
||||||
final_chunks.append(s)
|
|
||||||
else:
|
|
||||||
other_info = self._split_text(s, new_separators)
|
|
||||||
final_chunks.extend(other_info)
|
|
||||||
if _good_splits:
|
|
||||||
merged_text = self._merge_splits(_good_splits, _separator)
|
|
||||||
final_chunks.extend(merged_text)
|
|
||||||
return final_chunks
|
|
||||||
|
|
||||||
def split_text(self, text: str) -> List[str]:
|
|
||||||
return self._split_text(text, self._separators)
|
|
||||||
|
|
||||||
def chunk_data(chunk_strategy=None, source_data=None, chunk_size=None, chunk_overlap=None):
|
|
||||||
"""
|
|
||||||
Chunk data based on the specified strategy.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- chunk_strategy: The strategy to use for chunking.
|
|
||||||
- source_data: The data to be chunked.
|
|
||||||
- chunk_size: The size of each chunk.
|
|
||||||
- chunk_overlap: The overlap between chunks.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- The chunked data.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if chunk_strategy == ChunkStrategy.VANILLA:
|
|
||||||
chunked_data = vanilla_chunker(source_data, chunk_size, chunk_overlap)
|
|
||||||
|
|
||||||
elif chunk_strategy == ChunkStrategy.PARAGRAPH:
|
|
||||||
chunked_data = chunk_data_by_paragraph(source_data,chunk_size, chunk_overlap)
|
|
||||||
|
|
||||||
elif chunk_strategy == ChunkStrategy.SENTENCE:
|
|
||||||
chunked_data = chunk_by_sentence(source_data, chunk_size, chunk_overlap)
|
|
||||||
elif chunk_strategy == ChunkStrategy.EXACT:
|
|
||||||
chunked_data = chunk_data_exact(source_data, chunk_size, chunk_overlap)
|
|
||||||
elif chunk_strategy == ChunkStrategy.SUMMARY:
|
|
||||||
chunked_data = summary_chunker(source_data, chunk_size, chunk_overlap)
|
|
||||||
else:
|
|
||||||
chunked_data = vanilla_chunker(source_data, chunk_size, chunk_overlap)
|
|
||||||
|
|
||||||
return chunked_data
|
|
||||||
|
|
||||||
|
|
||||||
def vanilla_chunker(source_data, chunk_size=100, chunk_overlap=20):
|
|
||||||
# adapt this for different chunking strategies
|
|
||||||
|
|
||||||
text_splitter = RecursiveCharacterTextSplitter(
|
|
||||||
# Set a really small chunk size, just to show.
|
|
||||||
chunk_size=chunk_size,
|
|
||||||
chunk_overlap=chunk_overlap,
|
|
||||||
length_function=len
|
|
||||||
)
|
|
||||||
# try:
|
|
||||||
# pages = text_splitter.create_documents([source_data])
|
|
||||||
# except:
|
|
||||||
# try:
|
|
||||||
pages = text_splitter.create_documents([source_data])
|
|
||||||
# except:
|
|
||||||
# pages = text_splitter.create_documents(source_data.content)
|
|
||||||
# pages = source_data.load_and_split()
|
|
||||||
return pages
|
|
||||||
|
|
||||||
def summary_chunker(source_data, chunk_size=400, chunk_overlap=20):
|
|
||||||
"""
|
|
||||||
Chunk the given source data into smaller parts, returning the first five and last five chunks.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- source_data (str): The source data to be chunked.
|
|
||||||
- chunk_size (int): The size of each chunk.
|
|
||||||
- chunk_overlap (int): The overlap between consecutive chunks.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- List: A list containing the first five and last five chunks of the chunked source data.
|
|
||||||
"""
|
|
||||||
|
|
||||||
text_splitter = RecursiveCharacterTextSplitter(
|
|
||||||
chunk_size=chunk_size,
|
|
||||||
chunk_overlap=chunk_overlap,
|
|
||||||
length_function=len
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
pages = text_splitter.create_documents([source_data])
|
|
||||||
except:
|
|
||||||
pages = text_splitter.create_documents(source_data.content)
|
|
||||||
|
|
||||||
# Return the first 5 and last 5 chunks
|
|
||||||
if len(pages) > 10:
|
|
||||||
return pages[:5] + pages[-5:]
|
|
||||||
else:
|
|
||||||
return pages # Return all chunks if there are 10 or fewer
|
|
||||||
|
|
||||||
def chunk_data_exact(data_chunks, chunk_size, chunk_overlap):
|
|
||||||
data = "".join(data_chunks)
|
|
||||||
chunks = []
|
|
||||||
for i in range(0, len(data), chunk_size - chunk_overlap):
|
|
||||||
chunks.append(data[i:i + chunk_size])
|
|
||||||
return chunks
|
|
||||||
|
|
||||||
|
|
||||||
def chunk_by_sentence(data_chunks, chunk_size, overlap):
|
|
||||||
# Split by periods, question marks, exclamation marks, and ellipses
|
|
||||||
data = "".join(data_chunks)
|
|
||||||
|
|
||||||
# The regular expression is used to find series of charaters that end with one the following chaacters (. ! ? ...)
|
|
||||||
sentence_endings = r'(?<=[.!?…]) +'
|
|
||||||
sentences = re.split(sentence_endings, data)
|
|
||||||
|
|
||||||
sentence_chunks = []
|
|
||||||
for sentence in sentences:
|
|
||||||
if len(sentence) > chunk_size:
|
|
||||||
chunks = chunk_data_exact([sentence], chunk_size, overlap)
|
|
||||||
sentence_chunks.extend(chunks)
|
|
||||||
else:
|
|
||||||
sentence_chunks.append(sentence)
|
|
||||||
return sentence_chunks
|
|
||||||
|
|
||||||
|
|
||||||
def chunk_data_by_paragraph(data_chunks, chunk_size, overlap, bound=0.75):
|
|
||||||
data = "".join(data_chunks)
|
|
||||||
total_length = len(data)
|
|
||||||
chunks = []
|
|
||||||
check_bound = int(bound * chunk_size)
|
|
||||||
start_idx = 0
|
|
||||||
|
|
||||||
while start_idx < total_length:
|
|
||||||
# Set the end index to the minimum of start_idx + default_chunk_size or total_length
|
|
||||||
end_idx = min(start_idx + chunk_size, total_length)
|
|
||||||
|
|
||||||
# Find the next paragraph index within the current chunk and bound
|
|
||||||
next_paragraph_index = data.find('\n\n', start_idx + check_bound, end_idx)
|
|
||||||
|
|
||||||
# If a next paragraph index is found within the current chunk
|
|
||||||
if next_paragraph_index != -1:
|
|
||||||
# Update end_idx to include the paragraph delimiter
|
|
||||||
end_idx = next_paragraph_index + 2
|
|
||||||
|
|
||||||
chunks.append(data[start_idx:end_idx + overlap])
|
|
||||||
|
|
||||||
# Update start_idx to be the current end_idx
|
|
||||||
start_idx = end_idx
|
|
||||||
|
|
||||||
return chunks
|
|
||||||
|
|
@ -36,8 +36,6 @@ class ChunkStrategy(Enum):
|
||||||
EXACT = "exact"
|
EXACT = "exact"
|
||||||
PARAGRAPH = "paragraph"
|
PARAGRAPH = "paragraph"
|
||||||
SENTENCE = "sentence"
|
SENTENCE = "sentence"
|
||||||
VANILLA = "vanilla"
|
|
||||||
SUMMARY = "summary"
|
|
||||||
|
|
||||||
class MemorySummary(BaseModel):
|
class MemorySummary(BaseModel):
|
||||||
""" Memory summary. """
|
""" Memory summary. """
|
||||||
|
|
|
||||||
|
|
@ -23,24 +23,24 @@
|
||||||
"colbertv2_wiki17_abstracts = dspy.ColBERTv2(url = \"http://20.102.90.50:2017/wiki17_abstracts\")\n",
|
"colbertv2_wiki17_abstracts = dspy.ColBERTv2(url = \"http://20.102.90.50:2017/wiki17_abstracts\")\n",
|
||||||
"dspy.configure(rm = colbertv2_wiki17_abstracts)\n",
|
"dspy.configure(rm = colbertv2_wiki17_abstracts)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"dataset = HotPotQA(\n",
|
"# dataset = HotPotQA(\n",
|
||||||
" train_seed = 1,\n",
|
"# train_seed = 1,\n",
|
||||||
" train_size = 10,\n",
|
"# train_size = 10,\n",
|
||||||
" eval_seed = 2023,\n",
|
"# eval_seed = 2023,\n",
|
||||||
" dev_size = 0,\n",
|
"# dev_size = 0,\n",
|
||||||
" test_size = 0,\n",
|
"# test_size = 0,\n",
|
||||||
" keep_details = True,\n",
|
"# keep_details = True,\n",
|
||||||
")\n",
|
"# )\n",
|
||||||
"\n",
|
"\n",
|
||||||
"texts_to_add = []\n",
|
"# texts_to_add = []\n",
|
||||||
"\n",
|
"\n",
|
||||||
"for train_case in dataset.train:\n",
|
"# for train_case in dataset.train:\n",
|
||||||
" train_case_text = \"\\r\\n\".join(\" \".join(context_sentences) for context_sentences in train_case.get(\"context\")[\"sentences\"])\n",
|
"# train_case_text = \"\\r\\n\".join(\" \".join(context_sentences) for context_sentences in train_case.get(\"context\")[\"sentences\"])\n",
|
||||||
"\n",
|
"\n",
|
||||||
" texts_to_add.append(train_case_text)\n",
|
"# texts_to_add.append(train_case_text)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"dataset_name = \"train_dataset\"\n",
|
"dataset_name = \"short_stories\"\n",
|
||||||
"await cognee.add(texts_to_add, dataset_name)\n"
|
"await cognee.add(\"data://\" + data_directory_path, dataset_name)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
@ -61,7 +61,7 @@
|
||||||
"\n",
|
"\n",
|
||||||
"print(cognee.datasets.list_datasets())\n",
|
"print(cognee.datasets.list_datasets())\n",
|
||||||
"\n",
|
"\n",
|
||||||
"train_dataset = cognee.datasets.query_data('train_dataset')\n",
|
"train_dataset = cognee.datasets.query_data(\"short_stories\")\n",
|
||||||
"print(len(train_dataset))"
|
"print(len(train_dataset))"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
|
@ -73,8 +73,11 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from os import path\n",
|
"from os import path\n",
|
||||||
|
"import logging\n",
|
||||||
"import cognee\n",
|
"import cognee\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"logging.basicConfig(level = logging.INFO)\n",
|
||||||
|
"\n",
|
||||||
"await cognee.prune.prune_system()\n",
|
"await cognee.prune.prune_system()\n",
|
||||||
"\n",
|
"\n",
|
||||||
"data_directory_path = path.abspath(\"../.data\")\n",
|
"data_directory_path = path.abspath(\"../.data\")\n",
|
||||||
|
|
@ -83,7 +86,7 @@
|
||||||
"cognee_directory_path = path.abspath(\"../.cognee_system\")\n",
|
"cognee_directory_path = path.abspath(\"../.cognee_system\")\n",
|
||||||
"cognee.config.system_root_directory(cognee_directory_path)\n",
|
"cognee.config.system_root_directory(cognee_directory_path)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"await cognee.cognify('train_dataset')"
|
"await cognee.cognify('short_stories')"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
@ -151,10 +154,10 @@
|
||||||
"graph_client = await get_graph_client(GraphDBType.NETWORKX)\n",
|
"graph_client = await get_graph_client(GraphDBType.NETWORKX)\n",
|
||||||
"graph = graph_client.graph\n",
|
"graph = graph_client.graph\n",
|
||||||
"\n",
|
"\n",
|
||||||
"results = await search_similarity(\"Who is Ernie Grunwald?\", graph)\n",
|
"results = await search_similarity(\"Who are French girls?\", graph)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"for result in results:\n",
|
"for result in results:\n",
|
||||||
" print(\"Ernie Grunwald\" in result)\n",
|
" print(\"French girls\" in result)\n",
|
||||||
" print(result)"
|
" print(result)"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue