From 6e01e9af79f100330ea4cc702c6bf6f0dab4e002 Mon Sep 17 00:00:00 2001 From: Vasilije <8619304+Vasilije1990@users.noreply.github.com> Date: Tue, 12 Sep 2023 16:09:18 +0200 Subject: [PATCH] Split the files and fixed issues with: 1. wrong uuid for st memmory 2. weaviate checker logic that was not needed 3. decomposed vector db and factory classes into separate files --- .../level_2_pdf_vectorstore__dlt_contracts.py | 6 +- level_2/vectordb/basevectordb.py | 348 +---------------- level_2/vectordb/vectordb.py | 355 ++++++++++++++++++ 3 files changed, 375 insertions(+), 334 deletions(-) create mode 100644 level_2/vectordb/vectordb.py diff --git a/level_2/level_2_pdf_vectorstore__dlt_contracts.py b/level_2/level_2_pdf_vectorstore__dlt_contracts.py index a5a3dd52d..9799d651f 100644 --- a/level_2/level_2_pdf_vectorstore__dlt_contracts.py +++ b/level_2/level_2_pdf_vectorstore__dlt_contracts.py @@ -74,7 +74,9 @@ marvin.settings.openai.api_key = os.environ.get("OPENAI_API_KEY") # Assuming OpenAIEmbeddings and other necessary imports are available -from vectordb.basevectordb import BaseMemory, PineconeVectorDB, WeaviateVectorDB +from vectordb.basevectordb import BaseMemory + + from modulators.modulators import DifferentiableLayer @@ -115,7 +117,7 @@ class EpisodicBuffer(BaseMemory): user_id, memory_id, index_name, db_type, namespace="BUFFERMEMORY" ) - self.st_memory_id = "blah" + self.st_memory_id = str( uuid.uuid4()) self.llm = ChatOpenAI( temperature=0.0, max_tokens=1200, diff --git a/level_2/vectordb/basevectordb.py b/level_2/vectordb/basevectordb.py index 664526281..3093286f7 100644 --- a/level_2/vectordb/basevectordb.py +++ b/level_2/vectordb/basevectordb.py @@ -2,6 +2,8 @@ import logging from io import BytesIO +from level_2.vectordb.vectordb import PineconeVectorDB, WeaviateVectorDB + logging.basicConfig(level=logging.INFO) import marvin import requests @@ -63,325 +65,8 @@ class VectorDBFactory: raise ValueError(f"Unsupported database type: {db_type}") -class VectorDB: - OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "") - - def __init__( - self, - user_id: str, - index_name: str, - memory_id: str, - ltm_memory_id: str = LTM_MEMORY_ID_DEFAULT, - st_memory_id: str = ST_MEMORY_ID_DEFAULT, - buffer_id: str = BUFFER_ID_DEFAULT, - namespace: str = None, - ): - self.user_id = user_id - self.index_name = index_name - self.namespace = namespace - self.memory_id = memory_id - self.ltm_memory_id = ltm_memory_id - self.st_memory_id = st_memory_id - self.buffer_id = buffer_id - -class PineconeVectorDB(VectorDB): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.init_pinecone(self.index_name) - - def init_pinecone(self, index_name): - # Pinecone initialization logic - pass -class WeaviateVectorDB(VectorDB): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.init_weaviate(self.namespace) - - def init_weaviate(self, namespace: str): - # Weaviate initialization logic - embeddings = OpenAIEmbeddings() - auth_config = weaviate.auth.AuthApiKey( - api_key=os.environ.get("WEAVIATE_API_KEY") - ) - client = weaviate.Client( - url=os.environ.get("WEAVIATE_URL"), - auth_client_secret=auth_config, - additional_headers={"X-OpenAI-Api-Key": os.environ.get("OPENAI_API_KEY")}, - ) - retriever = WeaviateHybridSearchRetriever( - client=client, - index_name=namespace, - text_key="text", - attributes=[], - embedding=embeddings, - create_schema_if_missing=True, - ) - return retriever # If this is part of the initialization, call it here. - - def init_weaviate_client(self, namespace: str): - # Weaviate client initialization logic - auth_config = weaviate.auth.AuthApiKey( - api_key=os.environ.get("WEAVIATE_API_KEY") - ) - client = weaviate.Client( - url=os.environ.get("WEAVIATE_URL"), - auth_client_secret=auth_config, - additional_headers={"X-OpenAI-Api-Key": os.environ.get("OPENAI_API_KEY")}, - ) - return client - - def _document_loader(self, observation: str, loader_settings: dict): - # Create an in-memory file-like object for the PDF content - - if loader_settings.get("format") == "PDF": - - if loader_settings.get("source") == "url": - pdf_response = requests.get(loader_settings["path"]) - pdf_stream = BytesIO(pdf_response.content) - contents = pdf_stream.read() - tmp_location = os.path.join("/tmp", "tmp.pdf") - with open(tmp_location, "wb") as tmp_file: - tmp_file.write(contents) - - # Process the PDF using PyPDFLoader - loader = PyPDFLoader(tmp_location) - # adapt this for different chunking strategies - pages = loader.load_and_split() - return pages - - if loader_settings.get("source") == "file": - # Process the PDF using PyPDFLoader - # might need adapting for different loaders + OCR - # need to test the path - loader = PyPDFLoader(loader_settings["path"]) - pages = loader.load_and_split() - - return pages - else: - # Process the text by just loading the base text - return observation - - - async def add_memories( - self, observation: str, loader_settings: dict = None, params: dict = None ,namespace:str=None - ): - # Update Weaviate memories here - print(self.namespace) - if namespace is None: - namespace = self.namespace - retriever = self.init_weaviate(namespace) - - def _stuct(observation, params): - """Utility function to not repeat metadata structure""" - # needs smarter solution, like dynamic generation of metadata - return [ - Document( - metadata={ - # "text": observation, - "user_id": str(self.user_id), - "memory_id": str(self.memory_id), - "ltm_memory_id": str(self.ltm_memory_id), - "st_memory_id": str(self.st_memory_id), - "buffer_id": str(self.buffer_id), - "version": params.get("version", None) or "", - "agreement_id": params.get("agreement_id", None) or "", - "privacy_policy": params.get("privacy_policy", None) or "", - "terms_of_service": params.get("terms_of_service", None) or "", - "format": params.get("format", None) or "", - "schema_version": params.get("schema_version", None) or "", - "checksum": params.get("checksum", None) or "", - "owner": params.get("owner", None) or "", - "license": params.get("license", None) or "", - "validity_start": params.get("validity_start", None) or "", - "validity_end": params.get("validity_end", None) or "" - # **source_metadata, - }, - page_content=observation, - ) - ] - - if loader_settings: - # Load the document - document = self._document_loader(observation, loader_settings) - print("DOC LENGTH", len(document)) - for doc in document: - document_to_load = _stuct(doc.page_content, params) - retriever.add_documents( - document_to_load - ) - - return retriever.add_documents( - _stuct(observation, params) - ) - - async def fetch_memories( - self, observation: str, namespace: str, params: dict = None, n_of_observations =int(2) - ): - """ - Get documents from weaviate. - - Parameters: - - observation (str): User query. - - namespace (str): Type of memory we access. - - params (dict, optional): - - n_of_observations (int, optional): For weaviate, equals to autocut, defaults to 1. Ranges from 1 to 3. Check weaviate docs for more info. - - Returns: - Describe the return type and what the function returns. - - Args a json containing: - query (str): The query string. - path (list): The path for filtering, e.g., ['year']. - operator (str): The operator for filtering, e.g., 'Equal'. - valueText (str): The value for filtering, e.g., '2017*'. - - Example: - get_from_weaviate(query="some query", path=['year'], operator='Equal', valueText='2017*') - - """ - client = self.init_weaviate_client(self.namespace) - - print(self.namespace) - print(str(datetime.now())) - print(observation) - if namespace is None: - namespace = self.namespace - - params_user_id = { - "path": ["user_id"], - "operator": "Like", - "valueText": self.user_id, - } - - if params: - query_output = ( - client.query.get( - namespace, - [ - # "text", - "user_id", - "memory_id", - "ltm_memory_id", - "st_memory_id", - "buffer_id", - "version", - "agreement_id", - "privacy_policy", - "terms_of_service", - "format", - "schema_version", - "checksum", - "owner", - "license", - "validity_start", - "validity_end", - ], - ) - .with_where(params) - .with_near_text({"concepts": [observation]}) - .with_additional( - ["id", "creationTimeUnix", "lastUpdateTimeUnix", "score",'distance'] - ) - .with_where(params_user_id) - .with_limit(10) - .do() - ) - return query_output - else: - query_output = ( - client.query.get( - namespace, - - [ - "text", - "user_id", - "memory_id", - "ltm_memory_id", - "st_memory_id", - "buffer_id", - "version", - "agreement_id", - "privacy_policy", - "terms_of_service", - "format", - "schema_version", - "checksum", - "owner", - "license", - "validity_start", - "validity_end", - ], - ) - .with_additional( - ["id", "creationTimeUnix", "lastUpdateTimeUnix", "score", 'distance'] - ) - .with_hybrid( - query=observation, - fusion_type=HybridFusion.RELATIVE_SCORE - ) - .with_autocut(n_of_observations) - .with_where(params_user_id) - .with_limit(10) - .do() - ) - return query_output - - async def delete_memories(self, params: dict = None): - client = self.init_weaviate_client(self.namespace) - if params: - where_filter = { - "path": ["id"], - "operator": "Equal", - "valueText": params.get("id", None), - } - return client.batch.delete_objects( - class_name=self.namespace, - # Same `where` filter as in the GraphQL API - where=where_filter, - ) - else: - # Delete all objects - print("HERE IS THE USER ID", self.user_id) - return client.batch.delete_objects( - class_name=self.namespace, - where={ - "path": ["user_id"], - "operator": "Equal", - "valueText": self.user_id, - }, - ) - - def update_memories(self, observation, namespace: str, params: dict = None): - client = self.init_weaviate_client(self.namespace) - - client.data_object.update( - data_object={ - # "text": observation, - "user_id": str(self.user_id), - "memory_id": str(self.memory_id), - "ltm_memory_id": str(self.ltm_memory_id), - "st_memory_id": str(self.st_memory_id), - "buffer_id": str(self.buffer_id), - "version": params.get("version", None) or "", - "agreement_id": params.get("agreement_id", None) or "", - "privacy_policy": params.get("privacy_policy", None) or "", - "terms_of_service": params.get("terms_of_service", None) or "", - "format": params.get("format", None) or "", - "schema_version": params.get("schema_version", None) or "", - "checksum": params.get("checksum", None) or "", - "owner": params.get("owner", None) or "", - "license": params.get("license", None) or "", - "validity_start": params.get("validity_start", None) or "", - "validity_end": params.get("validity_end", None) or "" - # **source_metadata, - }, - class_name="Test", - uuid=params.get("id", None), - consistency_level=weaviate.data.replication.ConsistencyLevel.ALL, # default QUORUM - ) - return class BaseMemory: @@ -409,8 +94,8 @@ class BaseMemory: ) def init_client(self, namespace: str): - if self.db_type == "weaviate": - return self.vector_db.init_weaviate_client(namespace) + + return self.vector_db.init_weaviate_client(namespace) async def add_memories( self, @@ -419,11 +104,11 @@ class BaseMemory: params: Optional[dict] = None, namespace: Optional[str] = None, ): - if self.db_type == "weaviate": - return await self.vector_db.add_memories( - observation=observation, loader_settings=loader_settings, - params=params, namespace=namespace - ) + + return await self.vector_db.add_memories( + observation=observation, loader_settings=loader_settings, + params=params, namespace=namespace + ) # Add other db_type conditions if necessary async def fetch_memories( @@ -433,15 +118,14 @@ class BaseMemory: namespace: Optional[str] = None, n_of_observations: Optional[int] = 2, ): - if self.db_type == "weaviate": - return await self.vector_db.fetch_memories( - observation=observation, params=params, - namespace=namespace, - n_of_observations=n_of_observations - ) + + return await self.vector_db.fetch_memories( + observation=observation, params=params, + namespace=namespace, + n_of_observations=n_of_observations + ) async def delete_memories(self, params: Optional[str] = None): - if self.db_type == "weaviate": - return await self.vector_db.delete_memories(params) + return await self.vector_db.delete_memories(params) # Additional methods for specific Memory can be added here diff --git a/level_2/vectordb/vectordb.py b/level_2/vectordb/vectordb.py new file mode 100644 index 000000000..4fa4f0308 --- /dev/null +++ b/level_2/vectordb/vectordb.py @@ -0,0 +1,355 @@ + +# Make sure to install the following packages: dlt, langchain, duckdb, python-dotenv, openai, weaviate-client +import logging +from io import BytesIO + + + +logging.basicConfig(level=logging.INFO) +import marvin +import requests +from dotenv import load_dotenv +from langchain.document_loaders import PyPDFLoader +from langchain.retrievers import WeaviateHybridSearchRetriever +from weaviate.gql.get import HybridFusion + +load_dotenv() +from typing import Optional + +import tracemalloc + +tracemalloc.start() + +import os +from datetime import datetime +from langchain.embeddings.openai import OpenAIEmbeddings +from dotenv import load_dotenv +from langchain.schema import Document +import uuid +import weaviate + +load_dotenv() + + +LTM_MEMORY_ID_DEFAULT = "00000" +ST_MEMORY_ID_DEFAULT = "0000" +BUFFER_ID_DEFAULT = "0000" +class VectorDB: + OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "") + + def __init__( + self, + user_id: str, + index_name: str, + memory_id: str, + ltm_memory_id: str = LTM_MEMORY_ID_DEFAULT, + st_memory_id: str = ST_MEMORY_ID_DEFAULT, + buffer_id: str = BUFFER_ID_DEFAULT, + namespace: str = None, + ): + self.user_id = user_id + self.index_name = index_name + self.namespace = namespace + self.memory_id = memory_id + self.ltm_memory_id = ltm_memory_id + self.st_memory_id = st_memory_id + self.buffer_id = buffer_id + +class PineconeVectorDB(VectorDB): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.init_pinecone(self.index_name) + + def init_pinecone(self, index_name): + # Pinecone initialization logic + pass + + +class WeaviateVectorDB(VectorDB): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.init_weaviate(self.namespace) + + def init_weaviate(self, namespace: str): + # Weaviate initialization logic + embeddings = OpenAIEmbeddings() + auth_config = weaviate.auth.AuthApiKey( + api_key=os.environ.get("WEAVIATE_API_KEY") + ) + client = weaviate.Client( + url=os.environ.get("WEAVIATE_URL"), + auth_client_secret=auth_config, + additional_headers={"X-OpenAI-Api-Key": os.environ.get("OPENAI_API_KEY")}, + ) + retriever = WeaviateHybridSearchRetriever( + client=client, + index_name=namespace, + text_key="text", + attributes=[], + embedding=embeddings, + create_schema_if_missing=True, + ) + return retriever # If this is part of the initialization, call it here. + + def init_weaviate_client(self, namespace: str): + # Weaviate client initialization logic + auth_config = weaviate.auth.AuthApiKey( + api_key=os.environ.get("WEAVIATE_API_KEY") + ) + client = weaviate.Client( + url=os.environ.get("WEAVIATE_URL"), + auth_client_secret=auth_config, + additional_headers={"X-OpenAI-Api-Key": os.environ.get("OPENAI_API_KEY")}, + ) + return client + + def _document_loader(self, observation: str, loader_settings: dict): + # Create an in-memory file-like object for the PDF content + + if loader_settings.get("format") == "PDF": + + if loader_settings.get("source") == "url": + pdf_response = requests.get(loader_settings["path"]) + pdf_stream = BytesIO(pdf_response.content) + contents = pdf_stream.read() + tmp_location = os.path.join("/tmp", "tmp.pdf") + with open(tmp_location, "wb") as tmp_file: + tmp_file.write(contents) + + # Process the PDF using PyPDFLoader + loader = PyPDFLoader(tmp_location) + # adapt this for different chunking strategies + pages = loader.load_and_split() + return pages + + if loader_settings.get("source") == "file": + # Process the PDF using PyPDFLoader + # might need adapting for different loaders + OCR + # need to test the path + loader = PyPDFLoader(loader_settings["path"]) + pages = loader.load_and_split() + + return pages + else: + # Process the text by just loading the base text + return observation + + + async def add_memories( + self, observation: str, loader_settings: dict = None, params: dict = None ,namespace:str=None + ): + # Update Weaviate memories here + print(self.namespace) + if namespace is None: + namespace = self.namespace + retriever = self.init_weaviate(namespace) + + def _stuct(observation, params): + """Utility function to not repeat metadata structure""" + # needs smarter solution, like dynamic generation of metadata + return [ + Document( + metadata={ + # "text": observation, + "user_id": str(self.user_id), + "memory_id": str(self.memory_id), + "ltm_memory_id": str(self.ltm_memory_id), + "st_memory_id": str(self.st_memory_id), + "buffer_id": str(self.buffer_id), + "version": params.get("version", None) or "", + "agreement_id": params.get("agreement_id", None) or "", + "privacy_policy": params.get("privacy_policy", None) or "", + "terms_of_service": params.get("terms_of_service", None) or "", + "format": params.get("format", None) or "", + "schema_version": params.get("schema_version", None) or "", + "checksum": params.get("checksum", None) or "", + "owner": params.get("owner", None) or "", + "license": params.get("license", None) or "", + "validity_start": params.get("validity_start", None) or "", + "validity_end": params.get("validity_end", None) or "" + # **source_metadata, + }, + page_content=observation, + ) + ] + + if loader_settings: + # Load the document + document = self._document_loader(observation, loader_settings) + print("DOC LENGTH", len(document)) + for doc in document: + document_to_load = _stuct(doc.page_content, params) + retriever.add_documents( + document_to_load + ) + + return retriever.add_documents( + _stuct(observation, params) + ) + + async def fetch_memories( + self, observation: str, namespace: str, params: dict = None, n_of_observations =int(2) + ): + """ + Get documents from weaviate. + + Parameters: + - observation (str): User query. + - namespace (str): Type of memory we access. + - params (dict, optional): + - n_of_observations (int, optional): For weaviate, equals to autocut, defaults to 1. Ranges from 1 to 3. Check weaviate docs for more info. + + Returns: + Describe the return type and what the function returns. + + Args a json containing: + query (str): The query string. + path (list): The path for filtering, e.g., ['year']. + operator (str): The operator for filtering, e.g., 'Equal'. + valueText (str): The value for filtering, e.g., '2017*'. + + Example: + get_from_weaviate(query="some query", path=['year'], operator='Equal', valueText='2017*') + + """ + client = self.init_weaviate_client(self.namespace) + + print(self.namespace) + print(str(datetime.now())) + print(observation) + if namespace is None: + namespace = self.namespace + + params_user_id = { + "path": ["user_id"], + "operator": "Like", + "valueText": self.user_id, + } + + if params: + query_output = ( + client.query.get( + namespace, + [ + # "text", + "user_id", + "memory_id", + "ltm_memory_id", + "st_memory_id", + "buffer_id", + "version", + "agreement_id", + "privacy_policy", + "terms_of_service", + "format", + "schema_version", + "checksum", + "owner", + "license", + "validity_start", + "validity_end", + ], + ) + .with_where(params) + .with_near_text({"concepts": [observation]}) + .with_additional( + ["id", "creationTimeUnix", "lastUpdateTimeUnix", "score",'distance'] + ) + .with_where(params_user_id) + .with_limit(10) + .do() + ) + return query_output + else: + query_output = ( + client.query.get( + namespace, + + [ + "text", + "user_id", + "memory_id", + "ltm_memory_id", + "st_memory_id", + "buffer_id", + "version", + "agreement_id", + "privacy_policy", + "terms_of_service", + "format", + "schema_version", + "checksum", + "owner", + "license", + "validity_start", + "validity_end", + ], + ) + .with_additional( + ["id", "creationTimeUnix", "lastUpdateTimeUnix", "score", 'distance'] + ) + .with_hybrid( + query=observation, + fusion_type=HybridFusion.RELATIVE_SCORE + ) + .with_autocut(n_of_observations) + .with_where(params_user_id) + .with_limit(10) + .do() + ) + return query_output + + async def delete_memories(self, params: dict = None): + client = self.init_weaviate_client(self.namespace) + if params: + where_filter = { + "path": ["id"], + "operator": "Equal", + "valueText": params.get("id", None), + } + return client.batch.delete_objects( + class_name=self.namespace, + # Same `where` filter as in the GraphQL API + where=where_filter, + ) + else: + # Delete all objects + print("HERE IS THE USER ID", self.user_id) + return client.batch.delete_objects( + class_name=self.namespace, + where={ + "path": ["user_id"], + "operator": "Equal", + "valueText": self.user_id, + }, + ) + + def update_memories(self, observation, namespace: str, params: dict = None): + client = self.init_weaviate_client(self.namespace) + + client.data_object.update( + data_object={ + # "text": observation, + "user_id": str(self.user_id), + "memory_id": str(self.memory_id), + "ltm_memory_id": str(self.ltm_memory_id), + "st_memory_id": str(self.st_memory_id), + "buffer_id": str(self.buffer_id), + "version": params.get("version", None) or "", + "agreement_id": params.get("agreement_id", None) or "", + "privacy_policy": params.get("privacy_policy", None) or "", + "terms_of_service": params.get("terms_of_service", None) or "", + "format": params.get("format", None) or "", + "schema_version": params.get("schema_version", None) or "", + "checksum": params.get("checksum", None) or "", + "owner": params.get("owner", None) or "", + "license": params.get("license", None) or "", + "validity_start": params.get("validity_start", None) or "", + "validity_end": params.get("validity_end", None) or "" + # **source_metadata, + }, + class_name="Test", + uuid=params.get("id", None), + consistency_level=weaviate.data.replication.ConsistencyLevel.ALL, # default QUORUM + ) + return