1. wrong uuid for st memmory 2. weaviate checker logic that was not needed 3. decomposed vector db and factory classes into separate files
355 lines
13 KiB
Python
355 lines
13 KiB
Python
|
|
# Make sure to install the following packages: dlt, langchain, duckdb, python-dotenv, openai, weaviate-client
|
|
import logging
|
|
from io import BytesIO
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
import marvin
|
|
import requests
|
|
from dotenv import load_dotenv
|
|
from langchain.document_loaders import PyPDFLoader
|
|
from langchain.retrievers import WeaviateHybridSearchRetriever
|
|
from weaviate.gql.get import HybridFusion
|
|
|
|
load_dotenv()
|
|
from typing import Optional
|
|
|
|
import tracemalloc
|
|
|
|
tracemalloc.start()
|
|
|
|
import os
|
|
from datetime import datetime
|
|
from langchain.embeddings.openai import OpenAIEmbeddings
|
|
from dotenv import load_dotenv
|
|
from langchain.schema import Document
|
|
import uuid
|
|
import weaviate
|
|
|
|
load_dotenv()
|
|
|
|
|
|
LTM_MEMORY_ID_DEFAULT = "00000"
|
|
ST_MEMORY_ID_DEFAULT = "0000"
|
|
BUFFER_ID_DEFAULT = "0000"
|
|
class VectorDB:
|
|
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
|
|
|
|
def __init__(
|
|
self,
|
|
user_id: str,
|
|
index_name: str,
|
|
memory_id: str,
|
|
ltm_memory_id: str = LTM_MEMORY_ID_DEFAULT,
|
|
st_memory_id: str = ST_MEMORY_ID_DEFAULT,
|
|
buffer_id: str = BUFFER_ID_DEFAULT,
|
|
namespace: str = None,
|
|
):
|
|
self.user_id = user_id
|
|
self.index_name = index_name
|
|
self.namespace = namespace
|
|
self.memory_id = memory_id
|
|
self.ltm_memory_id = ltm_memory_id
|
|
self.st_memory_id = st_memory_id
|
|
self.buffer_id = buffer_id
|
|
|
|
class PineconeVectorDB(VectorDB):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.init_pinecone(self.index_name)
|
|
|
|
def init_pinecone(self, index_name):
|
|
# Pinecone initialization logic
|
|
pass
|
|
|
|
|
|
class WeaviateVectorDB(VectorDB):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.init_weaviate(self.namespace)
|
|
|
|
def init_weaviate(self, namespace: str):
|
|
# Weaviate initialization logic
|
|
embeddings = OpenAIEmbeddings()
|
|
auth_config = weaviate.auth.AuthApiKey(
|
|
api_key=os.environ.get("WEAVIATE_API_KEY")
|
|
)
|
|
client = weaviate.Client(
|
|
url=os.environ.get("WEAVIATE_URL"),
|
|
auth_client_secret=auth_config,
|
|
additional_headers={"X-OpenAI-Api-Key": os.environ.get("OPENAI_API_KEY")},
|
|
)
|
|
retriever = WeaviateHybridSearchRetriever(
|
|
client=client,
|
|
index_name=namespace,
|
|
text_key="text",
|
|
attributes=[],
|
|
embedding=embeddings,
|
|
create_schema_if_missing=True,
|
|
)
|
|
return retriever # If this is part of the initialization, call it here.
|
|
|
|
def init_weaviate_client(self, namespace: str):
|
|
# Weaviate client initialization logic
|
|
auth_config = weaviate.auth.AuthApiKey(
|
|
api_key=os.environ.get("WEAVIATE_API_KEY")
|
|
)
|
|
client = weaviate.Client(
|
|
url=os.environ.get("WEAVIATE_URL"),
|
|
auth_client_secret=auth_config,
|
|
additional_headers={"X-OpenAI-Api-Key": os.environ.get("OPENAI_API_KEY")},
|
|
)
|
|
return client
|
|
|
|
def _document_loader(self, observation: str, loader_settings: dict):
|
|
# Create an in-memory file-like object for the PDF content
|
|
|
|
if loader_settings.get("format") == "PDF":
|
|
|
|
if loader_settings.get("source") == "url":
|
|
pdf_response = requests.get(loader_settings["path"])
|
|
pdf_stream = BytesIO(pdf_response.content)
|
|
contents = pdf_stream.read()
|
|
tmp_location = os.path.join("/tmp", "tmp.pdf")
|
|
with open(tmp_location, "wb") as tmp_file:
|
|
tmp_file.write(contents)
|
|
|
|
# Process the PDF using PyPDFLoader
|
|
loader = PyPDFLoader(tmp_location)
|
|
# adapt this for different chunking strategies
|
|
pages = loader.load_and_split()
|
|
return pages
|
|
|
|
if loader_settings.get("source") == "file":
|
|
# Process the PDF using PyPDFLoader
|
|
# might need adapting for different loaders + OCR
|
|
# need to test the path
|
|
loader = PyPDFLoader(loader_settings["path"])
|
|
pages = loader.load_and_split()
|
|
|
|
return pages
|
|
else:
|
|
# Process the text by just loading the base text
|
|
return observation
|
|
|
|
|
|
async def add_memories(
|
|
self, observation: str, loader_settings: dict = None, params: dict = None ,namespace:str=None
|
|
):
|
|
# Update Weaviate memories here
|
|
print(self.namespace)
|
|
if namespace is None:
|
|
namespace = self.namespace
|
|
retriever = self.init_weaviate(namespace)
|
|
|
|
def _stuct(observation, params):
|
|
"""Utility function to not repeat metadata structure"""
|
|
# needs smarter solution, like dynamic generation of metadata
|
|
return [
|
|
Document(
|
|
metadata={
|
|
# "text": observation,
|
|
"user_id": str(self.user_id),
|
|
"memory_id": str(self.memory_id),
|
|
"ltm_memory_id": str(self.ltm_memory_id),
|
|
"st_memory_id": str(self.st_memory_id),
|
|
"buffer_id": str(self.buffer_id),
|
|
"version": params.get("version", None) or "",
|
|
"agreement_id": params.get("agreement_id", None) or "",
|
|
"privacy_policy": params.get("privacy_policy", None) or "",
|
|
"terms_of_service": params.get("terms_of_service", None) or "",
|
|
"format": params.get("format", None) or "",
|
|
"schema_version": params.get("schema_version", None) or "",
|
|
"checksum": params.get("checksum", None) or "",
|
|
"owner": params.get("owner", None) or "",
|
|
"license": params.get("license", None) or "",
|
|
"validity_start": params.get("validity_start", None) or "",
|
|
"validity_end": params.get("validity_end", None) or ""
|
|
# **source_metadata,
|
|
},
|
|
page_content=observation,
|
|
)
|
|
]
|
|
|
|
if loader_settings:
|
|
# Load the document
|
|
document = self._document_loader(observation, loader_settings)
|
|
print("DOC LENGTH", len(document))
|
|
for doc in document:
|
|
document_to_load = _stuct(doc.page_content, params)
|
|
retriever.add_documents(
|
|
document_to_load
|
|
)
|
|
|
|
return retriever.add_documents(
|
|
_stuct(observation, params)
|
|
)
|
|
|
|
async def fetch_memories(
|
|
self, observation: str, namespace: str, params: dict = None, n_of_observations =int(2)
|
|
):
|
|
"""
|
|
Get documents from weaviate.
|
|
|
|
Parameters:
|
|
- observation (str): User query.
|
|
- namespace (str): Type of memory we access.
|
|
- params (dict, optional):
|
|
- n_of_observations (int, optional): For weaviate, equals to autocut, defaults to 1. Ranges from 1 to 3. Check weaviate docs for more info.
|
|
|
|
Returns:
|
|
Describe the return type and what the function returns.
|
|
|
|
Args a json containing:
|
|
query (str): The query string.
|
|
path (list): The path for filtering, e.g., ['year'].
|
|
operator (str): The operator for filtering, e.g., 'Equal'.
|
|
valueText (str): The value for filtering, e.g., '2017*'.
|
|
|
|
Example:
|
|
get_from_weaviate(query="some query", path=['year'], operator='Equal', valueText='2017*')
|
|
|
|
"""
|
|
client = self.init_weaviate_client(self.namespace)
|
|
|
|
print(self.namespace)
|
|
print(str(datetime.now()))
|
|
print(observation)
|
|
if namespace is None:
|
|
namespace = self.namespace
|
|
|
|
params_user_id = {
|
|
"path": ["user_id"],
|
|
"operator": "Like",
|
|
"valueText": self.user_id,
|
|
}
|
|
|
|
if params:
|
|
query_output = (
|
|
client.query.get(
|
|
namespace,
|
|
[
|
|
# "text",
|
|
"user_id",
|
|
"memory_id",
|
|
"ltm_memory_id",
|
|
"st_memory_id",
|
|
"buffer_id",
|
|
"version",
|
|
"agreement_id",
|
|
"privacy_policy",
|
|
"terms_of_service",
|
|
"format",
|
|
"schema_version",
|
|
"checksum",
|
|
"owner",
|
|
"license",
|
|
"validity_start",
|
|
"validity_end",
|
|
],
|
|
)
|
|
.with_where(params)
|
|
.with_near_text({"concepts": [observation]})
|
|
.with_additional(
|
|
["id", "creationTimeUnix", "lastUpdateTimeUnix", "score",'distance']
|
|
)
|
|
.with_where(params_user_id)
|
|
.with_limit(10)
|
|
.do()
|
|
)
|
|
return query_output
|
|
else:
|
|
query_output = (
|
|
client.query.get(
|
|
namespace,
|
|
|
|
[
|
|
"text",
|
|
"user_id",
|
|
"memory_id",
|
|
"ltm_memory_id",
|
|
"st_memory_id",
|
|
"buffer_id",
|
|
"version",
|
|
"agreement_id",
|
|
"privacy_policy",
|
|
"terms_of_service",
|
|
"format",
|
|
"schema_version",
|
|
"checksum",
|
|
"owner",
|
|
"license",
|
|
"validity_start",
|
|
"validity_end",
|
|
],
|
|
)
|
|
.with_additional(
|
|
["id", "creationTimeUnix", "lastUpdateTimeUnix", "score", 'distance']
|
|
)
|
|
.with_hybrid(
|
|
query=observation,
|
|
fusion_type=HybridFusion.RELATIVE_SCORE
|
|
)
|
|
.with_autocut(n_of_observations)
|
|
.with_where(params_user_id)
|
|
.with_limit(10)
|
|
.do()
|
|
)
|
|
return query_output
|
|
|
|
async def delete_memories(self, params: dict = None):
|
|
client = self.init_weaviate_client(self.namespace)
|
|
if params:
|
|
where_filter = {
|
|
"path": ["id"],
|
|
"operator": "Equal",
|
|
"valueText": params.get("id", None),
|
|
}
|
|
return client.batch.delete_objects(
|
|
class_name=self.namespace,
|
|
# Same `where` filter as in the GraphQL API
|
|
where=where_filter,
|
|
)
|
|
else:
|
|
# Delete all objects
|
|
print("HERE IS THE USER ID", self.user_id)
|
|
return client.batch.delete_objects(
|
|
class_name=self.namespace,
|
|
where={
|
|
"path": ["user_id"],
|
|
"operator": "Equal",
|
|
"valueText": self.user_id,
|
|
},
|
|
)
|
|
|
|
def update_memories(self, observation, namespace: str, params: dict = None):
|
|
client = self.init_weaviate_client(self.namespace)
|
|
|
|
client.data_object.update(
|
|
data_object={
|
|
# "text": observation,
|
|
"user_id": str(self.user_id),
|
|
"memory_id": str(self.memory_id),
|
|
"ltm_memory_id": str(self.ltm_memory_id),
|
|
"st_memory_id": str(self.st_memory_id),
|
|
"buffer_id": str(self.buffer_id),
|
|
"version": params.get("version", None) or "",
|
|
"agreement_id": params.get("agreement_id", None) or "",
|
|
"privacy_policy": params.get("privacy_policy", None) or "",
|
|
"terms_of_service": params.get("terms_of_service", None) or "",
|
|
"format": params.get("format", None) or "",
|
|
"schema_version": params.get("schema_version", None) or "",
|
|
"checksum": params.get("checksum", None) or "",
|
|
"owner": params.get("owner", None) or "",
|
|
"license": params.get("license", None) or "",
|
|
"validity_start": params.get("validity_start", None) or "",
|
|
"validity_end": params.get("validity_end", None) or ""
|
|
# **source_metadata,
|
|
},
|
|
class_name="Test",
|
|
uuid=params.get("id", None),
|
|
consistency_level=weaviate.data.replication.ConsistencyLevel.ALL, # default QUORUM
|
|
)
|
|
return
|