Added following:
1. Dynamic metadata retrieval, refactored function 2. Load with using marshmallow, allows dynamic fields now 3. Added chunkers, different varieties 4. Fixed PDF loading so it is better standardized
This commit is contained in:
parent
4bfea2e328
commit
59c53f7339
10 changed files with 367 additions and 203 deletions
85
level_2/chunkers/chunkers.py
Normal file
85
level_2/chunkers/chunkers.py
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
from langchain.document_loaders import PyPDFLoader
|
||||
|
||||
from level_2.shared.chunk_strategy import ChunkStrategy
|
||||
import re
|
||||
def chunk_data(chunk_strategy=None, source_data=None, chunk_size=None, chunk_overlap=None):
|
||||
|
||||
if chunk_strategy == ChunkStrategy.VANILLA:
|
||||
chunked_data = vanilla_chunker(source_data, chunk_size, chunk_overlap)
|
||||
|
||||
elif chunk_strategy == ChunkStrategy.PARAGRAPH:
|
||||
chunked_data = chunk_data_by_paragraph(source_data,chunk_size, chunk_overlap)
|
||||
|
||||
elif chunk_strategy == ChunkStrategy.SENTENCE:
|
||||
chunked_data = chunk_by_sentence(source_data, chunk_size, chunk_overlap)
|
||||
elif chunk_strategy == ChunkStrategy.EXACT:
|
||||
chunked_data = chunk_data_exact(source_data, chunk_size, chunk_overlap)
|
||||
else:
|
||||
chunked_data = vanilla_chunker(source_data, chunk_size, chunk_overlap)
|
||||
return chunked_data
|
||||
|
||||
|
||||
def vanilla_chunker(source_data, chunk_size, chunk_overlap):
|
||||
# loader = PyPDFLoader(source_data)
|
||||
# adapt this for different chunking strategies
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
# Set a really small chunk size, just to show.
|
||||
chunk_size=100,
|
||||
chunk_overlap=20,
|
||||
length_function=len
|
||||
)
|
||||
pages = text_splitter.create_documents([source_data])
|
||||
# pages = source_data.load_and_split()
|
||||
return pages
|
||||
def chunk_data_exact(data_chunks, chunk_size, chunk_overlap):
|
||||
data = "".join(data_chunks)
|
||||
chunks = []
|
||||
for i in range(0, len(data), chunk_size - chunk_overlap):
|
||||
chunks.append(data[i:i + chunk_size])
|
||||
return chunks
|
||||
|
||||
|
||||
def chunk_by_sentence(data_chunks, chunk_size, overlap):
|
||||
# Split by periods, question marks, exclamation marks, and ellipses
|
||||
data = "".join(data_chunks)
|
||||
|
||||
# The regular expression is used to find series of charaters that end with one the following chaacters (. ! ? ...)
|
||||
sentence_endings = r'(?<=[.!?…]) +'
|
||||
sentences = re.split(sentence_endings, data)
|
||||
|
||||
sentence_chunks = []
|
||||
for sentence in sentences:
|
||||
if len(sentence) > chunk_size:
|
||||
chunks = chunk_data_exact([sentence], chunk_size, overlap)
|
||||
sentence_chunks.extend(chunks)
|
||||
else:
|
||||
sentence_chunks.append(sentence)
|
||||
return sentence_chunks
|
||||
|
||||
|
||||
def chunk_data_by_paragraph(data_chunks, chunk_size, overlap, bound=0.75):
|
||||
data = "".join(data_chunks)
|
||||
total_length = len(data)
|
||||
chunks = []
|
||||
check_bound = int(bound * chunk_size)
|
||||
start_idx = 0
|
||||
|
||||
while start_idx < total_length:
|
||||
# Set the end index to the minimum of start_idx + default_chunk_size or total_length
|
||||
end_idx = min(start_idx + chunk_size, total_length)
|
||||
|
||||
# Find the next paragraph index within the current chunk and bound
|
||||
next_paragraph_index = data.find('\n\n', start_idx + check_bound, end_idx)
|
||||
|
||||
# If a next paragraph index is found within the current chunk
|
||||
if next_paragraph_index != -1:
|
||||
# Update end_idx to include the paragraph delimiter
|
||||
end_idx = next_paragraph_index + 2
|
||||
|
||||
chunks.append(data[start_idx:end_idx + overlap])
|
||||
|
||||
# Update start_idx to be the current end_idx
|
||||
start_idx = end_idx
|
||||
|
||||
return chunks
|
||||
|
|
@ -742,34 +742,6 @@ class EpisodicBuffer(BaseMemory):
|
|||
out = self.fetch_memories(observation['original_query'], namespace="SEMANTICMEMORY")
|
||||
return out
|
||||
|
||||
@tool("retrieve_from_memories", args_schema=FetchText, return_direct=True)
|
||||
def retrieve_from_memories(observation, args_schema=FetchText):
|
||||
"""Retrieve from episodic memory if data doesn't exist in the context"""
|
||||
|
||||
new_observations = []
|
||||
observation = self.fetch_memories(observation['original_query'], namespace="EPISODICMEMORY")
|
||||
|
||||
for memory in observation:
|
||||
|
||||
unix_t = memory["data"]["Get"]["EPISODICMEMORY"][0]["_additional"][
|
||||
"lastUpdateTimeUnix"
|
||||
]
|
||||
|
||||
# Convert Unix timestamp to datetime
|
||||
last_update_datetime = datetime.fromtimestamp(int(unix_t) / 1000)
|
||||
time_difference = datetime.now() - last_update_datetime
|
||||
time_difference_text = humanize.naturaltime(time_difference)
|
||||
# Append the time difference to the memory
|
||||
memory["time_difference"] = str(time_difference_text)
|
||||
#patch the memory
|
||||
#retrieve again then
|
||||
|
||||
# Append the modified memory to the new list
|
||||
new_observations.append(memory)
|
||||
|
||||
|
||||
|
||||
|
||||
class TranslateText(BaseModel):
|
||||
observation: str = Field(description="observation we want to translate")
|
||||
|
||||
|
|
@ -1065,18 +1037,21 @@ async def main():
|
|||
"source": "url",
|
||||
"path": "https://www.ibiblio.org/ebooks/London/Call%20of%20Wild.pdf"
|
||||
}
|
||||
# load_jack_london = await memory._add_semantic_memory(observation = "bla", loader_settings=loader_settings, params=params)
|
||||
# print(load_jack_london)
|
||||
load_jack_london = await memory._add_semantic_memory(observation = "bla", loader_settings=loader_settings, params=params)
|
||||
print(load_jack_london)
|
||||
|
||||
modulator = {"relevance": 0.1, "frequency": 0.1}
|
||||
|
||||
# fdsf = await memory._fetch_semantic_memory(observation="bla", params=None)
|
||||
# print(fdsf)
|
||||
# await memory._delete_episodic_memory()
|
||||
#
|
||||
run_main_buffer = await memory._create_buffer_context(
|
||||
user_input="I want to know how does Buck adapt to life in the wild and then have that info translated to german ",
|
||||
params=params,
|
||||
attention_modulators=modulator,
|
||||
)
|
||||
print(run_main_buffer)
|
||||
# run_main_buffer = await memory._create_buffer_context(
|
||||
# user_input="I want to know how does Buck adapt to life in the wild and then have that info translated to german ",
|
||||
# params=params,
|
||||
# attention_modulators=modulator,
|
||||
# )
|
||||
# print(run_main_buffer)
|
||||
# #
|
||||
# run_main_buffer = await memory._run_main_buffer(
|
||||
# user_input="I want to know how does Buck adapt to life in the wild and then have that info translated to german ",
|
||||
|
|
|
|||
39
level_2/loaders/loaders.py
Normal file
39
level_2/loaders/loaders.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
import os
|
||||
from io import BytesIO
|
||||
|
||||
import fitz
|
||||
from level_2.chunkers.chunkers import chunk_data
|
||||
from langchain.document_loaders import PyPDFLoader
|
||||
|
||||
import requests
|
||||
def _document_loader( observation: str, loader_settings: dict):
|
||||
# Check the format of the document
|
||||
document_format = loader_settings.get("format", "text")
|
||||
|
||||
if document_format == "PDF":
|
||||
if loader_settings.get("source") == "url":
|
||||
pdf_response = requests.get(loader_settings["path"])
|
||||
pdf_stream = BytesIO(pdf_response.content)
|
||||
with fitz.open(stream=pdf_stream, filetype='pdf') as doc:
|
||||
file_content = ""
|
||||
for page in doc:
|
||||
file_content += page.get_text()
|
||||
pages = chunk_data(chunk_strategy= 'VANILLA', source_data=file_content)
|
||||
|
||||
return pages
|
||||
elif loader_settings.get("source") == "file":
|
||||
# Process the PDF using PyPDFLoader
|
||||
# might need adapting for different loaders + OCR
|
||||
# need to test the path
|
||||
loader = PyPDFLoader(loader_settings["path"])
|
||||
pages = loader.load_and_split()
|
||||
return pages
|
||||
|
||||
elif document_format == "text":
|
||||
# Process the text directly
|
||||
return observation
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported document format: {document_format}")
|
||||
|
||||
|
||||
54
level_2/poetry.lock
generated
54
level_2/poetry.lock
generated
|
|
@ -2491,6 +2491,58 @@ files = [
|
|||
[package.extras]
|
||||
plugins = ["importlib-metadata"]
|
||||
|
||||
[[package]]
|
||||
name = "pymupdf"
|
||||
version = "1.23.3"
|
||||
description = "A high performance Python library for data extraction, analysis, conversion & manipulation of PDF (and other) documents."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "PyMuPDF-1.23.3-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:52699939b7482c8c566a181e2a980a6801c91959ee96dae5663070fd2b960c6b"},
|
||||
{file = "PyMuPDF-1.23.3-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:95408d57ed77f3c396880a3fc0feae068c4bf577e7e2c761d24a345138062f8d"},
|
||||
{file = "PyMuPDF-1.23.3-cp310-none-manylinux2014_aarch64.whl", hash = "sha256:5eefd674e338ddd82cd9179ad7d4c2160796efd6c0d4cd1098b5314ff78688d7"},
|
||||
{file = "PyMuPDF-1.23.3-cp310-none-manylinux2014_x86_64.whl", hash = "sha256:c7696034f5f5472d1e6d3f3556858cf85e095b66c158a80b527facfa83542aee"},
|
||||
{file = "PyMuPDF-1.23.3-cp310-none-win32.whl", hash = "sha256:f3c6d427381f4ef76bec4e862c8969845e90bc842b3c534800be9cb6fe6b0e3b"},
|
||||
{file = "PyMuPDF-1.23.3-cp310-none-win_amd64.whl", hash = "sha256:0fd19017d4c7791146e38621d878393136e25a2a4fadd0372a98ab2a9aabc0c5"},
|
||||
{file = "PyMuPDF-1.23.3-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:0e88408dea51492431b111a721d88a4f4c2176786734b16374d77a421f410139"},
|
||||
{file = "PyMuPDF-1.23.3-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:c4dbf5e851373f4633b57187b0ae3dcde0efad6ef5969c4de14bb9a52a796261"},
|
||||
{file = "PyMuPDF-1.23.3-cp311-none-manylinux2014_aarch64.whl", hash = "sha256:7218c1099205edb3357cb5713661d11d7c04aaa910645da64e17c2d050d61352"},
|
||||
{file = "PyMuPDF-1.23.3-cp311-none-manylinux2014_x86_64.whl", hash = "sha256:0304d5def03d2bedf951179624ea636470b5ee0a706ea37636f7a3b2b08561a5"},
|
||||
{file = "PyMuPDF-1.23.3-cp311-none-win32.whl", hash = "sha256:35fe66d80cdc948ed55ac70c94b2e7f740fc08309c4ce125228ce0042a2fbba8"},
|
||||
{file = "PyMuPDF-1.23.3-cp311-none-win_amd64.whl", hash = "sha256:e643e4f30d1a5e358a8f65eab66dd0ea33f8170d61eb7549f0d227086c82d315"},
|
||||
{file = "PyMuPDF-1.23.3-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:95065c21c39dc93c4e224a2ac3c903bf31d635cdb569338d79e9befbac9755eb"},
|
||||
{file = "PyMuPDF-1.23.3-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:0c06610d78a86fcbfbcea77320c54f561ac4d568666d621afcf1109e8cfc829b"},
|
||||
{file = "PyMuPDF-1.23.3-cp38-none-manylinux2014_aarch64.whl", hash = "sha256:6e4ef7e65b3fb7f9248f1f2dc530f10d0e00a8080dd5da52808e6638a9868a10"},
|
||||
{file = "PyMuPDF-1.23.3-cp38-none-manylinux2014_x86_64.whl", hash = "sha256:d51b848d45e09e7fedfdeb0880a2a14872e25dd4e0932b9abf6a36a69bf01f6a"},
|
||||
{file = "PyMuPDF-1.23.3-cp38-none-win32.whl", hash = "sha256:42b879913a07fb251251af20e46747abc3d5d0276a48d2c28e128f5f88ef3dcd"},
|
||||
{file = "PyMuPDF-1.23.3-cp38-none-win_amd64.whl", hash = "sha256:a283236e09c056798ecaf6e0872790c63d91edf6d5f72b76504715d6b88da976"},
|
||||
{file = "PyMuPDF-1.23.3-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:6329a223ae38641fe4ff081beffd33f5e3be800c0409569b64a33b70f1b544cf"},
|
||||
{file = "PyMuPDF-1.23.3-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:640a5ada4479a2c69b811c91f163a7b55f7fe1c323b861373d6068893cc9e9e0"},
|
||||
{file = "PyMuPDF-1.23.3-cp39-none-manylinux2014_aarch64.whl", hash = "sha256:2f555d264f08e091eaf9fd27c33ba9bfdc39ac8d09aa12195ab529bcca79229d"},
|
||||
{file = "PyMuPDF-1.23.3-cp39-none-manylinux2014_x86_64.whl", hash = "sha256:96dc89254d78bddac8434be7b9f4c354fe57b224b5420614cde9c2f1d2f1355e"},
|
||||
{file = "PyMuPDF-1.23.3-cp39-none-win32.whl", hash = "sha256:f9a1d2f7484bde2ec81f3c88641f7a8b7f52450b807408ae7a340ddecb424659"},
|
||||
{file = "PyMuPDF-1.23.3-cp39-none-win_amd64.whl", hash = "sha256:7cfceb91048665965d826023c4acfc45f61f5cfcf101391b3c1d22f85cef0470"},
|
||||
{file = "PyMuPDF-1.23.3.tar.gz", hash = "sha256:021478ae6c76e8859241dbb970612c9080a8957d8bd697bba0b4531dc1cf4f87"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
PyMuPDFb = "1.23.3"
|
||||
|
||||
[[package]]
|
||||
name = "pymupdfb"
|
||||
version = "1.23.3"
|
||||
description = "MuPDF shared libraries for PyMuPDF."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "PyMuPDFb-1.23.3-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:5b05c643210eae8050d552188efab2cd68595ad75b5879a550e11af88e8bff05"},
|
||||
{file = "PyMuPDFb-1.23.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:2a2b81ac348ec123bfd72336a590399f8b0035a3052c1cf5cc2401ca7a4905e9"},
|
||||
{file = "PyMuPDFb-1.23.3-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:924f3f2229d232c965705d120b3ff38bbc37459af9d0e798b582950f875bee92"},
|
||||
{file = "PyMuPDFb-1.23.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6c287b9ce5ed397043c6e13df19640c94a348e9edc8012d9a7b001c69ba30ca9"},
|
||||
{file = "PyMuPDFb-1.23.3-py3-none-win32.whl", hash = "sha256:8703e3a8efebd83814e124d0fc3a082de2d2def329b63fca1065001e6a2deb49"},
|
||||
{file = "PyMuPDFb-1.23.3-py3-none-win_amd64.whl", hash = "sha256:89d88069cb8deb100ddcf56e1feefc7cff93ff791260325ed84551f96d3abd9f"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pypdf"
|
||||
version = "3.15.4"
|
||||
|
|
@ -4466,4 +4518,4 @@ multidict = ">=4.0"
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "761b58204631452d77e13bbc2d61034704e8e109619db4addd26ec159b9bb176"
|
||||
content-hash = "bc306ab25967437b68ef5216af4b68bf6bfdf5cb966bb6493cc3ad91e8888110"
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@ python-multipart = "^0.0.6"
|
|||
deep-translator = "^1.11.4"
|
||||
humanize = "^4.8.0"
|
||||
deepeval = "^0.10.12"
|
||||
pymupdf = "^1.23.3"
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
51
level_2/schema/semantic/semantic_schema.py
Normal file
51
level_2/schema/semantic/semantic_schema.py
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
from marshmallow import Schema, fields
|
||||
|
||||
class DocumentMetadataSchemaV1(Schema):
|
||||
user_id = fields.Str(required=True)
|
||||
memory_id = fields.Str(required=True)
|
||||
ltm_memory_id = fields.Str(required=True)
|
||||
st_memory_id = fields.Str(required=True)
|
||||
buffer_id = fields.Str(required=True)
|
||||
version = fields.Str(missing="")
|
||||
agreement_id = fields.Str(missing="")
|
||||
privacy_policy = fields.Str(missing="")
|
||||
terms_of_service = fields.Str(missing="")
|
||||
format = fields.Str(missing="")
|
||||
schema_version = fields.Str(missing="")
|
||||
checksum = fields.Str(missing="")
|
||||
owner = fields.Str(missing="")
|
||||
license = fields.Str(missing="")
|
||||
validity_start = fields.Str(missing="")
|
||||
validity_end = fields.Str(missing="")
|
||||
|
||||
class DocumentMetadataSchemaV2(Schema):
|
||||
user_id = fields.Str(required=True)
|
||||
memory_id = fields.Str(required=True)
|
||||
ltm_memory_id = fields.Str(required=True)
|
||||
st_memory_id = fields.Str(required=True)
|
||||
buffer_id = fields.Str(required=True)
|
||||
version = fields.Str(missing="")
|
||||
agreement_id = fields.Str(missing="")
|
||||
privacy_policy = fields.Str(missing="")
|
||||
terms_of_service = fields.Str(missing="")
|
||||
format = fields.Str(missing="")
|
||||
schema_version = fields.Str(missing="")
|
||||
checksum = fields.Str(missing="")
|
||||
owner = fields.Str(missing="")
|
||||
license = fields.Str(missing="")
|
||||
validity_start = fields.Str(missing="")
|
||||
validity_end = fields.Str(missing="")
|
||||
random = fields.Str(missing="")
|
||||
|
||||
class DocumentSchema(Schema):
|
||||
metadata = fields.Nested(DocumentMetadataSchemaV1, required=True)
|
||||
page_content = fields.Str(required=True)
|
||||
|
||||
|
||||
SCHEMA_VERSIONS = {
|
||||
"1.0": DocumentMetadataSchemaV1,
|
||||
"2.0": DocumentMetadataSchemaV2
|
||||
}
|
||||
|
||||
def get_schema_version(version):
|
||||
return SCHEMA_VERSIONS.get(version, DocumentMetadataSchemaV1)
|
||||
7
level_2/shared/chunk_strategy.py
Normal file
7
level_2/shared/chunk_strategy.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
from enum import Enum
|
||||
|
||||
class ChunkStrategy(Enum):
|
||||
EXACT = 'exact'
|
||||
PARAGRAPH = 'paragraph'
|
||||
SENTENCE = 'sentence'
|
||||
VANILLA = 'vanilla'
|
||||
|
|
@ -148,15 +148,15 @@ async def main():
|
|||
# tasks_list = """tasks": [{"task_order": "1", "task_name": "Fetch Information", "operation": "fetch from vector store", "original_query": "I want to know how does Buck adapt to life in the wild"]"""
|
||||
out_tasks = """here are the result_tasks [{'task_order': '1', 'task_name': 'Save Information', 'operation': 'save to vector store', 'original_query': 'Add to notes who is Buck and get info saved yesterday about him'}, {'docs': [{'semantic_search_term': "Add to notes who is Buck", 'document_summary': 'Buck was a dog stolen from his home', 'document_relevance': '0.75', 'attention_modulators_list': [{'frequency': '0.33', 'saliency': '0.75', 'relevance': '0.74'}]}], 'user_query': 'I want to know who buck is and check my notes from yesterday'}, {'task_order': '2', 'task_name': 'Check historical data', 'operation': 'check historical data', 'original_query': ' check my notes from yesterday'}, ' Data saved yesterday about Buck include informaton that he was stolen from home and that he was a pretty dog ']"""
|
||||
|
||||
await _add_to_episodic(user_input=user_input, result_tasks=out_tasks, tasks_list=None, attention_modulators=modulator, params=params)
|
||||
# await _add_to_episodic(user_input=user_input, result_tasks=out_tasks, tasks_list=None, attention_modulators=modulator, params=params)
|
||||
# await delete_from_episodic()
|
||||
# aa = await get_from_episodic(observation="summary")
|
||||
aa = await get_from_episodic(observation="summary")
|
||||
# await delete_from_buffer()
|
||||
modulator_changed = {"relevance": 0.9, "saliency": 0.9, "frequency": 0.9}
|
||||
await add_to_buffer(adjusted_modulator=modulator_changed)
|
||||
# await add_to_buffer(adjusted_modulator=modulator_changed)
|
||||
|
||||
# aa = await get_from_buffer(observation="summary")
|
||||
# print(aa)
|
||||
print(aa)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
|
|
|||
|
|
@ -103,11 +103,13 @@ class BaseMemory:
|
|||
loader_settings: dict = None,
|
||||
params: Optional[dict] = None,
|
||||
namespace: Optional[str] = None,
|
||||
custom_fields: Optional[str] = None,
|
||||
|
||||
):
|
||||
|
||||
return await self.vector_db.add_memories(
|
||||
observation=observation, loader_settings=loader_settings,
|
||||
params=params, namespace=namespace
|
||||
params=params, namespace=namespace, custom_fields=custom_fields
|
||||
)
|
||||
# Add other db_type conditions if necessary
|
||||
|
||||
|
|
|
|||
|
|
@ -3,29 +3,28 @@
|
|||
import logging
|
||||
from io import BytesIO
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
from marshmallow import Schema, fields
|
||||
from level_2.loaders.loaders import _document_loader
|
||||
# Add the parent directory to sys.path
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
import marvin
|
||||
import requests
|
||||
from dotenv import load_dotenv
|
||||
from langchain.document_loaders import PyPDFLoader
|
||||
from langchain.retrievers import WeaviateHybridSearchRetriever
|
||||
from weaviate.gql.get import HybridFusion
|
||||
|
||||
load_dotenv()
|
||||
from typing import Optional
|
||||
|
||||
import tracemalloc
|
||||
|
||||
tracemalloc.start()
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from dotenv import load_dotenv
|
||||
from level_2.schema.semantic.semantic_schema import DocumentSchema, SCHEMA_VERSIONS, DocumentMetadataSchemaV1
|
||||
from langchain.schema import Document
|
||||
import uuid
|
||||
import weaviate
|
||||
|
||||
load_dotenv()
|
||||
|
|
@ -103,122 +102,114 @@ class WeaviateVectorDB(VectorDB):
|
|||
)
|
||||
return client
|
||||
|
||||
def _document_loader(self, observation: str, loader_settings: dict):
|
||||
# Check the format of the document
|
||||
document_format = loader_settings.get("format", "text")
|
||||
# def _document_loader(self, observation: str, loader_settings: dict):
|
||||
# # Check the format of the document
|
||||
# document_format = loader_settings.get("format", "text")
|
||||
#
|
||||
# if document_format == "PDF":
|
||||
# if loader_settings.get("source") == "url":
|
||||
# pdf_response = requests.get(loader_settings["path"])
|
||||
# pdf_stream = BytesIO(pdf_response.content)
|
||||
# contents = pdf_stream.read()
|
||||
# tmp_location = os.path.join("/tmp", "tmp.pdf")
|
||||
# with open(tmp_location, "wb") as tmp_file:
|
||||
# tmp_file.write(contents)
|
||||
#
|
||||
# # Process the PDF using PyPDFLoader
|
||||
# loader = PyPDFLoader(tmp_location)
|
||||
# # adapt this for different chunking strategies
|
||||
# pages = loader.load_and_split()
|
||||
# return pages
|
||||
# elif loader_settings.get("source") == "file":
|
||||
# # Process the PDF using PyPDFLoader
|
||||
# # might need adapting for different loaders + OCR
|
||||
# # need to test the path
|
||||
# loader = PyPDFLoader(loader_settings["path"])
|
||||
# pages = loader.load_and_split()
|
||||
# return pages
|
||||
#
|
||||
# elif document_format == "text":
|
||||
# # Process the text directly
|
||||
# return observation
|
||||
#
|
||||
# else:
|
||||
# raise ValueError(f"Unsupported document format: {document_format}")
|
||||
def _stuct(self, observation, params, custom_fields=None):
|
||||
"""Utility function to create the document structure with optional custom fields."""
|
||||
# Dynamically construct metadata
|
||||
metadata = {
|
||||
key: str(getattr(self, key, params.get(key, "")))
|
||||
for key in [
|
||||
"user_id", "memory_id", "ltm_memory_id",
|
||||
"st_memory_id", "buffer_id", "version",
|
||||
"agreement_id", "privacy_policy", "terms_of_service",
|
||||
"format", "schema_version", "checksum",
|
||||
"owner", "license", "validity_start", "validity_end"
|
||||
]
|
||||
}
|
||||
# Merge with custom fields if provided
|
||||
if custom_fields:
|
||||
metadata.update(custom_fields)
|
||||
|
||||
if document_format == "PDF":
|
||||
if loader_settings.get("source") == "url":
|
||||
pdf_response = requests.get(loader_settings["path"])
|
||||
pdf_stream = BytesIO(pdf_response.content)
|
||||
contents = pdf_stream.read()
|
||||
tmp_location = os.path.join("/tmp", "tmp.pdf")
|
||||
with open(tmp_location, "wb") as tmp_file:
|
||||
tmp_file.write(contents)
|
||||
# Construct document data
|
||||
document_data = {
|
||||
"metadata": metadata,
|
||||
"page_content": observation
|
||||
}
|
||||
|
||||
# Process the PDF using PyPDFLoader
|
||||
loader = PyPDFLoader(tmp_location)
|
||||
# adapt this for different chunking strategies
|
||||
pages = loader.load_and_split()
|
||||
return pages
|
||||
elif loader_settings.get("source") == "file":
|
||||
# Process the PDF using PyPDFLoader
|
||||
# might need adapting for different loaders + OCR
|
||||
# need to test the path
|
||||
loader = PyPDFLoader(loader_settings["path"])
|
||||
pages = loader.load_and_split()
|
||||
return pages
|
||||
def get_document_schema_based_on_version(version):
|
||||
metadata_schema_class = SCHEMA_VERSIONS.get(version, DocumentMetadataSchemaV1)
|
||||
class DynamicDocumentSchema(Schema):
|
||||
metadata = fields.Nested(metadata_schema_class, required=True)
|
||||
page_content = fields.Str(required=True)
|
||||
|
||||
elif document_format == "text":
|
||||
# Process the text directly
|
||||
return observation
|
||||
return DynamicDocumentSchema
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported document format: {document_format}")
|
||||
# Validate and deserialize
|
||||
schema_version = params.get("schema_version", "1.0") # Default to "1.0" if not provided
|
||||
CurrentDocumentSchema = get_document_schema_based_on_version(schema_version)
|
||||
loaded_document = CurrentDocumentSchema().load(document_data)
|
||||
return [loaded_document]
|
||||
|
||||
async def add_memories(
|
||||
self, observation: str, loader_settings: dict = None, params: dict = None ,namespace:str=None
|
||||
):
|
||||
async def add_memories(self, observation, loader_settings=None, params=None, namespace=None, custom_fields=None):
|
||||
# Update Weaviate memories here
|
||||
print(self.namespace)
|
||||
if namespace is None:
|
||||
namespace = self.namespace
|
||||
retriever = self.init_weaviate(namespace)
|
||||
|
||||
def _stuct(observation, params):
|
||||
"""Utility function to not repeat metadata structure"""
|
||||
# needs smarter solution, like dynamic generation of metadata
|
||||
return [
|
||||
Document(
|
||||
metadata={
|
||||
# "text": observation,
|
||||
"user_id": str(self.user_id),
|
||||
"memory_id": str(self.memory_id),
|
||||
"ltm_memory_id": str(self.ltm_memory_id),
|
||||
"st_memory_id": str(self.st_memory_id),
|
||||
"buffer_id": str(self.buffer_id),
|
||||
"version": params.get("version", None) or "",
|
||||
"agreement_id": params.get("agreement_id", None) or "",
|
||||
"privacy_policy": params.get("privacy_policy", None) or "",
|
||||
"terms_of_service": params.get("terms_of_service", None) or "",
|
||||
"format": params.get("format", None) or "",
|
||||
"schema_version": params.get("schema_version", None) or "",
|
||||
"checksum": params.get("checksum", None) or "",
|
||||
"owner": params.get("owner", None) or "",
|
||||
"license": params.get("license", None) or "",
|
||||
"validity_start": params.get("validity_start", None) or "",
|
||||
"validity_end": params.get("validity_end", None) or ""
|
||||
# **source_metadata,
|
||||
},
|
||||
page_content=observation,
|
||||
)
|
||||
]
|
||||
|
||||
retriever = self.init_weaviate(namespace) # Assuming `init_weaviate` is a method of the class
|
||||
if loader_settings:
|
||||
# Load the document
|
||||
document = self._document_loader(observation, loader_settings)
|
||||
print("DOC LENGTH", len(document))
|
||||
for doc in document:
|
||||
document_to_load = _stuct(doc.page_content, params)
|
||||
retriever.add_documents(
|
||||
document_to_load
|
||||
)
|
||||
|
||||
return retriever.add_documents(
|
||||
_stuct(observation, params)
|
||||
)
|
||||
# Assuming _document_loader returns a list of documents
|
||||
documents = _document_loader(observation, loader_settings)
|
||||
for doc in documents:
|
||||
document_to_load = self._stuct(doc.page_content, params, custom_fields)
|
||||
print("here is the doc to load1", document_to_load)
|
||||
retriever.add_documents([
|
||||
Document(metadata=document_to_load[0]['metadata'], page_content=document_to_load[0]['page_content'])])
|
||||
else:
|
||||
document_to_load = self._stuct(observation, params, custom_fields)
|
||||
retriever.add_documents([
|
||||
Document(metadata=document_to_load[0]['metadata'], page_content=document_to_load[0]['page_content'])])
|
||||
|
||||
async def fetch_memories(
|
||||
self, observation: str, namespace: str, params: dict = None, n_of_observations =int(2)
|
||||
self, observation: str, namespace: str, params: dict = None, n_of_observations: int = 2
|
||||
):
|
||||
"""
|
||||
Get documents from weaviate.
|
||||
Fetch documents from weaviate.
|
||||
|
||||
Parameters:
|
||||
- observation (str): User query.
|
||||
- namespace (str): Type of memory we access.
|
||||
- params (dict, optional):
|
||||
- n_of_observations (int, optional): For weaviate, equals to autocut, defaults to 1. Ranges from 1 to 3. Check weaviate docs for more info.
|
||||
Parameters:
|
||||
- observation (str): User query.
|
||||
- namespace (str): Type of memory accessed.
|
||||
- params (dict, optional): Filtering parameters.
|
||||
- n_of_observations (int, optional): For weaviate, equals to autocut. Defaults to 2. Ranges from 1 to 3.
|
||||
|
||||
Returns:
|
||||
Describe the return type and what the function returns.
|
||||
|
||||
Args a json containing:
|
||||
query (str): The query string.
|
||||
path (list): The path for filtering, e.g., ['year'].
|
||||
operator (str): The operator for filtering, e.g., 'Equal'.
|
||||
valueText (str): The value for filtering, e.g., '2017*'.
|
||||
Returns:
|
||||
List of documents matching the query.
|
||||
|
||||
Example:
|
||||
get_from_weaviate(query="some query", path=['year'], operator='Equal', valueText='2017*')
|
||||
|
||||
fetch_memories(query="some query", path=['year'], operator='Equal', valueText='2017*')
|
||||
"""
|
||||
client = self.init_weaviate_client(self.namespace)
|
||||
|
||||
print(self.namespace)
|
||||
print(str(datetime.now()))
|
||||
print(observation)
|
||||
if namespace is None:
|
||||
if not namespace:
|
||||
namespace = self.namespace
|
||||
|
||||
params_user_id = {
|
||||
|
|
@ -227,78 +218,39 @@ class WeaviateVectorDB(VectorDB):
|
|||
"valueText": self.user_id,
|
||||
}
|
||||
|
||||
def list_objects_of_class(class_name, schema):
|
||||
return [
|
||||
prop["name"]
|
||||
for class_obj in schema["classes"]
|
||||
if class_obj["class"] == class_name
|
||||
for prop in class_obj["properties"]
|
||||
]
|
||||
|
||||
base_query = client.query.get(
|
||||
namespace, list(list_objects_of_class(namespace, client.schema.get()))
|
||||
).with_additional(
|
||||
["id", "creationTimeUnix", "lastUpdateTimeUnix", "score", 'distance']
|
||||
).with_where(params_user_id).with_limit(10)
|
||||
|
||||
if params:
|
||||
query_output = (
|
||||
client.query.get(
|
||||
namespace,
|
||||
[
|
||||
# "text",
|
||||
"user_id",
|
||||
"memory_id",
|
||||
"ltm_memory_id",
|
||||
"st_memory_id",
|
||||
"buffer_id",
|
||||
"version",
|
||||
"agreement_id",
|
||||
"privacy_policy",
|
||||
"terms_of_service",
|
||||
"format",
|
||||
"schema_version",
|
||||
"checksum",
|
||||
"owner",
|
||||
"license",
|
||||
"validity_start",
|
||||
"validity_end",
|
||||
],
|
||||
)
|
||||
base_query
|
||||
.with_where(params)
|
||||
.with_near_text({"concepts": [observation]})
|
||||
.with_additional(
|
||||
["id", "creationTimeUnix", "lastUpdateTimeUnix", "score",'distance']
|
||||
)
|
||||
.with_where(params_user_id)
|
||||
.with_limit(10)
|
||||
.do()
|
||||
)
|
||||
return query_output
|
||||
else:
|
||||
query_output = (
|
||||
client.query.get(
|
||||
namespace,
|
||||
|
||||
[
|
||||
"text",
|
||||
"user_id",
|
||||
"memory_id",
|
||||
"ltm_memory_id",
|
||||
"st_memory_id",
|
||||
"buffer_id",
|
||||
"version",
|
||||
"agreement_id",
|
||||
"privacy_policy",
|
||||
"terms_of_service",
|
||||
"format",
|
||||
"schema_version",
|
||||
"checksum",
|
||||
"owner",
|
||||
"license",
|
||||
"validity_start",
|
||||
"validity_end",
|
||||
],
|
||||
)
|
||||
.with_additional(
|
||||
["id", "creationTimeUnix", "lastUpdateTimeUnix", "score", 'distance']
|
||||
)
|
||||
base_query
|
||||
.with_hybrid(
|
||||
query=observation,
|
||||
fusion_type=HybridFusion.RELATIVE_SCORE
|
||||
)
|
||||
.with_autocut(n_of_observations)
|
||||
.with_where(params_user_id)
|
||||
.with_limit(10)
|
||||
.do()
|
||||
)
|
||||
return query_output
|
||||
|
||||
return query_output
|
||||
|
||||
async def delete_memories(self, params: dict = None):
|
||||
client = self.init_weaviate_client(self.namespace)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue