Merge pull request #72 from topoteretes/COG-165

feat: Add chunking
This commit is contained in:
Vasilije 2024-04-23 15:56:34 +02:00 committed by GitHub
commit 0aed5f27be
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 202 additions and 291 deletions

View file

@ -55,7 +55,7 @@ async def add_files(file_paths: List[str], dataset_name: str):
if data_directory_path not in file_path:
file_name = file_path.split("/")[-1]
file_directory_path = data_directory_path + "/" + (dataset_name.replace('.', "/") + "/" if dataset_name != "root" else "")
file_directory_path = data_directory_path + "/" + (dataset_name.replace(".", "/") + "/" if dataset_name != "root" else "")
dataset_file_path = path.join(file_directory_path, file_name)
LocalStorage.ensure_directory_exists(file_directory_path)

View file

@ -1,8 +1,11 @@
import asyncio
from uuid import uuid4
from typing import List, Union
import logging
import instructor
from openai import OpenAI
from nltk.corpus import stopwords
from cognee.config import Config
from cognee.modules.cognify.graph.add_data_chunks import add_data_chunks
from cognee.modules.cognify.graph.add_document_node import add_document_node
from cognee.modules.cognify.graph.add_classification_nodes import add_classification_nodes
@ -11,9 +14,6 @@ from cognee.modules.cognify.graph.add_summary_nodes import add_summary_nodes
from cognee.modules.cognify.graph.add_node_connections import group_nodes_by_layer, \
graph_ready_output, connect_nodes_in_graph
from cognee.modules.cognify.llm.resolve_cross_graph_references import resolve_cross_graph_references
from cognee.config import Config
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
from cognee.modules.cognify.graph.add_label_nodes import add_label_nodes
from cognee.modules.cognify.graph.add_cognitive_layers import add_cognitive_layers
@ -26,6 +26,7 @@ from cognee.modules.data.get_content_summary import get_content_summary
from cognee.modules.data.get_cognitive_layers import get_cognitive_layers
from cognee.modules.data.get_layer_graphs import get_layer_graphs
config = Config()
config.load()
@ -33,10 +34,12 @@ aclient = instructor.patch(OpenAI())
USER_ID = "default_user"
logger = logging.getLogger(__name__)
logger = logging.getLogger("cognify")
async def cognify(datasets: Union[str, List[str]] = None):
"""This function is responsible for the cognitive processing of the content."""
# Has to be loaded in advance, multithreading doesn't work without it.
stopwords.ensure_loaded()
db_engine = infrastructure_config.get_config()["database_engine"]
@ -53,10 +56,10 @@ async def cognify(datasets: Union[str, List[str]] = None):
graphs = await asyncio.gather(*awaitables)
return graphs[0]
# datasets is a dataset name string
added_datasets = db_engine.get_datasets()
dataset_files = []
# datasets is a dataset name string
dataset_name = datasets.replace(".", "_").replace(" ", "_")
for added_dataset in added_datasets:
@ -73,30 +76,35 @@ async def cognify(datasets: Union[str, List[str]] = None):
data_chunks = {}
chunk_engine = infrastructure_config.get_config()["chunk_engine"]
chunk_strategy = infrastructure_config.get_config()["chunk_strategy"]
for (dataset_name, files) in dataset_files:
for file_metadata in files[:3]:
for file_metadata in files:
with open(file_metadata["file_path"], "rb") as file:
try:
file_type = guess_file_type(file)
text = extract_text_from_file(file, file_type)
subchunks = chunk_engine.chunk_data(chunk_strategy, text, config.chunk_size, config.chunk_overlap)
if dataset_name not in data_chunks:
data_chunks[dataset_name] = []
data_chunks[dataset_name].append(dict(text = text, file_metadata = file_metadata))
for subchunk in subchunks:
data_chunks[dataset_name].append(dict(text = subchunk, chunk_id = str(uuid4()), file_metadata = file_metadata))
except FileTypeException:
logger.warning("File (%s) has an unknown file type. We are skipping it.", file_metadata["id"])
added_chunks: list[tuple[str, str, dict]] = await add_data_chunks(data_chunks)
await asyncio.gather(
*[process_text(chunk["collection"], chunk["id"], chunk["text"], chunk["file_metadata"]) for chunk in added_chunks]
*[process_text(chunk["collection"], chunk["chunk_id"], chunk["text"], chunk["file_metadata"]) for chunk in added_chunks]
)
return graph_client.graph
async def process_text(chunk_collection: str, chunk_id: str, input_text: str, file_metadata: dict):
print(f"Processing document ({file_metadata['id']}).")
print(f"Processing chunk ({chunk_id}) from document ({file_metadata['id']}).")
graph_client = await get_graph_client(infrastructure_config.get_config()["graph_engine"])
@ -115,12 +123,12 @@ async def process_text(chunk_collection: str, chunk_id: str, input_text: str, fi
categories = classified_categories,
)
print(f"Document ({document_id}) classified.")
print(f"Chunk ({chunk_id}) classified.")
content_summary = await get_content_summary(input_text)
await add_summary_nodes(graph_client, document_id, content_summary)
print(f"Document ({document_id}) summarized.")
print(f"Chunk ({chunk_id}) summarized.")
cognitive_layers = await get_cognitive_layers(input_text, classified_categories)
cognitive_layers = (await add_cognitive_layers(graph_client, document_id, cognitive_layers))[:2]
@ -132,8 +140,6 @@ async def process_text(chunk_collection: str, chunk_id: str, input_text: str, fi
db_engine = infrastructure_config.get_config()["database_engine"]
relevant_documents_to_connect = db_engine.fetch_cognify_data(excluded_document_id = file_metadata["id"])
print("Relevant documents to connect are: ", relevant_documents_to_connect)
list_of_nodes = []
relevant_documents_to_connect.append({
@ -144,13 +150,9 @@ async def process_text(chunk_collection: str, chunk_id: str, input_text: str, fi
node_descriptions_to_match = await graph_client.extract_node_description(document["layer_id"])
list_of_nodes.extend(node_descriptions_to_match)
print("List of nodes are: ", len(list_of_nodes))
nodes_by_layer = await group_nodes_by_layer(list_of_nodes)
print("Nodes by layer are: ", str(nodes_by_layer)[:5000])
results = await resolve_cross_graph_references(nodes_by_layer)
print("Results are: ", str(results)[:3000])
relationships = graph_ready_output(results)
@ -160,4 +162,4 @@ async def process_text(chunk_collection: str, chunk_id: str, input_text: str, fi
score_threshold = infrastructure_config.get_config()["intra_layer_score_treshold"]
)
print(f"Document ({document_id}) cognified.")
print(f"Chunk ({chunk_id}) cognified.")

View file

@ -67,4 +67,8 @@ class config():
infrastructure_config.set_config({
"connect_documents": connect_documents
})
@staticmethod
def set_chunk_strategy(chunk_strategy: object):
infrastructure_config.set_config({
"chunk_strategy": chunk_strategy
})

View file

@ -8,7 +8,7 @@ from dataclasses import dataclass, field
from pathlib import Path
from dotenv import load_dotenv
from cognee.root_dir import get_absolute_path
from cognee.shared.data_models import ChunkStrategy
base_dir = Path(__file__).resolve().parent.parent
# Load the .env file from the base directory
@ -116,6 +116,11 @@ class Config:
# Client ID
anon_clientid: Optional[str] = field(default_factory=lambda: uuid.uuid4().hex)
#Chunking parameters
chunk_size: int = 1500
chunk_overlap: int = 0
chunk_strategy: str = ChunkStrategy.PARAGRAPH
def load(self):
"""Loads the configuration from a file or environment variables."""
config = configparser.ConfigParser()

View file

@ -6,6 +6,7 @@ from .databases.vector.embeddings.DefaultEmbeddingEngine import DefaultEmbedding
from .llm.llm_interface import LLMInterface
from .llm.openai.adapter import OpenAIAdapter
from .files.storage import LocalStorage
from .data.chunking.DefaultChunkEngine import DefaultChunkEngine
from ..shared.data_models import GraphDBType, DefaultContentPrediction, KnowledgeGraph, SummarizedContent, \
LabeledContent, DefaultCognitiveLayer
@ -30,6 +31,8 @@ class InfrastructureConfig():
connect_documents = config.connect_documents
database_directory_path: str = None
database_file_path: str = None
chunk_strategy = config.chunk_strategy
chunk_engine = None
def get_config(self, config_entity: str = None) -> dict:
if (config_entity is None or config_entity == "database_engine") and self.database_engine is None:
@ -69,6 +72,12 @@ class InfrastructureConfig():
if self.connect_documents is None:
self.connect_documents = config.connect_documents
if self.chunk_strategy is None:
self.chunk_strategy = config.chunk_strategy
if self.chunk_engine is None:
self.chunk_engine = DefaultChunkEngine()
if (config_entity is None or config_entity == "llm_engine") and self.llm_engine is None:
self.llm_engine = OpenAIAdapter(config.openai_key, config.openai_model)
@ -120,16 +129,18 @@ class InfrastructureConfig():
"embedding_engine": self.embedding_engine,
"connect_documents": self.connect_documents,
"database_directory_path": self.database_directory_path,
"database_path": self.database_file_path
"database_path": self.database_file_path,
"chunk_strategy": self.chunk_strategy,
"chunk_engine": self.chunk_engine,
}
def set_config(self, new_config: dict):
if "system_root_directory" in new_config:
self.system_root_directory = new_config["system_root_directory"]
if "data_root_directory" in new_config:
self.data_root_directory = new_config["data_root_directory"]
if "database_engine" in new_config:
self.database_engine = new_config["database_engine"]
@ -169,4 +180,10 @@ class InfrastructureConfig():
if "connect_documents" in new_config:
self.connect_documents = new_config["connect_documents"]
if "chunk_strategy" in new_config:
self.chunk_strategy = new_config["chunk_strategy"]
if "chunk_engine" in new_config:
self.chunk_engine = new_config["chunk_engine"]
infrastructure_config = InfrastructureConfig()

View file

@ -0,0 +1,125 @@
""" Chunking strategies for splitting text into smaller parts."""
from __future__ import annotations
import re
from cognee.shared.data_models import ChunkStrategy
class DefaultChunkEngine():
@staticmethod
def _split_text_with_regex(
text: str, separator: str, keep_separator: bool
) -> list[str]:
# Now that we have the separator, split the text
if separator:
if keep_separator:
# The parentheses in the pattern keep the delimiters in the result.
_splits = re.split(f"({separator})", text)
splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
if len(_splits) % 2 == 0:
splits += _splits[-1:]
splits = [_splits[0]] + splits
else:
splits = re.split(separator, text)
else:
splits = list(text)
return [s for s in splits if s != ""]
@staticmethod
def chunk_data(
chunk_strategy = None,
source_data = None,
chunk_size = None,
chunk_overlap = None,
):
"""
Chunk data based on the specified strategy.
Parameters:
- chunk_strategy: The strategy to use for chunking.
- source_data: The data to be chunked.
- chunk_size: The size of each chunk.
- chunk_overlap: The overlap between chunks.
Returns:
- The chunked data.
"""
if chunk_strategy == ChunkStrategy.PARAGRAPH:
chunked_data = DefaultChunkEngine.chunk_data_by_paragraph(source_data,chunk_size, chunk_overlap)
elif chunk_strategy == ChunkStrategy.SENTENCE:
chunked_data = DefaultChunkEngine.chunk_by_sentence(source_data, chunk_size, chunk_overlap)
elif chunk_strategy == ChunkStrategy.EXACT:
chunked_data = DefaultChunkEngine.chunk_data_exact(source_data, chunk_size, chunk_overlap)
return chunked_data
@staticmethod
def chunk_data_exact(data_chunks, chunk_size, chunk_overlap):
data = "".join(data_chunks)
chunks = []
for i in range(0, len(data), chunk_size - chunk_overlap):
chunks.append(data[i:i + chunk_size])
return chunks
@staticmethod
def chunk_by_sentence(data_chunks, chunk_size, overlap):
# Split by periods, question marks, exclamation marks, and ellipses
data = "".join(data_chunks)
# The regular expression is used to find series of charaters that end with one the following chaacters (. ! ? ...)
sentence_endings = r'(?<=[.!?…]) +'
sentences = re.split(sentence_endings, data)
sentence_chunks = []
for sentence in sentences:
if len(sentence) > chunk_size:
chunks = DefaultChunkEngine.chunk_data_exact([sentence], chunk_size, overlap)
sentence_chunks.extend(chunks)
else:
sentence_chunks.append(sentence)
return sentence_chunks
@staticmethod
def chunk_data_by_paragraph(data_chunks, chunk_size, overlap, bound = 0.75):
data = "".join(data_chunks)
total_length = len(data)
chunks = []
check_bound = int(bound * chunk_size)
start_idx = 0
chunk_splitter = "\n\n"
if data.find("\n\n") == -1:
chunk_splitter = "\n"
while start_idx < total_length:
# Set the end index to the minimum of start_idx + default_chunk_size or total_length
end_idx = min(start_idx + chunk_size, total_length)
# Find the next paragraph index within the current chunk and bound
next_paragraph_index = data.find(chunk_splitter, start_idx + check_bound, end_idx)
# If a next paragraph index is found within the current chunk
if next_paragraph_index != -1:
# Update end_idx to include the paragraph delimiter
end_idx = next_paragraph_index + 2
end_index = end_idx + overlap
chunk_text = data[start_idx:end_index]
while chunk_text[-1] != "." and end_index < total_length:
chunk_text += data[end_index]
end_index += 1
end_idx = end_index - overlap
chunks.append(chunk_text.replace("\n", "").strip())
# Update start_idx to be the current end_idx
start_idx = end_idx
return chunks

View file

@ -1,10 +1,10 @@
from typing import TypedDict
from uuid import uuid4
from cognee.infrastructure import infrastructure_config
from cognee.infrastructure.databases.vector import DataPoint
class TextChunk(TypedDict):
text: str
chunk_id: str
file_metadata: dict
async def add_data_chunks(dataset_data_chunks: dict[str, list[TextChunk]]):
@ -20,20 +20,20 @@ async def add_data_chunks(dataset_data_chunks: dict[str, list[TextChunk]]):
dataset_chunks = [
dict(
id = str(uuid4()),
chunk_id = chunk["chunk_id"],
collection = dataset_name,
text = chunk["text"],
file_metadata = chunk["file_metadata"],
) for chunk in chunks
]
identified_chunks.extend(dataset_chunks)
await vector_client.create_data_points(
dataset_name,
[
DataPoint(
id = chunk["id"],
id = chunk["chunk_id"],
payload = dict(text = chunk["text"]),
embed_field = "text"
) for chunk in dataset_chunks

View file

@ -1,243 +0,0 @@
""" Chunking strategies for splitting text into smaller parts."""
from __future__ import annotations
from cognee.shared.data_models import ChunkStrategy
import re
from typing import Any, List, Optional
class CharacterTextSplitter():
"""Splitting text that looks at characters."""
def __init__(
self, separator: str = "\n\n", is_separator_regex: bool = False, **kwargs: Any
) -> None:
"""Create a new TextSplitter."""
super().__init__(**kwargs)
self._separator = separator
self._is_separator_regex = is_separator_regex
def split_text(self, text: str) -> List[str]:
"""Split incoming text and return chunks."""
# First we naively split the large input into a bunch of smaller ones.
separator = (
self._separator if self._is_separator_regex else re.escape(self._separator)
)
splits = _split_text_with_regex(text, separator, self._keep_separator)
_separator = "" if self._keep_separator else self._separator
return self._merge_splits(splits, _separator)
def _split_text_with_regex(
text: str, separator: str, keep_separator: bool
) -> List[str]:
# Now that we have the separator, split the text
if separator:
if keep_separator:
# The parentheses in the pattern keep the delimiters in the result.
_splits = re.split(f"({separator})", text)
splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
if len(_splits) % 2 == 0:
splits += _splits[-1:]
splits = [_splits[0]] + splits
else:
splits = re.split(separator, text)
else:
splits = list(text)
return [s for s in splits if s != ""]
class RecursiveCharacterTextSplitter():
"""Splitting text by recursively look at characters.
Recursively tries to split by different characters to find one
that works.
"""
def __init__(
self,
separators: Optional[List[str]] = None,
keep_separator: bool = True,
is_separator_regex: bool = False,
**kwargs: Any,
) -> None:
"""Create a new TextSplitter."""
super().__init__(keep_separator=keep_separator, **kwargs)
self._separators = separators or ["\n\n", "\n", " ", ""]
self._is_separator_regex = is_separator_regex
def _split_text(self, text: str, separators: List[str]) -> List[str]:
"""Split incoming text and return chunks."""
final_chunks = []
# Get appropriate separator to use
separator = separators[-1]
new_separators = []
for i, _s in enumerate(separators):
_separator = _s if self._is_separator_regex else re.escape(_s)
if _s == "":
separator = _s
break
if re.search(_separator, text):
separator = _s
new_separators = separators[i + 1 :]
break
_separator = separator if self._is_separator_regex else re.escape(separator)
splits = _split_text_with_regex(text, _separator, self._keep_separator)
# Now go merging things, recursively splitting longer texts.
_good_splits = []
_separator = "" if self._keep_separator else separator
for s in splits:
if self._length_function(s) < self._chunk_size:
_good_splits.append(s)
else:
if _good_splits:
merged_text = self._merge_splits(_good_splits, _separator)
final_chunks.extend(merged_text)
_good_splits = []
if not new_separators:
final_chunks.append(s)
else:
other_info = self._split_text(s, new_separators)
final_chunks.extend(other_info)
if _good_splits:
merged_text = self._merge_splits(_good_splits, _separator)
final_chunks.extend(merged_text)
return final_chunks
def split_text(self, text: str) -> List[str]:
return self._split_text(text, self._separators)
def chunk_data(chunk_strategy=None, source_data=None, chunk_size=None, chunk_overlap=None):
"""
Chunk data based on the specified strategy.
Parameters:
- chunk_strategy: The strategy to use for chunking.
- source_data: The data to be chunked.
- chunk_size: The size of each chunk.
- chunk_overlap: The overlap between chunks.
Returns:
- The chunked data.
"""
if chunk_strategy == ChunkStrategy.VANILLA:
chunked_data = vanilla_chunker(source_data, chunk_size, chunk_overlap)
elif chunk_strategy == ChunkStrategy.PARAGRAPH:
chunked_data = chunk_data_by_paragraph(source_data,chunk_size, chunk_overlap)
elif chunk_strategy == ChunkStrategy.SENTENCE:
chunked_data = chunk_by_sentence(source_data, chunk_size, chunk_overlap)
elif chunk_strategy == ChunkStrategy.EXACT:
chunked_data = chunk_data_exact(source_data, chunk_size, chunk_overlap)
elif chunk_strategy == ChunkStrategy.SUMMARY:
chunked_data = summary_chunker(source_data, chunk_size, chunk_overlap)
else:
chunked_data = vanilla_chunker(source_data, chunk_size, chunk_overlap)
return chunked_data
def vanilla_chunker(source_data, chunk_size=100, chunk_overlap=20):
# adapt this for different chunking strategies
text_splitter = RecursiveCharacterTextSplitter(
# Set a really small chunk size, just to show.
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=len
)
# try:
# pages = text_splitter.create_documents([source_data])
# except:
# try:
pages = text_splitter.create_documents([source_data])
# except:
# pages = text_splitter.create_documents(source_data.content)
# pages = source_data.load_and_split()
return pages
def summary_chunker(source_data, chunk_size=400, chunk_overlap=20):
"""
Chunk the given source data into smaller parts, returning the first five and last five chunks.
Parameters:
- source_data (str): The source data to be chunked.
- chunk_size (int): The size of each chunk.
- chunk_overlap (int): The overlap between consecutive chunks.
Returns:
- List: A list containing the first five and last five chunks of the chunked source data.
"""
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=len
)
try:
pages = text_splitter.create_documents([source_data])
except:
pages = text_splitter.create_documents(source_data.content)
# Return the first 5 and last 5 chunks
if len(pages) > 10:
return pages[:5] + pages[-5:]
else:
return pages # Return all chunks if there are 10 or fewer
def chunk_data_exact(data_chunks, chunk_size, chunk_overlap):
data = "".join(data_chunks)
chunks = []
for i in range(0, len(data), chunk_size - chunk_overlap):
chunks.append(data[i:i + chunk_size])
return chunks
def chunk_by_sentence(data_chunks, chunk_size, overlap):
# Split by periods, question marks, exclamation marks, and ellipses
data = "".join(data_chunks)
# The regular expression is used to find series of charaters that end with one the following chaacters (. ! ? ...)
sentence_endings = r'(?<=[.!?…]) +'
sentences = re.split(sentence_endings, data)
sentence_chunks = []
for sentence in sentences:
if len(sentence) > chunk_size:
chunks = chunk_data_exact([sentence], chunk_size, overlap)
sentence_chunks.extend(chunks)
else:
sentence_chunks.append(sentence)
return sentence_chunks
def chunk_data_by_paragraph(data_chunks, chunk_size, overlap, bound=0.75):
data = "".join(data_chunks)
total_length = len(data)
chunks = []
check_bound = int(bound * chunk_size)
start_idx = 0
while start_idx < total_length:
# Set the end index to the minimum of start_idx + default_chunk_size or total_length
end_idx = min(start_idx + chunk_size, total_length)
# Find the next paragraph index within the current chunk and bound
next_paragraph_index = data.find('\n\n', start_idx + check_bound, end_idx)
# If a next paragraph index is found within the current chunk
if next_paragraph_index != -1:
# Update end_idx to include the paragraph delimiter
end_idx = next_paragraph_index + 2
chunks.append(data[start_idx:end_idx + overlap])
# Update start_idx to be the current end_idx
start_idx = end_idx
return chunks

View file

@ -36,8 +36,6 @@ class ChunkStrategy(Enum):
EXACT = "exact"
PARAGRAPH = "paragraph"
SENTENCE = "sentence"
VANILLA = "vanilla"
SUMMARY = "summary"
class MemorySummary(BaseModel):
""" Memory summary. """

View file

@ -23,24 +23,24 @@
"colbertv2_wiki17_abstracts = dspy.ColBERTv2(url = \"http://20.102.90.50:2017/wiki17_abstracts\")\n",
"dspy.configure(rm = colbertv2_wiki17_abstracts)\n",
"\n",
"dataset = HotPotQA(\n",
" train_seed = 1,\n",
" train_size = 10,\n",
" eval_seed = 2023,\n",
" dev_size = 0,\n",
" test_size = 0,\n",
" keep_details = True,\n",
")\n",
"# dataset = HotPotQA(\n",
"# train_seed = 1,\n",
"# train_size = 10,\n",
"# eval_seed = 2023,\n",
"# dev_size = 0,\n",
"# test_size = 0,\n",
"# keep_details = True,\n",
"# )\n",
"\n",
"texts_to_add = []\n",
"# texts_to_add = []\n",
"\n",
"for train_case in dataset.train:\n",
" train_case_text = \"\\r\\n\".join(\" \".join(context_sentences) for context_sentences in train_case.get(\"context\")[\"sentences\"])\n",
"# for train_case in dataset.train:\n",
"# train_case_text = \"\\r\\n\".join(\" \".join(context_sentences) for context_sentences in train_case.get(\"context\")[\"sentences\"])\n",
"\n",
" texts_to_add.append(train_case_text)\n",
"# texts_to_add.append(train_case_text)\n",
"\n",
"dataset_name = \"train_dataset\"\n",
"await cognee.add(texts_to_add, dataset_name)\n"
"dataset_name = \"short_stories\"\n",
"await cognee.add(\"data://\" + data_directory_path, dataset_name)\n"
]
},
{
@ -61,7 +61,7 @@
"\n",
"print(cognee.datasets.list_datasets())\n",
"\n",
"train_dataset = cognee.datasets.query_data('train_dataset')\n",
"train_dataset = cognee.datasets.query_data(\"short_stories\")\n",
"print(len(train_dataset))"
]
},
@ -73,8 +73,11 @@
"outputs": [],
"source": [
"from os import path\n",
"import logging\n",
"import cognee\n",
"\n",
"logging.basicConfig(level = logging.INFO)\n",
"\n",
"await cognee.prune.prune_system()\n",
"\n",
"data_directory_path = path.abspath(\"../.data\")\n",
@ -83,7 +86,7 @@
"cognee_directory_path = path.abspath(\"../.cognee_system\")\n",
"cognee.config.system_root_directory(cognee_directory_path)\n",
"\n",
"await cognee.cognify('train_dataset')"
"await cognee.cognify('short_stories')"
]
},
{
@ -151,10 +154,10 @@
"graph_client = await get_graph_client(GraphDBType.NETWORKX)\n",
"graph = graph_client.graph\n",
"\n",
"results = await search_similarity(\"Who is Ernie Grunwald?\", graph)\n",
"results = await search_similarity(\"Who are French girls?\", graph)\n",
"\n",
"for result in results:\n",
" print(\"Ernie Grunwald\" in result)\n",
" print(\"French girls\" in result)\n",
" print(result)"
]
}