Merge pull request #13 from topoteretes/dynamic_metadata_loading_base_splitters
Dynamic metadata loading base splitters
This commit is contained in:
commit
fd7d05e705
10 changed files with 374 additions and 229 deletions
85
level_2/chunkers/chunkers.py
Normal file
85
level_2/chunkers/chunkers.py
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
from langchain.document_loaders import PyPDFLoader
|
||||
|
||||
from level_2.shared.chunk_strategy import ChunkStrategy
|
||||
import re
|
||||
def chunk_data(chunk_strategy=None, source_data=None, chunk_size=None, chunk_overlap=None):
|
||||
|
||||
if chunk_strategy == ChunkStrategy.VANILLA:
|
||||
chunked_data = vanilla_chunker(source_data, chunk_size, chunk_overlap)
|
||||
|
||||
elif chunk_strategy == ChunkStrategy.PARAGRAPH:
|
||||
chunked_data = chunk_data_by_paragraph(source_data,chunk_size, chunk_overlap)
|
||||
|
||||
elif chunk_strategy == ChunkStrategy.SENTENCE:
|
||||
chunked_data = chunk_by_sentence(source_data, chunk_size, chunk_overlap)
|
||||
elif chunk_strategy == ChunkStrategy.EXACT:
|
||||
chunked_data = chunk_data_exact(source_data, chunk_size, chunk_overlap)
|
||||
else:
|
||||
chunked_data = vanilla_chunker(source_data, chunk_size, chunk_overlap)
|
||||
return chunked_data
|
||||
|
||||
|
||||
def vanilla_chunker(source_data, chunk_size, chunk_overlap):
|
||||
# loader = PyPDFLoader(source_data)
|
||||
# adapt this for different chunking strategies
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
# Set a really small chunk size, just to show.
|
||||
chunk_size=100,
|
||||
chunk_overlap=20,
|
||||
length_function=len
|
||||
)
|
||||
pages = text_splitter.create_documents([source_data])
|
||||
# pages = source_data.load_and_split()
|
||||
return pages
|
||||
def chunk_data_exact(data_chunks, chunk_size, chunk_overlap):
|
||||
data = "".join(data_chunks)
|
||||
chunks = []
|
||||
for i in range(0, len(data), chunk_size - chunk_overlap):
|
||||
chunks.append(data[i:i + chunk_size])
|
||||
return chunks
|
||||
|
||||
|
||||
def chunk_by_sentence(data_chunks, chunk_size, overlap):
|
||||
# Split by periods, question marks, exclamation marks, and ellipses
|
||||
data = "".join(data_chunks)
|
||||
|
||||
# The regular expression is used to find series of charaters that end with one the following chaacters (. ! ? ...)
|
||||
sentence_endings = r'(?<=[.!?…]) +'
|
||||
sentences = re.split(sentence_endings, data)
|
||||
|
||||
sentence_chunks = []
|
||||
for sentence in sentences:
|
||||
if len(sentence) > chunk_size:
|
||||
chunks = chunk_data_exact([sentence], chunk_size, overlap)
|
||||
sentence_chunks.extend(chunks)
|
||||
else:
|
||||
sentence_chunks.append(sentence)
|
||||
return sentence_chunks
|
||||
|
||||
|
||||
def chunk_data_by_paragraph(data_chunks, chunk_size, overlap, bound=0.75):
|
||||
data = "".join(data_chunks)
|
||||
total_length = len(data)
|
||||
chunks = []
|
||||
check_bound = int(bound * chunk_size)
|
||||
start_idx = 0
|
||||
|
||||
while start_idx < total_length:
|
||||
# Set the end index to the minimum of start_idx + default_chunk_size or total_length
|
||||
end_idx = min(start_idx + chunk_size, total_length)
|
||||
|
||||
# Find the next paragraph index within the current chunk and bound
|
||||
next_paragraph_index = data.find('\n\n', start_idx + check_bound, end_idx)
|
||||
|
||||
# If a next paragraph index is found within the current chunk
|
||||
if next_paragraph_index != -1:
|
||||
# Update end_idx to include the paragraph delimiter
|
||||
end_idx = next_paragraph_index + 2
|
||||
|
||||
chunks.append(data[start_idx:end_idx + overlap])
|
||||
|
||||
# Update start_idx to be the current end_idx
|
||||
start_idx = end_idx
|
||||
|
||||
return chunks
|
||||
|
|
@ -362,9 +362,9 @@ class EpisodicBuffer(BaseMemory):
|
|||
"""Determines what operations are available for the user to process PDFs"""
|
||||
|
||||
return [
|
||||
"translate",
|
||||
"structure",
|
||||
"fetch from vector store"
|
||||
"retrieve over time",
|
||||
"save to personal notes",
|
||||
"translate to german"
|
||||
# "load to semantic memory",
|
||||
# "load to episodic memory",
|
||||
# "load to buffer",
|
||||
|
|
@ -594,6 +594,8 @@ class EpisodicBuffer(BaseMemory):
|
|||
episodic_context = await chain.arun(input=prompt_filter_chunk, verbose=True)
|
||||
print(cb)
|
||||
|
||||
print("HERE IS THE EPISODIC CONTEXT", episodic_context)
|
||||
|
||||
class BufferModulators(BaseModel):
|
||||
attention_modulators: Dict[str, float] = Field(... , description="Attention modulators")
|
||||
|
||||
|
|
@ -728,55 +730,6 @@ class EpisodicBuffer(BaseMemory):
|
|||
complete_agent_prompt= f" Document context is: {document_from_vectorstore} \n Task is : {task['task_order']} {task['task_name']} {task['operation']} "
|
||||
|
||||
# task['vector_store_context_results']=document_context_result_parsed.dict()
|
||||
class PromptWrapper(BaseModel):
|
||||
observation: str = Field(
|
||||
description="observation we want to fetch from vectordb"
|
||||
)
|
||||
|
||||
@tool(
|
||||
"convert_to_structured", args_schema=PromptWrapper, return_direct=True
|
||||
)
|
||||
def convert_to_structured(observation=None, json_schema=None):
|
||||
"""Convert unstructured data to structured data"""
|
||||
BASE_DIR = os.getcwd()
|
||||
json_path = os.path.join(
|
||||
BASE_DIR, "schema_registry", "ticket_schema.json"
|
||||
)
|
||||
|
||||
def load_json_or_infer_schema(file_path, document_path):
|
||||
"""Load JSON schema from file or infer schema from text"""
|
||||
|
||||
# Attempt to load the JSON file
|
||||
with open(file_path, "r") as file:
|
||||
json_schema = json.load(file)
|
||||
return json_schema
|
||||
|
||||
json_schema = load_json_or_infer_schema(json_path, None)
|
||||
|
||||
def run_open_ai_mapper(observation=None, json_schema=None):
|
||||
"""Convert unstructured data to structured data"""
|
||||
|
||||
prompt_msgs = [
|
||||
SystemMessage(
|
||||
content="You are a world class algorithm converting unstructured data into structured data."
|
||||
),
|
||||
HumanMessage(
|
||||
content="Convert unstructured data to structured data:"
|
||||
),
|
||||
HumanMessagePromptTemplate.from_template("{input}"),
|
||||
HumanMessage(
|
||||
content="Tips: Make sure to answer in the correct format"
|
||||
),
|
||||
]
|
||||
prompt_ = ChatPromptTemplate(messages=prompt_msgs)
|
||||
chain_funct = create_structured_output_chain(
|
||||
json_schema, prompt=prompt_, llm=self.llm, verbose=True
|
||||
)
|
||||
output = chain_funct.run(input=observation, llm=self.llm)
|
||||
return output
|
||||
|
||||
result = run_open_ai_mapper(observation, json_schema)
|
||||
return result
|
||||
|
||||
class FetchText(BaseModel):
|
||||
observation: str = Field(description="observation we want to translate")
|
||||
|
|
@ -802,7 +755,7 @@ class EpisodicBuffer(BaseMemory):
|
|||
|
||||
agent = initialize_agent(
|
||||
llm=self.llm,
|
||||
tools=[fetch_from_vector_store,translate_to_de, convert_to_structured],
|
||||
tools=[fetch_from_vector_store,translate_to_de],
|
||||
agent=AgentType.OPENAI_FUNCTIONS,
|
||||
verbose=True,
|
||||
)
|
||||
|
|
@ -1084,18 +1037,21 @@ async def main():
|
|||
"source": "url",
|
||||
"path": "https://www.ibiblio.org/ebooks/London/Call%20of%20Wild.pdf"
|
||||
}
|
||||
# load_jack_london = await memory._add_semantic_memory(observation = "bla", loader_settings=loader_settings, params=params)
|
||||
# print(load_jack_london)
|
||||
load_jack_london = await memory._add_semantic_memory(observation = "bla", loader_settings=loader_settings, params=params)
|
||||
print(load_jack_london)
|
||||
|
||||
modulator = {"relevance": 0.1, "frequency": 0.1}
|
||||
|
||||
# fdsf = await memory._fetch_semantic_memory(observation="bla", params=None)
|
||||
# print(fdsf)
|
||||
# await memory._delete_episodic_memory()
|
||||
#
|
||||
run_main_buffer = await memory._create_buffer_context(
|
||||
user_input="I want to know how does Buck adapt to life in the wild and then have that info translated to german ",
|
||||
params=params,
|
||||
attention_modulators=modulator,
|
||||
)
|
||||
print(run_main_buffer)
|
||||
# run_main_buffer = await memory._create_buffer_context(
|
||||
# user_input="I want to know how does Buck adapt to life in the wild and then have that info translated to german ",
|
||||
# params=params,
|
||||
# attention_modulators=modulator,
|
||||
# )
|
||||
# print(run_main_buffer)
|
||||
# #
|
||||
# run_main_buffer = await memory._run_main_buffer(
|
||||
# user_input="I want to know how does Buck adapt to life in the wild and then have that info translated to german ",
|
||||
|
|
|
|||
39
level_2/loaders/loaders.py
Normal file
39
level_2/loaders/loaders.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
import os
|
||||
from io import BytesIO
|
||||
|
||||
import fitz
|
||||
from level_2.chunkers.chunkers import chunk_data
|
||||
from langchain.document_loaders import PyPDFLoader
|
||||
|
||||
import requests
|
||||
def _document_loader( observation: str, loader_settings: dict):
|
||||
# Check the format of the document
|
||||
document_format = loader_settings.get("format", "text")
|
||||
|
||||
if document_format == "PDF":
|
||||
if loader_settings.get("source") == "url":
|
||||
pdf_response = requests.get(loader_settings["path"])
|
||||
pdf_stream = BytesIO(pdf_response.content)
|
||||
with fitz.open(stream=pdf_stream, filetype='pdf') as doc:
|
||||
file_content = ""
|
||||
for page in doc:
|
||||
file_content += page.get_text()
|
||||
pages = chunk_data(chunk_strategy= 'VANILLA', source_data=file_content)
|
||||
|
||||
return pages
|
||||
elif loader_settings.get("source") == "file":
|
||||
# Process the PDF using PyPDFLoader
|
||||
# might need adapting for different loaders + OCR
|
||||
# need to test the path
|
||||
loader = PyPDFLoader(loader_settings["path"])
|
||||
pages = loader.load_and_split()
|
||||
return pages
|
||||
|
||||
elif document_format == "text":
|
||||
# Process the text directly
|
||||
return observation
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported document format: {document_format}")
|
||||
|
||||
|
||||
54
level_2/poetry.lock
generated
54
level_2/poetry.lock
generated
|
|
@ -2491,6 +2491,58 @@ files = [
|
|||
[package.extras]
|
||||
plugins = ["importlib-metadata"]
|
||||
|
||||
[[package]]
|
||||
name = "pymupdf"
|
||||
version = "1.23.3"
|
||||
description = "A high performance Python library for data extraction, analysis, conversion & manipulation of PDF (and other) documents."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "PyMuPDF-1.23.3-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:52699939b7482c8c566a181e2a980a6801c91959ee96dae5663070fd2b960c6b"},
|
||||
{file = "PyMuPDF-1.23.3-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:95408d57ed77f3c396880a3fc0feae068c4bf577e7e2c761d24a345138062f8d"},
|
||||
{file = "PyMuPDF-1.23.3-cp310-none-manylinux2014_aarch64.whl", hash = "sha256:5eefd674e338ddd82cd9179ad7d4c2160796efd6c0d4cd1098b5314ff78688d7"},
|
||||
{file = "PyMuPDF-1.23.3-cp310-none-manylinux2014_x86_64.whl", hash = "sha256:c7696034f5f5472d1e6d3f3556858cf85e095b66c158a80b527facfa83542aee"},
|
||||
{file = "PyMuPDF-1.23.3-cp310-none-win32.whl", hash = "sha256:f3c6d427381f4ef76bec4e862c8969845e90bc842b3c534800be9cb6fe6b0e3b"},
|
||||
{file = "PyMuPDF-1.23.3-cp310-none-win_amd64.whl", hash = "sha256:0fd19017d4c7791146e38621d878393136e25a2a4fadd0372a98ab2a9aabc0c5"},
|
||||
{file = "PyMuPDF-1.23.3-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:0e88408dea51492431b111a721d88a4f4c2176786734b16374d77a421f410139"},
|
||||
{file = "PyMuPDF-1.23.3-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:c4dbf5e851373f4633b57187b0ae3dcde0efad6ef5969c4de14bb9a52a796261"},
|
||||
{file = "PyMuPDF-1.23.3-cp311-none-manylinux2014_aarch64.whl", hash = "sha256:7218c1099205edb3357cb5713661d11d7c04aaa910645da64e17c2d050d61352"},
|
||||
{file = "PyMuPDF-1.23.3-cp311-none-manylinux2014_x86_64.whl", hash = "sha256:0304d5def03d2bedf951179624ea636470b5ee0a706ea37636f7a3b2b08561a5"},
|
||||
{file = "PyMuPDF-1.23.3-cp311-none-win32.whl", hash = "sha256:35fe66d80cdc948ed55ac70c94b2e7f740fc08309c4ce125228ce0042a2fbba8"},
|
||||
{file = "PyMuPDF-1.23.3-cp311-none-win_amd64.whl", hash = "sha256:e643e4f30d1a5e358a8f65eab66dd0ea33f8170d61eb7549f0d227086c82d315"},
|
||||
{file = "PyMuPDF-1.23.3-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:95065c21c39dc93c4e224a2ac3c903bf31d635cdb569338d79e9befbac9755eb"},
|
||||
{file = "PyMuPDF-1.23.3-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:0c06610d78a86fcbfbcea77320c54f561ac4d568666d621afcf1109e8cfc829b"},
|
||||
{file = "PyMuPDF-1.23.3-cp38-none-manylinux2014_aarch64.whl", hash = "sha256:6e4ef7e65b3fb7f9248f1f2dc530f10d0e00a8080dd5da52808e6638a9868a10"},
|
||||
{file = "PyMuPDF-1.23.3-cp38-none-manylinux2014_x86_64.whl", hash = "sha256:d51b848d45e09e7fedfdeb0880a2a14872e25dd4e0932b9abf6a36a69bf01f6a"},
|
||||
{file = "PyMuPDF-1.23.3-cp38-none-win32.whl", hash = "sha256:42b879913a07fb251251af20e46747abc3d5d0276a48d2c28e128f5f88ef3dcd"},
|
||||
{file = "PyMuPDF-1.23.3-cp38-none-win_amd64.whl", hash = "sha256:a283236e09c056798ecaf6e0872790c63d91edf6d5f72b76504715d6b88da976"},
|
||||
{file = "PyMuPDF-1.23.3-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:6329a223ae38641fe4ff081beffd33f5e3be800c0409569b64a33b70f1b544cf"},
|
||||
{file = "PyMuPDF-1.23.3-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:640a5ada4479a2c69b811c91f163a7b55f7fe1c323b861373d6068893cc9e9e0"},
|
||||
{file = "PyMuPDF-1.23.3-cp39-none-manylinux2014_aarch64.whl", hash = "sha256:2f555d264f08e091eaf9fd27c33ba9bfdc39ac8d09aa12195ab529bcca79229d"},
|
||||
{file = "PyMuPDF-1.23.3-cp39-none-manylinux2014_x86_64.whl", hash = "sha256:96dc89254d78bddac8434be7b9f4c354fe57b224b5420614cde9c2f1d2f1355e"},
|
||||
{file = "PyMuPDF-1.23.3-cp39-none-win32.whl", hash = "sha256:f9a1d2f7484bde2ec81f3c88641f7a8b7f52450b807408ae7a340ddecb424659"},
|
||||
{file = "PyMuPDF-1.23.3-cp39-none-win_amd64.whl", hash = "sha256:7cfceb91048665965d826023c4acfc45f61f5cfcf101391b3c1d22f85cef0470"},
|
||||
{file = "PyMuPDF-1.23.3.tar.gz", hash = "sha256:021478ae6c76e8859241dbb970612c9080a8957d8bd697bba0b4531dc1cf4f87"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
PyMuPDFb = "1.23.3"
|
||||
|
||||
[[package]]
|
||||
name = "pymupdfb"
|
||||
version = "1.23.3"
|
||||
description = "MuPDF shared libraries for PyMuPDF."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "PyMuPDFb-1.23.3-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:5b05c643210eae8050d552188efab2cd68595ad75b5879a550e11af88e8bff05"},
|
||||
{file = "PyMuPDFb-1.23.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:2a2b81ac348ec123bfd72336a590399f8b0035a3052c1cf5cc2401ca7a4905e9"},
|
||||
{file = "PyMuPDFb-1.23.3-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:924f3f2229d232c965705d120b3ff38bbc37459af9d0e798b582950f875bee92"},
|
||||
{file = "PyMuPDFb-1.23.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6c287b9ce5ed397043c6e13df19640c94a348e9edc8012d9a7b001c69ba30ca9"},
|
||||
{file = "PyMuPDFb-1.23.3-py3-none-win32.whl", hash = "sha256:8703e3a8efebd83814e124d0fc3a082de2d2def329b63fca1065001e6a2deb49"},
|
||||
{file = "PyMuPDFb-1.23.3-py3-none-win_amd64.whl", hash = "sha256:89d88069cb8deb100ddcf56e1feefc7cff93ff791260325ed84551f96d3abd9f"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pypdf"
|
||||
version = "3.15.4"
|
||||
|
|
@ -4466,4 +4518,4 @@ multidict = ">=4.0"
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "761b58204631452d77e13bbc2d61034704e8e109619db4addd26ec159b9bb176"
|
||||
content-hash = "bc306ab25967437b68ef5216af4b68bf6bfdf5cb966bb6493cc3ad91e8888110"
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@ python-multipart = "^0.0.6"
|
|||
deep-translator = "^1.11.4"
|
||||
humanize = "^4.8.0"
|
||||
deepeval = "^0.10.12"
|
||||
pymupdf = "^1.23.3"
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
51
level_2/schema/semantic/semantic_schema.py
Normal file
51
level_2/schema/semantic/semantic_schema.py
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
from marshmallow import Schema, fields
|
||||
|
||||
class DocumentMetadataSchemaV1(Schema):
|
||||
user_id = fields.Str(required=True)
|
||||
memory_id = fields.Str(required=True)
|
||||
ltm_memory_id = fields.Str(required=True)
|
||||
st_memory_id = fields.Str(required=True)
|
||||
buffer_id = fields.Str(required=True)
|
||||
version = fields.Str(missing="")
|
||||
agreement_id = fields.Str(missing="")
|
||||
privacy_policy = fields.Str(missing="")
|
||||
terms_of_service = fields.Str(missing="")
|
||||
format = fields.Str(missing="")
|
||||
schema_version = fields.Str(missing="")
|
||||
checksum = fields.Str(missing="")
|
||||
owner = fields.Str(missing="")
|
||||
license = fields.Str(missing="")
|
||||
validity_start = fields.Str(missing="")
|
||||
validity_end = fields.Str(missing="")
|
||||
|
||||
class DocumentMetadataSchemaV2(Schema):
|
||||
user_id = fields.Str(required=True)
|
||||
memory_id = fields.Str(required=True)
|
||||
ltm_memory_id = fields.Str(required=True)
|
||||
st_memory_id = fields.Str(required=True)
|
||||
buffer_id = fields.Str(required=True)
|
||||
version = fields.Str(missing="")
|
||||
agreement_id = fields.Str(missing="")
|
||||
privacy_policy = fields.Str(missing="")
|
||||
terms_of_service = fields.Str(missing="")
|
||||
format = fields.Str(missing="")
|
||||
schema_version = fields.Str(missing="")
|
||||
checksum = fields.Str(missing="")
|
||||
owner = fields.Str(missing="")
|
||||
license = fields.Str(missing="")
|
||||
validity_start = fields.Str(missing="")
|
||||
validity_end = fields.Str(missing="")
|
||||
random = fields.Str(missing="")
|
||||
|
||||
class DocumentSchema(Schema):
|
||||
metadata = fields.Nested(DocumentMetadataSchemaV1, required=True)
|
||||
page_content = fields.Str(required=True)
|
||||
|
||||
|
||||
SCHEMA_VERSIONS = {
|
||||
"1.0": DocumentMetadataSchemaV1,
|
||||
"2.0": DocumentMetadataSchemaV2
|
||||
}
|
||||
|
||||
def get_schema_version(version):
|
||||
return SCHEMA_VERSIONS.get(version, DocumentMetadataSchemaV1)
|
||||
7
level_2/shared/chunk_strategy.py
Normal file
7
level_2/shared/chunk_strategy.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
from enum import Enum
|
||||
|
||||
class ChunkStrategy(Enum):
|
||||
EXACT = 'exact'
|
||||
PARAGRAPH = 'paragraph'
|
||||
SENTENCE = 'sentence'
|
||||
VANILLA = 'vanilla'
|
||||
File diff suppressed because one or more lines are too long
|
|
@ -103,11 +103,13 @@ class BaseMemory:
|
|||
loader_settings: dict = None,
|
||||
params: Optional[dict] = None,
|
||||
namespace: Optional[str] = None,
|
||||
custom_fields: Optional[str] = None,
|
||||
|
||||
):
|
||||
|
||||
return await self.vector_db.add_memories(
|
||||
observation=observation, loader_settings=loader_settings,
|
||||
params=params, namespace=namespace
|
||||
params=params, namespace=namespace, custom_fields=custom_fields
|
||||
)
|
||||
# Add other db_type conditions if necessary
|
||||
|
||||
|
|
|
|||
|
|
@ -3,29 +3,28 @@
|
|||
import logging
|
||||
from io import BytesIO
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
from marshmallow import Schema, fields
|
||||
from level_2.loaders.loaders import _document_loader
|
||||
# Add the parent directory to sys.path
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
import marvin
|
||||
import requests
|
||||
from dotenv import load_dotenv
|
||||
from langchain.document_loaders import PyPDFLoader
|
||||
from langchain.retrievers import WeaviateHybridSearchRetriever
|
||||
from weaviate.gql.get import HybridFusion
|
||||
|
||||
load_dotenv()
|
||||
from typing import Optional
|
||||
|
||||
import tracemalloc
|
||||
|
||||
tracemalloc.start()
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from dotenv import load_dotenv
|
||||
from level_2.schema.semantic.semantic_schema import DocumentSchema, SCHEMA_VERSIONS, DocumentMetadataSchemaV1
|
||||
from langchain.schema import Document
|
||||
import uuid
|
||||
import weaviate
|
||||
|
||||
load_dotenv()
|
||||
|
|
@ -103,122 +102,114 @@ class WeaviateVectorDB(VectorDB):
|
|||
)
|
||||
return client
|
||||
|
||||
def _document_loader(self, observation: str, loader_settings: dict):
|
||||
# Check the format of the document
|
||||
document_format = loader_settings.get("format", "text")
|
||||
# def _document_loader(self, observation: str, loader_settings: dict):
|
||||
# # Check the format of the document
|
||||
# document_format = loader_settings.get("format", "text")
|
||||
#
|
||||
# if document_format == "PDF":
|
||||
# if loader_settings.get("source") == "url":
|
||||
# pdf_response = requests.get(loader_settings["path"])
|
||||
# pdf_stream = BytesIO(pdf_response.content)
|
||||
# contents = pdf_stream.read()
|
||||
# tmp_location = os.path.join("/tmp", "tmp.pdf")
|
||||
# with open(tmp_location, "wb") as tmp_file:
|
||||
# tmp_file.write(contents)
|
||||
#
|
||||
# # Process the PDF using PyPDFLoader
|
||||
# loader = PyPDFLoader(tmp_location)
|
||||
# # adapt this for different chunking strategies
|
||||
# pages = loader.load_and_split()
|
||||
# return pages
|
||||
# elif loader_settings.get("source") == "file":
|
||||
# # Process the PDF using PyPDFLoader
|
||||
# # might need adapting for different loaders + OCR
|
||||
# # need to test the path
|
||||
# loader = PyPDFLoader(loader_settings["path"])
|
||||
# pages = loader.load_and_split()
|
||||
# return pages
|
||||
#
|
||||
# elif document_format == "text":
|
||||
# # Process the text directly
|
||||
# return observation
|
||||
#
|
||||
# else:
|
||||
# raise ValueError(f"Unsupported document format: {document_format}")
|
||||
def _stuct(self, observation, params, custom_fields=None):
|
||||
"""Utility function to create the document structure with optional custom fields."""
|
||||
# Dynamically construct metadata
|
||||
metadata = {
|
||||
key: str(getattr(self, key, params.get(key, "")))
|
||||
for key in [
|
||||
"user_id", "memory_id", "ltm_memory_id",
|
||||
"st_memory_id", "buffer_id", "version",
|
||||
"agreement_id", "privacy_policy", "terms_of_service",
|
||||
"format", "schema_version", "checksum",
|
||||
"owner", "license", "validity_start", "validity_end"
|
||||
]
|
||||
}
|
||||
# Merge with custom fields if provided
|
||||
if custom_fields:
|
||||
metadata.update(custom_fields)
|
||||
|
||||
if document_format == "PDF":
|
||||
if loader_settings.get("source") == "url":
|
||||
pdf_response = requests.get(loader_settings["path"])
|
||||
pdf_stream = BytesIO(pdf_response.content)
|
||||
contents = pdf_stream.read()
|
||||
tmp_location = os.path.join("/tmp", "tmp.pdf")
|
||||
with open(tmp_location, "wb") as tmp_file:
|
||||
tmp_file.write(contents)
|
||||
# Construct document data
|
||||
document_data = {
|
||||
"metadata": metadata,
|
||||
"page_content": observation
|
||||
}
|
||||
|
||||
# Process the PDF using PyPDFLoader
|
||||
loader = PyPDFLoader(tmp_location)
|
||||
# adapt this for different chunking strategies
|
||||
pages = loader.load_and_split()
|
||||
return pages
|
||||
elif loader_settings.get("source") == "file":
|
||||
# Process the PDF using PyPDFLoader
|
||||
# might need adapting for different loaders + OCR
|
||||
# need to test the path
|
||||
loader = PyPDFLoader(loader_settings["path"])
|
||||
pages = loader.load_and_split()
|
||||
return pages
|
||||
def get_document_schema_based_on_version(version):
|
||||
metadata_schema_class = SCHEMA_VERSIONS.get(version, DocumentMetadataSchemaV1)
|
||||
class DynamicDocumentSchema(Schema):
|
||||
metadata = fields.Nested(metadata_schema_class, required=True)
|
||||
page_content = fields.Str(required=True)
|
||||
|
||||
elif document_format == "text":
|
||||
# Process the text directly
|
||||
return observation
|
||||
return DynamicDocumentSchema
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported document format: {document_format}")
|
||||
# Validate and deserialize
|
||||
schema_version = params.get("schema_version", "1.0") # Default to "1.0" if not provided
|
||||
CurrentDocumentSchema = get_document_schema_based_on_version(schema_version)
|
||||
loaded_document = CurrentDocumentSchema().load(document_data)
|
||||
return [loaded_document]
|
||||
|
||||
async def add_memories(
|
||||
self, observation: str, loader_settings: dict = None, params: dict = None ,namespace:str=None
|
||||
):
|
||||
async def add_memories(self, observation, loader_settings=None, params=None, namespace=None, custom_fields=None):
|
||||
# Update Weaviate memories here
|
||||
print(self.namespace)
|
||||
if namespace is None:
|
||||
namespace = self.namespace
|
||||
retriever = self.init_weaviate(namespace)
|
||||
|
||||
def _stuct(observation, params):
|
||||
"""Utility function to not repeat metadata structure"""
|
||||
# needs smarter solution, like dynamic generation of metadata
|
||||
return [
|
||||
Document(
|
||||
metadata={
|
||||
# "text": observation,
|
||||
"user_id": str(self.user_id),
|
||||
"memory_id": str(self.memory_id),
|
||||
"ltm_memory_id": str(self.ltm_memory_id),
|
||||
"st_memory_id": str(self.st_memory_id),
|
||||
"buffer_id": str(self.buffer_id),
|
||||
"version": params.get("version", None) or "",
|
||||
"agreement_id": params.get("agreement_id", None) or "",
|
||||
"privacy_policy": params.get("privacy_policy", None) or "",
|
||||
"terms_of_service": params.get("terms_of_service", None) or "",
|
||||
"format": params.get("format", None) or "",
|
||||
"schema_version": params.get("schema_version", None) or "",
|
||||
"checksum": params.get("checksum", None) or "",
|
||||
"owner": params.get("owner", None) or "",
|
||||
"license": params.get("license", None) or "",
|
||||
"validity_start": params.get("validity_start", None) or "",
|
||||
"validity_end": params.get("validity_end", None) or ""
|
||||
# **source_metadata,
|
||||
},
|
||||
page_content=observation,
|
||||
)
|
||||
]
|
||||
|
||||
retriever = self.init_weaviate(namespace) # Assuming `init_weaviate` is a method of the class
|
||||
if loader_settings:
|
||||
# Load the document
|
||||
document = self._document_loader(observation, loader_settings)
|
||||
print("DOC LENGTH", len(document))
|
||||
for doc in document:
|
||||
document_to_load = _stuct(doc.page_content, params)
|
||||
retriever.add_documents(
|
||||
document_to_load
|
||||
)
|
||||
|
||||
return retriever.add_documents(
|
||||
_stuct(observation, params)
|
||||
)
|
||||
# Assuming _document_loader returns a list of documents
|
||||
documents = _document_loader(observation, loader_settings)
|
||||
for doc in documents:
|
||||
document_to_load = self._stuct(doc.page_content, params, custom_fields)
|
||||
print("here is the doc to load1", document_to_load)
|
||||
retriever.add_documents([
|
||||
Document(metadata=document_to_load[0]['metadata'], page_content=document_to_load[0]['page_content'])])
|
||||
else:
|
||||
document_to_load = self._stuct(observation, params, custom_fields)
|
||||
retriever.add_documents([
|
||||
Document(metadata=document_to_load[0]['metadata'], page_content=document_to_load[0]['page_content'])])
|
||||
|
||||
async def fetch_memories(
|
||||
self, observation: str, namespace: str, params: dict = None, n_of_observations =int(2)
|
||||
self, observation: str, namespace: str, params: dict = None, n_of_observations: int = 2
|
||||
):
|
||||
"""
|
||||
Get documents from weaviate.
|
||||
Fetch documents from weaviate.
|
||||
|
||||
Parameters:
|
||||
- observation (str): User query.
|
||||
- namespace (str): Type of memory we access.
|
||||
- params (dict, optional):
|
||||
- n_of_observations (int, optional): For weaviate, equals to autocut, defaults to 1. Ranges from 1 to 3. Check weaviate docs for more info.
|
||||
Parameters:
|
||||
- observation (str): User query.
|
||||
- namespace (str): Type of memory accessed.
|
||||
- params (dict, optional): Filtering parameters.
|
||||
- n_of_observations (int, optional): For weaviate, equals to autocut. Defaults to 2. Ranges from 1 to 3.
|
||||
|
||||
Returns:
|
||||
Describe the return type and what the function returns.
|
||||
|
||||
Args a json containing:
|
||||
query (str): The query string.
|
||||
path (list): The path for filtering, e.g., ['year'].
|
||||
operator (str): The operator for filtering, e.g., 'Equal'.
|
||||
valueText (str): The value for filtering, e.g., '2017*'.
|
||||
Returns:
|
||||
List of documents matching the query.
|
||||
|
||||
Example:
|
||||
get_from_weaviate(query="some query", path=['year'], operator='Equal', valueText='2017*')
|
||||
|
||||
fetch_memories(query="some query", path=['year'], operator='Equal', valueText='2017*')
|
||||
"""
|
||||
client = self.init_weaviate_client(self.namespace)
|
||||
|
||||
print(self.namespace)
|
||||
print(str(datetime.now()))
|
||||
print(observation)
|
||||
if namespace is None:
|
||||
if not namespace:
|
||||
namespace = self.namespace
|
||||
|
||||
params_user_id = {
|
||||
|
|
@ -227,78 +218,39 @@ class WeaviateVectorDB(VectorDB):
|
|||
"valueText": self.user_id,
|
||||
}
|
||||
|
||||
def list_objects_of_class(class_name, schema):
|
||||
return [
|
||||
prop["name"]
|
||||
for class_obj in schema["classes"]
|
||||
if class_obj["class"] == class_name
|
||||
for prop in class_obj["properties"]
|
||||
]
|
||||
|
||||
base_query = client.query.get(
|
||||
namespace, list(list_objects_of_class(namespace, client.schema.get()))
|
||||
).with_additional(
|
||||
["id", "creationTimeUnix", "lastUpdateTimeUnix", "score", 'distance']
|
||||
).with_where(params_user_id).with_limit(10)
|
||||
|
||||
if params:
|
||||
query_output = (
|
||||
client.query.get(
|
||||
namespace,
|
||||
[
|
||||
# "text",
|
||||
"user_id",
|
||||
"memory_id",
|
||||
"ltm_memory_id",
|
||||
"st_memory_id",
|
||||
"buffer_id",
|
||||
"version",
|
||||
"agreement_id",
|
||||
"privacy_policy",
|
||||
"terms_of_service",
|
||||
"format",
|
||||
"schema_version",
|
||||
"checksum",
|
||||
"owner",
|
||||
"license",
|
||||
"validity_start",
|
||||
"validity_end",
|
||||
],
|
||||
)
|
||||
base_query
|
||||
.with_where(params)
|
||||
.with_near_text({"concepts": [observation]})
|
||||
.with_additional(
|
||||
["id", "creationTimeUnix", "lastUpdateTimeUnix", "score",'distance']
|
||||
)
|
||||
.with_where(params_user_id)
|
||||
.with_limit(10)
|
||||
.do()
|
||||
)
|
||||
return query_output
|
||||
else:
|
||||
query_output = (
|
||||
client.query.get(
|
||||
namespace,
|
||||
|
||||
[
|
||||
"text",
|
||||
"user_id",
|
||||
"memory_id",
|
||||
"ltm_memory_id",
|
||||
"st_memory_id",
|
||||
"buffer_id",
|
||||
"version",
|
||||
"agreement_id",
|
||||
"privacy_policy",
|
||||
"terms_of_service",
|
||||
"format",
|
||||
"schema_version",
|
||||
"checksum",
|
||||
"owner",
|
||||
"license",
|
||||
"validity_start",
|
||||
"validity_end",
|
||||
],
|
||||
)
|
||||
.with_additional(
|
||||
["id", "creationTimeUnix", "lastUpdateTimeUnix", "score", 'distance']
|
||||
)
|
||||
base_query
|
||||
.with_hybrid(
|
||||
query=observation,
|
||||
fusion_type=HybridFusion.RELATIVE_SCORE
|
||||
)
|
||||
.with_autocut(n_of_observations)
|
||||
.with_where(params_user_id)
|
||||
.with_limit(10)
|
||||
.do()
|
||||
)
|
||||
return query_output
|
||||
|
||||
return query_output
|
||||
|
||||
async def delete_memories(self, params: dict = None):
|
||||
client = self.init_weaviate_client(self.namespace)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue