Merge pull request #13 from topoteretes/dynamic_metadata_loading_base_splitters
Dynamic metadata loading base splitters
This commit is contained in:
commit
fd7d05e705
10 changed files with 374 additions and 229 deletions
85
level_2/chunkers/chunkers.py
Normal file
85
level_2/chunkers/chunkers.py
Normal file
|
|
@ -0,0 +1,85 @@
|
||||||
|
from langchain.document_loaders import PyPDFLoader
|
||||||
|
|
||||||
|
from level_2.shared.chunk_strategy import ChunkStrategy
|
||||||
|
import re
|
||||||
|
def chunk_data(chunk_strategy=None, source_data=None, chunk_size=None, chunk_overlap=None):
|
||||||
|
|
||||||
|
if chunk_strategy == ChunkStrategy.VANILLA:
|
||||||
|
chunked_data = vanilla_chunker(source_data, chunk_size, chunk_overlap)
|
||||||
|
|
||||||
|
elif chunk_strategy == ChunkStrategy.PARAGRAPH:
|
||||||
|
chunked_data = chunk_data_by_paragraph(source_data,chunk_size, chunk_overlap)
|
||||||
|
|
||||||
|
elif chunk_strategy == ChunkStrategy.SENTENCE:
|
||||||
|
chunked_data = chunk_by_sentence(source_data, chunk_size, chunk_overlap)
|
||||||
|
elif chunk_strategy == ChunkStrategy.EXACT:
|
||||||
|
chunked_data = chunk_data_exact(source_data, chunk_size, chunk_overlap)
|
||||||
|
else:
|
||||||
|
chunked_data = vanilla_chunker(source_data, chunk_size, chunk_overlap)
|
||||||
|
return chunked_data
|
||||||
|
|
||||||
|
|
||||||
|
def vanilla_chunker(source_data, chunk_size, chunk_overlap):
|
||||||
|
# loader = PyPDFLoader(source_data)
|
||||||
|
# adapt this for different chunking strategies
|
||||||
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
|
text_splitter = RecursiveCharacterTextSplitter(
|
||||||
|
# Set a really small chunk size, just to show.
|
||||||
|
chunk_size=100,
|
||||||
|
chunk_overlap=20,
|
||||||
|
length_function=len
|
||||||
|
)
|
||||||
|
pages = text_splitter.create_documents([source_data])
|
||||||
|
# pages = source_data.load_and_split()
|
||||||
|
return pages
|
||||||
|
def chunk_data_exact(data_chunks, chunk_size, chunk_overlap):
|
||||||
|
data = "".join(data_chunks)
|
||||||
|
chunks = []
|
||||||
|
for i in range(0, len(data), chunk_size - chunk_overlap):
|
||||||
|
chunks.append(data[i:i + chunk_size])
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_by_sentence(data_chunks, chunk_size, overlap):
|
||||||
|
# Split by periods, question marks, exclamation marks, and ellipses
|
||||||
|
data = "".join(data_chunks)
|
||||||
|
|
||||||
|
# The regular expression is used to find series of charaters that end with one the following chaacters (. ! ? ...)
|
||||||
|
sentence_endings = r'(?<=[.!?…]) +'
|
||||||
|
sentences = re.split(sentence_endings, data)
|
||||||
|
|
||||||
|
sentence_chunks = []
|
||||||
|
for sentence in sentences:
|
||||||
|
if len(sentence) > chunk_size:
|
||||||
|
chunks = chunk_data_exact([sentence], chunk_size, overlap)
|
||||||
|
sentence_chunks.extend(chunks)
|
||||||
|
else:
|
||||||
|
sentence_chunks.append(sentence)
|
||||||
|
return sentence_chunks
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_data_by_paragraph(data_chunks, chunk_size, overlap, bound=0.75):
|
||||||
|
data = "".join(data_chunks)
|
||||||
|
total_length = len(data)
|
||||||
|
chunks = []
|
||||||
|
check_bound = int(bound * chunk_size)
|
||||||
|
start_idx = 0
|
||||||
|
|
||||||
|
while start_idx < total_length:
|
||||||
|
# Set the end index to the minimum of start_idx + default_chunk_size or total_length
|
||||||
|
end_idx = min(start_idx + chunk_size, total_length)
|
||||||
|
|
||||||
|
# Find the next paragraph index within the current chunk and bound
|
||||||
|
next_paragraph_index = data.find('\n\n', start_idx + check_bound, end_idx)
|
||||||
|
|
||||||
|
# If a next paragraph index is found within the current chunk
|
||||||
|
if next_paragraph_index != -1:
|
||||||
|
# Update end_idx to include the paragraph delimiter
|
||||||
|
end_idx = next_paragraph_index + 2
|
||||||
|
|
||||||
|
chunks.append(data[start_idx:end_idx + overlap])
|
||||||
|
|
||||||
|
# Update start_idx to be the current end_idx
|
||||||
|
start_idx = end_idx
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
@ -362,9 +362,9 @@ class EpisodicBuffer(BaseMemory):
|
||||||
"""Determines what operations are available for the user to process PDFs"""
|
"""Determines what operations are available for the user to process PDFs"""
|
||||||
|
|
||||||
return [
|
return [
|
||||||
"translate",
|
"retrieve over time",
|
||||||
"structure",
|
"save to personal notes",
|
||||||
"fetch from vector store"
|
"translate to german"
|
||||||
# "load to semantic memory",
|
# "load to semantic memory",
|
||||||
# "load to episodic memory",
|
# "load to episodic memory",
|
||||||
# "load to buffer",
|
# "load to buffer",
|
||||||
|
|
@ -594,6 +594,8 @@ class EpisodicBuffer(BaseMemory):
|
||||||
episodic_context = await chain.arun(input=prompt_filter_chunk, verbose=True)
|
episodic_context = await chain.arun(input=prompt_filter_chunk, verbose=True)
|
||||||
print(cb)
|
print(cb)
|
||||||
|
|
||||||
|
print("HERE IS THE EPISODIC CONTEXT", episodic_context)
|
||||||
|
|
||||||
class BufferModulators(BaseModel):
|
class BufferModulators(BaseModel):
|
||||||
attention_modulators: Dict[str, float] = Field(... , description="Attention modulators")
|
attention_modulators: Dict[str, float] = Field(... , description="Attention modulators")
|
||||||
|
|
||||||
|
|
@ -728,55 +730,6 @@ class EpisodicBuffer(BaseMemory):
|
||||||
complete_agent_prompt= f" Document context is: {document_from_vectorstore} \n Task is : {task['task_order']} {task['task_name']} {task['operation']} "
|
complete_agent_prompt= f" Document context is: {document_from_vectorstore} \n Task is : {task['task_order']} {task['task_name']} {task['operation']} "
|
||||||
|
|
||||||
# task['vector_store_context_results']=document_context_result_parsed.dict()
|
# task['vector_store_context_results']=document_context_result_parsed.dict()
|
||||||
class PromptWrapper(BaseModel):
|
|
||||||
observation: str = Field(
|
|
||||||
description="observation we want to fetch from vectordb"
|
|
||||||
)
|
|
||||||
|
|
||||||
@tool(
|
|
||||||
"convert_to_structured", args_schema=PromptWrapper, return_direct=True
|
|
||||||
)
|
|
||||||
def convert_to_structured(observation=None, json_schema=None):
|
|
||||||
"""Convert unstructured data to structured data"""
|
|
||||||
BASE_DIR = os.getcwd()
|
|
||||||
json_path = os.path.join(
|
|
||||||
BASE_DIR, "schema_registry", "ticket_schema.json"
|
|
||||||
)
|
|
||||||
|
|
||||||
def load_json_or_infer_schema(file_path, document_path):
|
|
||||||
"""Load JSON schema from file or infer schema from text"""
|
|
||||||
|
|
||||||
# Attempt to load the JSON file
|
|
||||||
with open(file_path, "r") as file:
|
|
||||||
json_schema = json.load(file)
|
|
||||||
return json_schema
|
|
||||||
|
|
||||||
json_schema = load_json_or_infer_schema(json_path, None)
|
|
||||||
|
|
||||||
def run_open_ai_mapper(observation=None, json_schema=None):
|
|
||||||
"""Convert unstructured data to structured data"""
|
|
||||||
|
|
||||||
prompt_msgs = [
|
|
||||||
SystemMessage(
|
|
||||||
content="You are a world class algorithm converting unstructured data into structured data."
|
|
||||||
),
|
|
||||||
HumanMessage(
|
|
||||||
content="Convert unstructured data to structured data:"
|
|
||||||
),
|
|
||||||
HumanMessagePromptTemplate.from_template("{input}"),
|
|
||||||
HumanMessage(
|
|
||||||
content="Tips: Make sure to answer in the correct format"
|
|
||||||
),
|
|
||||||
]
|
|
||||||
prompt_ = ChatPromptTemplate(messages=prompt_msgs)
|
|
||||||
chain_funct = create_structured_output_chain(
|
|
||||||
json_schema, prompt=prompt_, llm=self.llm, verbose=True
|
|
||||||
)
|
|
||||||
output = chain_funct.run(input=observation, llm=self.llm)
|
|
||||||
return output
|
|
||||||
|
|
||||||
result = run_open_ai_mapper(observation, json_schema)
|
|
||||||
return result
|
|
||||||
|
|
||||||
class FetchText(BaseModel):
|
class FetchText(BaseModel):
|
||||||
observation: str = Field(description="observation we want to translate")
|
observation: str = Field(description="observation we want to translate")
|
||||||
|
|
@ -802,7 +755,7 @@ class EpisodicBuffer(BaseMemory):
|
||||||
|
|
||||||
agent = initialize_agent(
|
agent = initialize_agent(
|
||||||
llm=self.llm,
|
llm=self.llm,
|
||||||
tools=[fetch_from_vector_store,translate_to_de, convert_to_structured],
|
tools=[fetch_from_vector_store,translate_to_de],
|
||||||
agent=AgentType.OPENAI_FUNCTIONS,
|
agent=AgentType.OPENAI_FUNCTIONS,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
)
|
)
|
||||||
|
|
@ -1084,18 +1037,21 @@ async def main():
|
||||||
"source": "url",
|
"source": "url",
|
||||||
"path": "https://www.ibiblio.org/ebooks/London/Call%20of%20Wild.pdf"
|
"path": "https://www.ibiblio.org/ebooks/London/Call%20of%20Wild.pdf"
|
||||||
}
|
}
|
||||||
# load_jack_london = await memory._add_semantic_memory(observation = "bla", loader_settings=loader_settings, params=params)
|
load_jack_london = await memory._add_semantic_memory(observation = "bla", loader_settings=loader_settings, params=params)
|
||||||
# print(load_jack_london)
|
print(load_jack_london)
|
||||||
|
|
||||||
modulator = {"relevance": 0.1, "frequency": 0.1}
|
modulator = {"relevance": 0.1, "frequency": 0.1}
|
||||||
|
|
||||||
|
# fdsf = await memory._fetch_semantic_memory(observation="bla", params=None)
|
||||||
|
# print(fdsf)
|
||||||
# await memory._delete_episodic_memory()
|
# await memory._delete_episodic_memory()
|
||||||
#
|
#
|
||||||
run_main_buffer = await memory._create_buffer_context(
|
# run_main_buffer = await memory._create_buffer_context(
|
||||||
user_input="I want to know how does Buck adapt to life in the wild and then have that info translated to german ",
|
# user_input="I want to know how does Buck adapt to life in the wild and then have that info translated to german ",
|
||||||
params=params,
|
# params=params,
|
||||||
attention_modulators=modulator,
|
# attention_modulators=modulator,
|
||||||
)
|
# )
|
||||||
print(run_main_buffer)
|
# print(run_main_buffer)
|
||||||
# #
|
# #
|
||||||
# run_main_buffer = await memory._run_main_buffer(
|
# run_main_buffer = await memory._run_main_buffer(
|
||||||
# user_input="I want to know how does Buck adapt to life in the wild and then have that info translated to german ",
|
# user_input="I want to know how does Buck adapt to life in the wild and then have that info translated to german ",
|
||||||
|
|
|
||||||
39
level_2/loaders/loaders.py
Normal file
39
level_2/loaders/loaders.py
Normal file
|
|
@ -0,0 +1,39 @@
|
||||||
|
import os
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
import fitz
|
||||||
|
from level_2.chunkers.chunkers import chunk_data
|
||||||
|
from langchain.document_loaders import PyPDFLoader
|
||||||
|
|
||||||
|
import requests
|
||||||
|
def _document_loader( observation: str, loader_settings: dict):
|
||||||
|
# Check the format of the document
|
||||||
|
document_format = loader_settings.get("format", "text")
|
||||||
|
|
||||||
|
if document_format == "PDF":
|
||||||
|
if loader_settings.get("source") == "url":
|
||||||
|
pdf_response = requests.get(loader_settings["path"])
|
||||||
|
pdf_stream = BytesIO(pdf_response.content)
|
||||||
|
with fitz.open(stream=pdf_stream, filetype='pdf') as doc:
|
||||||
|
file_content = ""
|
||||||
|
for page in doc:
|
||||||
|
file_content += page.get_text()
|
||||||
|
pages = chunk_data(chunk_strategy= 'VANILLA', source_data=file_content)
|
||||||
|
|
||||||
|
return pages
|
||||||
|
elif loader_settings.get("source") == "file":
|
||||||
|
# Process the PDF using PyPDFLoader
|
||||||
|
# might need adapting for different loaders + OCR
|
||||||
|
# need to test the path
|
||||||
|
loader = PyPDFLoader(loader_settings["path"])
|
||||||
|
pages = loader.load_and_split()
|
||||||
|
return pages
|
||||||
|
|
||||||
|
elif document_format == "text":
|
||||||
|
# Process the text directly
|
||||||
|
return observation
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported document format: {document_format}")
|
||||||
|
|
||||||
|
|
||||||
54
level_2/poetry.lock
generated
54
level_2/poetry.lock
generated
|
|
@ -2491,6 +2491,58 @@ files = [
|
||||||
[package.extras]
|
[package.extras]
|
||||||
plugins = ["importlib-metadata"]
|
plugins = ["importlib-metadata"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pymupdf"
|
||||||
|
version = "1.23.3"
|
||||||
|
description = "A high performance Python library for data extraction, analysis, conversion & manipulation of PDF (and other) documents."
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "PyMuPDF-1.23.3-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:52699939b7482c8c566a181e2a980a6801c91959ee96dae5663070fd2b960c6b"},
|
||||||
|
{file = "PyMuPDF-1.23.3-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:95408d57ed77f3c396880a3fc0feae068c4bf577e7e2c761d24a345138062f8d"},
|
||||||
|
{file = "PyMuPDF-1.23.3-cp310-none-manylinux2014_aarch64.whl", hash = "sha256:5eefd674e338ddd82cd9179ad7d4c2160796efd6c0d4cd1098b5314ff78688d7"},
|
||||||
|
{file = "PyMuPDF-1.23.3-cp310-none-manylinux2014_x86_64.whl", hash = "sha256:c7696034f5f5472d1e6d3f3556858cf85e095b66c158a80b527facfa83542aee"},
|
||||||
|
{file = "PyMuPDF-1.23.3-cp310-none-win32.whl", hash = "sha256:f3c6d427381f4ef76bec4e862c8969845e90bc842b3c534800be9cb6fe6b0e3b"},
|
||||||
|
{file = "PyMuPDF-1.23.3-cp310-none-win_amd64.whl", hash = "sha256:0fd19017d4c7791146e38621d878393136e25a2a4fadd0372a98ab2a9aabc0c5"},
|
||||||
|
{file = "PyMuPDF-1.23.3-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:0e88408dea51492431b111a721d88a4f4c2176786734b16374d77a421f410139"},
|
||||||
|
{file = "PyMuPDF-1.23.3-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:c4dbf5e851373f4633b57187b0ae3dcde0efad6ef5969c4de14bb9a52a796261"},
|
||||||
|
{file = "PyMuPDF-1.23.3-cp311-none-manylinux2014_aarch64.whl", hash = "sha256:7218c1099205edb3357cb5713661d11d7c04aaa910645da64e17c2d050d61352"},
|
||||||
|
{file = "PyMuPDF-1.23.3-cp311-none-manylinux2014_x86_64.whl", hash = "sha256:0304d5def03d2bedf951179624ea636470b5ee0a706ea37636f7a3b2b08561a5"},
|
||||||
|
{file = "PyMuPDF-1.23.3-cp311-none-win32.whl", hash = "sha256:35fe66d80cdc948ed55ac70c94b2e7f740fc08309c4ce125228ce0042a2fbba8"},
|
||||||
|
{file = "PyMuPDF-1.23.3-cp311-none-win_amd64.whl", hash = "sha256:e643e4f30d1a5e358a8f65eab66dd0ea33f8170d61eb7549f0d227086c82d315"},
|
||||||
|
{file = "PyMuPDF-1.23.3-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:95065c21c39dc93c4e224a2ac3c903bf31d635cdb569338d79e9befbac9755eb"},
|
||||||
|
{file = "PyMuPDF-1.23.3-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:0c06610d78a86fcbfbcea77320c54f561ac4d568666d621afcf1109e8cfc829b"},
|
||||||
|
{file = "PyMuPDF-1.23.3-cp38-none-manylinux2014_aarch64.whl", hash = "sha256:6e4ef7e65b3fb7f9248f1f2dc530f10d0e00a8080dd5da52808e6638a9868a10"},
|
||||||
|
{file = "PyMuPDF-1.23.3-cp38-none-manylinux2014_x86_64.whl", hash = "sha256:d51b848d45e09e7fedfdeb0880a2a14872e25dd4e0932b9abf6a36a69bf01f6a"},
|
||||||
|
{file = "PyMuPDF-1.23.3-cp38-none-win32.whl", hash = "sha256:42b879913a07fb251251af20e46747abc3d5d0276a48d2c28e128f5f88ef3dcd"},
|
||||||
|
{file = "PyMuPDF-1.23.3-cp38-none-win_amd64.whl", hash = "sha256:a283236e09c056798ecaf6e0872790c63d91edf6d5f72b76504715d6b88da976"},
|
||||||
|
{file = "PyMuPDF-1.23.3-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:6329a223ae38641fe4ff081beffd33f5e3be800c0409569b64a33b70f1b544cf"},
|
||||||
|
{file = "PyMuPDF-1.23.3-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:640a5ada4479a2c69b811c91f163a7b55f7fe1c323b861373d6068893cc9e9e0"},
|
||||||
|
{file = "PyMuPDF-1.23.3-cp39-none-manylinux2014_aarch64.whl", hash = "sha256:2f555d264f08e091eaf9fd27c33ba9bfdc39ac8d09aa12195ab529bcca79229d"},
|
||||||
|
{file = "PyMuPDF-1.23.3-cp39-none-manylinux2014_x86_64.whl", hash = "sha256:96dc89254d78bddac8434be7b9f4c354fe57b224b5420614cde9c2f1d2f1355e"},
|
||||||
|
{file = "PyMuPDF-1.23.3-cp39-none-win32.whl", hash = "sha256:f9a1d2f7484bde2ec81f3c88641f7a8b7f52450b807408ae7a340ddecb424659"},
|
||||||
|
{file = "PyMuPDF-1.23.3-cp39-none-win_amd64.whl", hash = "sha256:7cfceb91048665965d826023c4acfc45f61f5cfcf101391b3c1d22f85cef0470"},
|
||||||
|
{file = "PyMuPDF-1.23.3.tar.gz", hash = "sha256:021478ae6c76e8859241dbb970612c9080a8957d8bd697bba0b4531dc1cf4f87"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
PyMuPDFb = "1.23.3"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pymupdfb"
|
||||||
|
version = "1.23.3"
|
||||||
|
description = "MuPDF shared libraries for PyMuPDF."
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "PyMuPDFb-1.23.3-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:5b05c643210eae8050d552188efab2cd68595ad75b5879a550e11af88e8bff05"},
|
||||||
|
{file = "PyMuPDFb-1.23.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:2a2b81ac348ec123bfd72336a590399f8b0035a3052c1cf5cc2401ca7a4905e9"},
|
||||||
|
{file = "PyMuPDFb-1.23.3-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:924f3f2229d232c965705d120b3ff38bbc37459af9d0e798b582950f875bee92"},
|
||||||
|
{file = "PyMuPDFb-1.23.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6c287b9ce5ed397043c6e13df19640c94a348e9edc8012d9a7b001c69ba30ca9"},
|
||||||
|
{file = "PyMuPDFb-1.23.3-py3-none-win32.whl", hash = "sha256:8703e3a8efebd83814e124d0fc3a082de2d2def329b63fca1065001e6a2deb49"},
|
||||||
|
{file = "PyMuPDFb-1.23.3-py3-none-win_amd64.whl", hash = "sha256:89d88069cb8deb100ddcf56e1feefc7cff93ff791260325ed84551f96d3abd9f"},
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pypdf"
|
name = "pypdf"
|
||||||
version = "3.15.4"
|
version = "3.15.4"
|
||||||
|
|
@ -4466,4 +4518,4 @@ multidict = ">=4.0"
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.10"
|
python-versions = "^3.10"
|
||||||
content-hash = "761b58204631452d77e13bbc2d61034704e8e109619db4addd26ec159b9bb176"
|
content-hash = "bc306ab25967437b68ef5216af4b68bf6bfdf5cb966bb6493cc3ad91e8888110"
|
||||||
|
|
|
||||||
|
|
@ -41,6 +41,7 @@ python-multipart = "^0.0.6"
|
||||||
deep-translator = "^1.11.4"
|
deep-translator = "^1.11.4"
|
||||||
humanize = "^4.8.0"
|
humanize = "^4.8.0"
|
||||||
deepeval = "^0.10.12"
|
deepeval = "^0.10.12"
|
||||||
|
pymupdf = "^1.23.3"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
51
level_2/schema/semantic/semantic_schema.py
Normal file
51
level_2/schema/semantic/semantic_schema.py
Normal file
|
|
@ -0,0 +1,51 @@
|
||||||
|
from marshmallow import Schema, fields
|
||||||
|
|
||||||
|
class DocumentMetadataSchemaV1(Schema):
|
||||||
|
user_id = fields.Str(required=True)
|
||||||
|
memory_id = fields.Str(required=True)
|
||||||
|
ltm_memory_id = fields.Str(required=True)
|
||||||
|
st_memory_id = fields.Str(required=True)
|
||||||
|
buffer_id = fields.Str(required=True)
|
||||||
|
version = fields.Str(missing="")
|
||||||
|
agreement_id = fields.Str(missing="")
|
||||||
|
privacy_policy = fields.Str(missing="")
|
||||||
|
terms_of_service = fields.Str(missing="")
|
||||||
|
format = fields.Str(missing="")
|
||||||
|
schema_version = fields.Str(missing="")
|
||||||
|
checksum = fields.Str(missing="")
|
||||||
|
owner = fields.Str(missing="")
|
||||||
|
license = fields.Str(missing="")
|
||||||
|
validity_start = fields.Str(missing="")
|
||||||
|
validity_end = fields.Str(missing="")
|
||||||
|
|
||||||
|
class DocumentMetadataSchemaV2(Schema):
|
||||||
|
user_id = fields.Str(required=True)
|
||||||
|
memory_id = fields.Str(required=True)
|
||||||
|
ltm_memory_id = fields.Str(required=True)
|
||||||
|
st_memory_id = fields.Str(required=True)
|
||||||
|
buffer_id = fields.Str(required=True)
|
||||||
|
version = fields.Str(missing="")
|
||||||
|
agreement_id = fields.Str(missing="")
|
||||||
|
privacy_policy = fields.Str(missing="")
|
||||||
|
terms_of_service = fields.Str(missing="")
|
||||||
|
format = fields.Str(missing="")
|
||||||
|
schema_version = fields.Str(missing="")
|
||||||
|
checksum = fields.Str(missing="")
|
||||||
|
owner = fields.Str(missing="")
|
||||||
|
license = fields.Str(missing="")
|
||||||
|
validity_start = fields.Str(missing="")
|
||||||
|
validity_end = fields.Str(missing="")
|
||||||
|
random = fields.Str(missing="")
|
||||||
|
|
||||||
|
class DocumentSchema(Schema):
|
||||||
|
metadata = fields.Nested(DocumentMetadataSchemaV1, required=True)
|
||||||
|
page_content = fields.Str(required=True)
|
||||||
|
|
||||||
|
|
||||||
|
SCHEMA_VERSIONS = {
|
||||||
|
"1.0": DocumentMetadataSchemaV1,
|
||||||
|
"2.0": DocumentMetadataSchemaV2
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_schema_version(version):
|
||||||
|
return SCHEMA_VERSIONS.get(version, DocumentMetadataSchemaV1)
|
||||||
7
level_2/shared/chunk_strategy.py
Normal file
7
level_2/shared/chunk_strategy.py
Normal file
|
|
@ -0,0 +1,7 @@
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
class ChunkStrategy(Enum):
|
||||||
|
EXACT = 'exact'
|
||||||
|
PARAGRAPH = 'paragraph'
|
||||||
|
SENTENCE = 'sentence'
|
||||||
|
VANILLA = 'vanilla'
|
||||||
File diff suppressed because one or more lines are too long
|
|
@ -103,11 +103,13 @@ class BaseMemory:
|
||||||
loader_settings: dict = None,
|
loader_settings: dict = None,
|
||||||
params: Optional[dict] = None,
|
params: Optional[dict] = None,
|
||||||
namespace: Optional[str] = None,
|
namespace: Optional[str] = None,
|
||||||
|
custom_fields: Optional[str] = None,
|
||||||
|
|
||||||
):
|
):
|
||||||
|
|
||||||
return await self.vector_db.add_memories(
|
return await self.vector_db.add_memories(
|
||||||
observation=observation, loader_settings=loader_settings,
|
observation=observation, loader_settings=loader_settings,
|
||||||
params=params, namespace=namespace
|
params=params, namespace=namespace, custom_fields=custom_fields
|
||||||
)
|
)
|
||||||
# Add other db_type conditions if necessary
|
# Add other db_type conditions if necessary
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,29 +3,28 @@
|
||||||
import logging
|
import logging
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
from marshmallow import Schema, fields
|
||||||
|
from level_2.loaders.loaders import _document_loader
|
||||||
|
# Add the parent directory to sys.path
|
||||||
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
import marvin
|
import marvin
|
||||||
import requests
|
import requests
|
||||||
from dotenv import load_dotenv
|
|
||||||
from langchain.document_loaders import PyPDFLoader
|
from langchain.document_loaders import PyPDFLoader
|
||||||
from langchain.retrievers import WeaviateHybridSearchRetriever
|
from langchain.retrievers import WeaviateHybridSearchRetriever
|
||||||
from weaviate.gql.get import HybridFusion
|
from weaviate.gql.get import HybridFusion
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import tracemalloc
|
import tracemalloc
|
||||||
|
|
||||||
tracemalloc.start()
|
tracemalloc.start()
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
from level_2.schema.semantic.semantic_schema import DocumentSchema, SCHEMA_VERSIONS, DocumentMetadataSchemaV1
|
||||||
from langchain.schema import Document
|
from langchain.schema import Document
|
||||||
import uuid
|
|
||||||
import weaviate
|
import weaviate
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
@ -103,122 +102,114 @@ class WeaviateVectorDB(VectorDB):
|
||||||
)
|
)
|
||||||
return client
|
return client
|
||||||
|
|
||||||
def _document_loader(self, observation: str, loader_settings: dict):
|
# def _document_loader(self, observation: str, loader_settings: dict):
|
||||||
# Check the format of the document
|
# # Check the format of the document
|
||||||
document_format = loader_settings.get("format", "text")
|
# document_format = loader_settings.get("format", "text")
|
||||||
|
#
|
||||||
|
# if document_format == "PDF":
|
||||||
|
# if loader_settings.get("source") == "url":
|
||||||
|
# pdf_response = requests.get(loader_settings["path"])
|
||||||
|
# pdf_stream = BytesIO(pdf_response.content)
|
||||||
|
# contents = pdf_stream.read()
|
||||||
|
# tmp_location = os.path.join("/tmp", "tmp.pdf")
|
||||||
|
# with open(tmp_location, "wb") as tmp_file:
|
||||||
|
# tmp_file.write(contents)
|
||||||
|
#
|
||||||
|
# # Process the PDF using PyPDFLoader
|
||||||
|
# loader = PyPDFLoader(tmp_location)
|
||||||
|
# # adapt this for different chunking strategies
|
||||||
|
# pages = loader.load_and_split()
|
||||||
|
# return pages
|
||||||
|
# elif loader_settings.get("source") == "file":
|
||||||
|
# # Process the PDF using PyPDFLoader
|
||||||
|
# # might need adapting for different loaders + OCR
|
||||||
|
# # need to test the path
|
||||||
|
# loader = PyPDFLoader(loader_settings["path"])
|
||||||
|
# pages = loader.load_and_split()
|
||||||
|
# return pages
|
||||||
|
#
|
||||||
|
# elif document_format == "text":
|
||||||
|
# # Process the text directly
|
||||||
|
# return observation
|
||||||
|
#
|
||||||
|
# else:
|
||||||
|
# raise ValueError(f"Unsupported document format: {document_format}")
|
||||||
|
def _stuct(self, observation, params, custom_fields=None):
|
||||||
|
"""Utility function to create the document structure with optional custom fields."""
|
||||||
|
# Dynamically construct metadata
|
||||||
|
metadata = {
|
||||||
|
key: str(getattr(self, key, params.get(key, "")))
|
||||||
|
for key in [
|
||||||
|
"user_id", "memory_id", "ltm_memory_id",
|
||||||
|
"st_memory_id", "buffer_id", "version",
|
||||||
|
"agreement_id", "privacy_policy", "terms_of_service",
|
||||||
|
"format", "schema_version", "checksum",
|
||||||
|
"owner", "license", "validity_start", "validity_end"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
# Merge with custom fields if provided
|
||||||
|
if custom_fields:
|
||||||
|
metadata.update(custom_fields)
|
||||||
|
|
||||||
if document_format == "PDF":
|
# Construct document data
|
||||||
if loader_settings.get("source") == "url":
|
document_data = {
|
||||||
pdf_response = requests.get(loader_settings["path"])
|
"metadata": metadata,
|
||||||
pdf_stream = BytesIO(pdf_response.content)
|
"page_content": observation
|
||||||
contents = pdf_stream.read()
|
}
|
||||||
tmp_location = os.path.join("/tmp", "tmp.pdf")
|
|
||||||
with open(tmp_location, "wb") as tmp_file:
|
|
||||||
tmp_file.write(contents)
|
|
||||||
|
|
||||||
# Process the PDF using PyPDFLoader
|
def get_document_schema_based_on_version(version):
|
||||||
loader = PyPDFLoader(tmp_location)
|
metadata_schema_class = SCHEMA_VERSIONS.get(version, DocumentMetadataSchemaV1)
|
||||||
# adapt this for different chunking strategies
|
class DynamicDocumentSchema(Schema):
|
||||||
pages = loader.load_and_split()
|
metadata = fields.Nested(metadata_schema_class, required=True)
|
||||||
return pages
|
page_content = fields.Str(required=True)
|
||||||
elif loader_settings.get("source") == "file":
|
|
||||||
# Process the PDF using PyPDFLoader
|
|
||||||
# might need adapting for different loaders + OCR
|
|
||||||
# need to test the path
|
|
||||||
loader = PyPDFLoader(loader_settings["path"])
|
|
||||||
pages = loader.load_and_split()
|
|
||||||
return pages
|
|
||||||
|
|
||||||
elif document_format == "text":
|
return DynamicDocumentSchema
|
||||||
# Process the text directly
|
|
||||||
return observation
|
|
||||||
|
|
||||||
else:
|
# Validate and deserialize
|
||||||
raise ValueError(f"Unsupported document format: {document_format}")
|
schema_version = params.get("schema_version", "1.0") # Default to "1.0" if not provided
|
||||||
|
CurrentDocumentSchema = get_document_schema_based_on_version(schema_version)
|
||||||
|
loaded_document = CurrentDocumentSchema().load(document_data)
|
||||||
|
return [loaded_document]
|
||||||
|
|
||||||
async def add_memories(
|
async def add_memories(self, observation, loader_settings=None, params=None, namespace=None, custom_fields=None):
|
||||||
self, observation: str, loader_settings: dict = None, params: dict = None ,namespace:str=None
|
|
||||||
):
|
|
||||||
# Update Weaviate memories here
|
# Update Weaviate memories here
|
||||||
print(self.namespace)
|
|
||||||
if namespace is None:
|
if namespace is None:
|
||||||
namespace = self.namespace
|
namespace = self.namespace
|
||||||
retriever = self.init_weaviate(namespace)
|
retriever = self.init_weaviate(namespace) # Assuming `init_weaviate` is a method of the class
|
||||||
|
|
||||||
def _stuct(observation, params):
|
|
||||||
"""Utility function to not repeat metadata structure"""
|
|
||||||
# needs smarter solution, like dynamic generation of metadata
|
|
||||||
return [
|
|
||||||
Document(
|
|
||||||
metadata={
|
|
||||||
# "text": observation,
|
|
||||||
"user_id": str(self.user_id),
|
|
||||||
"memory_id": str(self.memory_id),
|
|
||||||
"ltm_memory_id": str(self.ltm_memory_id),
|
|
||||||
"st_memory_id": str(self.st_memory_id),
|
|
||||||
"buffer_id": str(self.buffer_id),
|
|
||||||
"version": params.get("version", None) or "",
|
|
||||||
"agreement_id": params.get("agreement_id", None) or "",
|
|
||||||
"privacy_policy": params.get("privacy_policy", None) or "",
|
|
||||||
"terms_of_service": params.get("terms_of_service", None) or "",
|
|
||||||
"format": params.get("format", None) or "",
|
|
||||||
"schema_version": params.get("schema_version", None) or "",
|
|
||||||
"checksum": params.get("checksum", None) or "",
|
|
||||||
"owner": params.get("owner", None) or "",
|
|
||||||
"license": params.get("license", None) or "",
|
|
||||||
"validity_start": params.get("validity_start", None) or "",
|
|
||||||
"validity_end": params.get("validity_end", None) or ""
|
|
||||||
# **source_metadata,
|
|
||||||
},
|
|
||||||
page_content=observation,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
if loader_settings:
|
if loader_settings:
|
||||||
# Load the document
|
# Assuming _document_loader returns a list of documents
|
||||||
document = self._document_loader(observation, loader_settings)
|
documents = _document_loader(observation, loader_settings)
|
||||||
print("DOC LENGTH", len(document))
|
for doc in documents:
|
||||||
for doc in document:
|
document_to_load = self._stuct(doc.page_content, params, custom_fields)
|
||||||
document_to_load = _stuct(doc.page_content, params)
|
print("here is the doc to load1", document_to_load)
|
||||||
retriever.add_documents(
|
retriever.add_documents([
|
||||||
document_to_load
|
Document(metadata=document_to_load[0]['metadata'], page_content=document_to_load[0]['page_content'])])
|
||||||
)
|
else:
|
||||||
|
document_to_load = self._stuct(observation, params, custom_fields)
|
||||||
return retriever.add_documents(
|
retriever.add_documents([
|
||||||
_stuct(observation, params)
|
Document(metadata=document_to_load[0]['metadata'], page_content=document_to_load[0]['page_content'])])
|
||||||
)
|
|
||||||
|
|
||||||
async def fetch_memories(
|
async def fetch_memories(
|
||||||
self, observation: str, namespace: str, params: dict = None, n_of_observations =int(2)
|
self, observation: str, namespace: str, params: dict = None, n_of_observations: int = 2
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Get documents from weaviate.
|
Fetch documents from weaviate.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
- observation (str): User query.
|
- observation (str): User query.
|
||||||
- namespace (str): Type of memory we access.
|
- namespace (str): Type of memory accessed.
|
||||||
- params (dict, optional):
|
- params (dict, optional): Filtering parameters.
|
||||||
- n_of_observations (int, optional): For weaviate, equals to autocut, defaults to 1. Ranges from 1 to 3. Check weaviate docs for more info.
|
- n_of_observations (int, optional): For weaviate, equals to autocut. Defaults to 2. Ranges from 1 to 3.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Describe the return type and what the function returns.
|
List of documents matching the query.
|
||||||
|
|
||||||
Args a json containing:
|
|
||||||
query (str): The query string.
|
|
||||||
path (list): The path for filtering, e.g., ['year'].
|
|
||||||
operator (str): The operator for filtering, e.g., 'Equal'.
|
|
||||||
valueText (str): The value for filtering, e.g., '2017*'.
|
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
get_from_weaviate(query="some query", path=['year'], operator='Equal', valueText='2017*')
|
fetch_memories(query="some query", path=['year'], operator='Equal', valueText='2017*')
|
||||||
|
|
||||||
"""
|
"""
|
||||||
client = self.init_weaviate_client(self.namespace)
|
client = self.init_weaviate_client(self.namespace)
|
||||||
|
|
||||||
print(self.namespace)
|
if not namespace:
|
||||||
print(str(datetime.now()))
|
|
||||||
print(observation)
|
|
||||||
if namespace is None:
|
|
||||||
namespace = self.namespace
|
namespace = self.namespace
|
||||||
|
|
||||||
params_user_id = {
|
params_user_id = {
|
||||||
|
|
@ -227,78 +218,39 @@ class WeaviateVectorDB(VectorDB):
|
||||||
"valueText": self.user_id,
|
"valueText": self.user_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def list_objects_of_class(class_name, schema):
|
||||||
|
return [
|
||||||
|
prop["name"]
|
||||||
|
for class_obj in schema["classes"]
|
||||||
|
if class_obj["class"] == class_name
|
||||||
|
for prop in class_obj["properties"]
|
||||||
|
]
|
||||||
|
|
||||||
|
base_query = client.query.get(
|
||||||
|
namespace, list(list_objects_of_class(namespace, client.schema.get()))
|
||||||
|
).with_additional(
|
||||||
|
["id", "creationTimeUnix", "lastUpdateTimeUnix", "score", 'distance']
|
||||||
|
).with_where(params_user_id).with_limit(10)
|
||||||
|
|
||||||
if params:
|
if params:
|
||||||
query_output = (
|
query_output = (
|
||||||
client.query.get(
|
base_query
|
||||||
namespace,
|
|
||||||
[
|
|
||||||
# "text",
|
|
||||||
"user_id",
|
|
||||||
"memory_id",
|
|
||||||
"ltm_memory_id",
|
|
||||||
"st_memory_id",
|
|
||||||
"buffer_id",
|
|
||||||
"version",
|
|
||||||
"agreement_id",
|
|
||||||
"privacy_policy",
|
|
||||||
"terms_of_service",
|
|
||||||
"format",
|
|
||||||
"schema_version",
|
|
||||||
"checksum",
|
|
||||||
"owner",
|
|
||||||
"license",
|
|
||||||
"validity_start",
|
|
||||||
"validity_end",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
.with_where(params)
|
.with_where(params)
|
||||||
.with_near_text({"concepts": [observation]})
|
.with_near_text({"concepts": [observation]})
|
||||||
.with_additional(
|
|
||||||
["id", "creationTimeUnix", "lastUpdateTimeUnix", "score",'distance']
|
|
||||||
)
|
|
||||||
.with_where(params_user_id)
|
|
||||||
.with_limit(10)
|
|
||||||
.do()
|
.do()
|
||||||
)
|
)
|
||||||
return query_output
|
|
||||||
else:
|
else:
|
||||||
query_output = (
|
query_output = (
|
||||||
client.query.get(
|
base_query
|
||||||
namespace,
|
|
||||||
|
|
||||||
[
|
|
||||||
"text",
|
|
||||||
"user_id",
|
|
||||||
"memory_id",
|
|
||||||
"ltm_memory_id",
|
|
||||||
"st_memory_id",
|
|
||||||
"buffer_id",
|
|
||||||
"version",
|
|
||||||
"agreement_id",
|
|
||||||
"privacy_policy",
|
|
||||||
"terms_of_service",
|
|
||||||
"format",
|
|
||||||
"schema_version",
|
|
||||||
"checksum",
|
|
||||||
"owner",
|
|
||||||
"license",
|
|
||||||
"validity_start",
|
|
||||||
"validity_end",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
.with_additional(
|
|
||||||
["id", "creationTimeUnix", "lastUpdateTimeUnix", "score", 'distance']
|
|
||||||
)
|
|
||||||
.with_hybrid(
|
.with_hybrid(
|
||||||
query=observation,
|
query=observation,
|
||||||
fusion_type=HybridFusion.RELATIVE_SCORE
|
fusion_type=HybridFusion.RELATIVE_SCORE
|
||||||
)
|
)
|
||||||
.with_autocut(n_of_observations)
|
.with_autocut(n_of_observations)
|
||||||
.with_where(params_user_id)
|
|
||||||
.with_limit(10)
|
|
||||||
.do()
|
.do()
|
||||||
)
|
)
|
||||||
return query_output
|
|
||||||
|
return query_output
|
||||||
|
|
||||||
async def delete_memories(self, params: dict = None):
|
async def delete_memories(self, params: dict = None):
|
||||||
client = self.init_weaviate_client(self.namespace)
|
client = self.init_weaviate_client(self.namespace)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue