diff --git a/level_3/vectordb/basevectordb.py b/level_3/vectordb/basevectordb.py index bbcec0dde..18e022ce2 100644 --- a/level_3/vectordb/basevectordb.py +++ b/level_3/vectordb/basevectordb.py @@ -40,34 +40,28 @@ import json OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "") marvin.settings.openai.api_key = os.environ.get("OPENAI_API_KEY") -LTM_MEMORY_ID_DEFAULT = "00000" -ST_MEMORY_ID_DEFAULT = "0000" -BUFFER_ID_DEFAULT = "0000" - - class VectorDBFactory: + def __init__(self): + self.db_map = { + "pinecone": PineconeVectorDB, + "weaviate": WeaviateVectorDB, + # Add more database types and their corresponding classes here + } + def create_vector_db( self, user_id: str, index_name: str, memory_id: str, - ltm_memory_id: str = LTM_MEMORY_ID_DEFAULT, - st_memory_id: str = ST_MEMORY_ID_DEFAULT, - buffer_id: str = BUFFER_ID_DEFAULT, - db_type: str = "pinecone", + db_type: str = "weaviate", namespace: str = None, - embeddings = None, + embeddings=None, ): - db_map = {"pinecone": PineconeVectorDB, "weaviate": WeaviateVectorDB} - - if db_type in db_map: - return db_map[db_type]( + if db_type in self.db_map: + return self.db_map[db_type]( user_id, index_name, memory_id, - ltm_memory_id, - st_memory_id, - buffer_id, namespace, embeddings ) @@ -101,8 +95,61 @@ class BaseMemory: ) def init_client(self, embeddings, namespace: str): + return self.vector_db.init_client(embeddings, namespace) - return self.vector_db.init_weaviate_client(embeddings, namespace) + +# class VectorDBFactory: +# def create_vector_db( +# self, +# user_id: str, +# index_name: str, +# memory_id: str, +# db_type: str = "pinecone", +# namespace: str = None, +# embeddings = None, +# ): +# db_map = {"pinecone": PineconeVectorDB, "weaviate": WeaviateVectorDB} +# +# if db_type in db_map: +# return db_map[db_type]( +# user_id, +# index_name, +# memory_id, +# namespace, +# embeddings +# ) +# +# raise ValueError(f"Unsupported database type: {db_type}") +# +# class BaseMemory: +# def __init__( +# self, +# user_id: str, +# memory_id: Optional[str], +# index_name: Optional[str], +# db_type: str, +# namespace: str, +# embeddings: Optional[None], +# ): +# self.user_id = user_id +# self.memory_id = memory_id +# self.index_name = index_name +# self.namespace = namespace +# self.embeddings = embeddings +# self.db_type = db_type +# factory = VectorDBFactory() +# self.vector_db = factory.create_vector_db( +# self.user_id, +# self.index_name, +# self.memory_id, +# db_type=self.db_type, +# namespace=self.namespace, +# embeddings=self.embeddings +# ) +# +# def init_client(self, embeddings, namespace: str): +# +# return self.vector_db.init_weaviate_client(embeddings, namespace) def create_field(self, field_type, **kwargs): field_mapping = { diff --git a/level_3/vectordb/vectordb.py b/level_3/vectordb/vectordb.py index 2382e9d04..d413fe4d3 100644 --- a/level_3/vectordb/vectordb.py +++ b/level_3/vectordb/vectordb.py @@ -32,9 +32,6 @@ class VectorDB: user_id: str, index_name: str, memory_id: str, - ltm_memory_id: str = LTM_MEMORY_ID_DEFAULT, - st_memory_id: str = ST_MEMORY_ID_DEFAULT, - buffer_id: str = BUFFER_ID_DEFAULT, namespace: str = None, embeddings = None, ): @@ -42,9 +39,6 @@ class VectorDB: self.index_name = index_name self.namespace = namespace self.memory_id = memory_id - self.ltm_memory_id = ltm_memory_id - self.st_memory_id = st_memory_id - self.buffer_id = buffer_id self.embeddings = embeddings class PineconeVectorDB(VectorDB): @@ -81,7 +75,7 @@ class WeaviateVectorDB(VectorDB): embedding=embeddings, create_schema_if_missing=True, ) - return retriever # If this is part of the initialization, call it here. + return retriever def init_weaviate_client(self, namespace: str): # Weaviate client initialization logic @@ -295,9 +289,6 @@ class WeaviateVectorDB(VectorDB): data_object={ # "text": observation, "user_id": str(self.user_id), - "memory_id": str(self.memory_id), - "ltm_memory_id": str(self.ltm_memory_id), - "st_memory_id": str(self.st_memory_id), "buffer_id": str(self.buffer_id), "version": params.get("version", None) or "", "agreement_id": params.get("agreement_id", None) or "",