Added a few fixes and refactored the base app
This commit is contained in:
parent
987364606e
commit
076497ef15
2 changed files with 66 additions and 28 deletions
|
|
@ -40,34 +40,28 @@ import json
|
||||||
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
|
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
|
||||||
marvin.settings.openai.api_key = os.environ.get("OPENAI_API_KEY")
|
marvin.settings.openai.api_key = os.environ.get("OPENAI_API_KEY")
|
||||||
|
|
||||||
LTM_MEMORY_ID_DEFAULT = "00000"
|
|
||||||
ST_MEMORY_ID_DEFAULT = "0000"
|
|
||||||
BUFFER_ID_DEFAULT = "0000"
|
|
||||||
|
|
||||||
|
|
||||||
class VectorDBFactory:
|
class VectorDBFactory:
|
||||||
|
def __init__(self):
|
||||||
|
self.db_map = {
|
||||||
|
"pinecone": PineconeVectorDB,
|
||||||
|
"weaviate": WeaviateVectorDB,
|
||||||
|
# Add more database types and their corresponding classes here
|
||||||
|
}
|
||||||
|
|
||||||
def create_vector_db(
|
def create_vector_db(
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
index_name: str,
|
index_name: str,
|
||||||
memory_id: str,
|
memory_id: str,
|
||||||
ltm_memory_id: str = LTM_MEMORY_ID_DEFAULT,
|
db_type: str = "weaviate",
|
||||||
st_memory_id: str = ST_MEMORY_ID_DEFAULT,
|
|
||||||
buffer_id: str = BUFFER_ID_DEFAULT,
|
|
||||||
db_type: str = "pinecone",
|
|
||||||
namespace: str = None,
|
namespace: str = None,
|
||||||
embeddings = None,
|
embeddings=None,
|
||||||
):
|
):
|
||||||
db_map = {"pinecone": PineconeVectorDB, "weaviate": WeaviateVectorDB}
|
if db_type in self.db_map:
|
||||||
|
return self.db_map[db_type](
|
||||||
if db_type in db_map:
|
|
||||||
return db_map[db_type](
|
|
||||||
user_id,
|
user_id,
|
||||||
index_name,
|
index_name,
|
||||||
memory_id,
|
memory_id,
|
||||||
ltm_memory_id,
|
|
||||||
st_memory_id,
|
|
||||||
buffer_id,
|
|
||||||
namespace,
|
namespace,
|
||||||
embeddings
|
embeddings
|
||||||
)
|
)
|
||||||
|
|
@ -101,8 +95,61 @@ class BaseMemory:
|
||||||
)
|
)
|
||||||
|
|
||||||
def init_client(self, embeddings, namespace: str):
|
def init_client(self, embeddings, namespace: str):
|
||||||
|
return self.vector_db.init_client(embeddings, namespace)
|
||||||
|
|
||||||
return self.vector_db.init_weaviate_client(embeddings, namespace)
|
|
||||||
|
# class VectorDBFactory:
|
||||||
|
# def create_vector_db(
|
||||||
|
# self,
|
||||||
|
# user_id: str,
|
||||||
|
# index_name: str,
|
||||||
|
# memory_id: str,
|
||||||
|
# db_type: str = "pinecone",
|
||||||
|
# namespace: str = None,
|
||||||
|
# embeddings = None,
|
||||||
|
# ):
|
||||||
|
# db_map = {"pinecone": PineconeVectorDB, "weaviate": WeaviateVectorDB}
|
||||||
|
#
|
||||||
|
# if db_type in db_map:
|
||||||
|
# return db_map[db_type](
|
||||||
|
# user_id,
|
||||||
|
# index_name,
|
||||||
|
# memory_id,
|
||||||
|
# namespace,
|
||||||
|
# embeddings
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# raise ValueError(f"Unsupported database type: {db_type}")
|
||||||
|
#
|
||||||
|
# class BaseMemory:
|
||||||
|
# def __init__(
|
||||||
|
# self,
|
||||||
|
# user_id: str,
|
||||||
|
# memory_id: Optional[str],
|
||||||
|
# index_name: Optional[str],
|
||||||
|
# db_type: str,
|
||||||
|
# namespace: str,
|
||||||
|
# embeddings: Optional[None],
|
||||||
|
# ):
|
||||||
|
# self.user_id = user_id
|
||||||
|
# self.memory_id = memory_id
|
||||||
|
# self.index_name = index_name
|
||||||
|
# self.namespace = namespace
|
||||||
|
# self.embeddings = embeddings
|
||||||
|
# self.db_type = db_type
|
||||||
|
# factory = VectorDBFactory()
|
||||||
|
# self.vector_db = factory.create_vector_db(
|
||||||
|
# self.user_id,
|
||||||
|
# self.index_name,
|
||||||
|
# self.memory_id,
|
||||||
|
# db_type=self.db_type,
|
||||||
|
# namespace=self.namespace,
|
||||||
|
# embeddings=self.embeddings
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# def init_client(self, embeddings, namespace: str):
|
||||||
|
#
|
||||||
|
# return self.vector_db.init_weaviate_client(embeddings, namespace)
|
||||||
|
|
||||||
def create_field(self, field_type, **kwargs):
|
def create_field(self, field_type, **kwargs):
|
||||||
field_mapping = {
|
field_mapping = {
|
||||||
|
|
|
||||||
|
|
@ -32,9 +32,6 @@ class VectorDB:
|
||||||
user_id: str,
|
user_id: str,
|
||||||
index_name: str,
|
index_name: str,
|
||||||
memory_id: str,
|
memory_id: str,
|
||||||
ltm_memory_id: str = LTM_MEMORY_ID_DEFAULT,
|
|
||||||
st_memory_id: str = ST_MEMORY_ID_DEFAULT,
|
|
||||||
buffer_id: str = BUFFER_ID_DEFAULT,
|
|
||||||
namespace: str = None,
|
namespace: str = None,
|
||||||
embeddings = None,
|
embeddings = None,
|
||||||
):
|
):
|
||||||
|
|
@ -42,9 +39,6 @@ class VectorDB:
|
||||||
self.index_name = index_name
|
self.index_name = index_name
|
||||||
self.namespace = namespace
|
self.namespace = namespace
|
||||||
self.memory_id = memory_id
|
self.memory_id = memory_id
|
||||||
self.ltm_memory_id = ltm_memory_id
|
|
||||||
self.st_memory_id = st_memory_id
|
|
||||||
self.buffer_id = buffer_id
|
|
||||||
self.embeddings = embeddings
|
self.embeddings = embeddings
|
||||||
|
|
||||||
class PineconeVectorDB(VectorDB):
|
class PineconeVectorDB(VectorDB):
|
||||||
|
|
@ -81,7 +75,7 @@ class WeaviateVectorDB(VectorDB):
|
||||||
embedding=embeddings,
|
embedding=embeddings,
|
||||||
create_schema_if_missing=True,
|
create_schema_if_missing=True,
|
||||||
)
|
)
|
||||||
return retriever # If this is part of the initialization, call it here.
|
return retriever
|
||||||
|
|
||||||
def init_weaviate_client(self, namespace: str):
|
def init_weaviate_client(self, namespace: str):
|
||||||
# Weaviate client initialization logic
|
# Weaviate client initialization logic
|
||||||
|
|
@ -295,9 +289,6 @@ class WeaviateVectorDB(VectorDB):
|
||||||
data_object={
|
data_object={
|
||||||
# "text": observation,
|
# "text": observation,
|
||||||
"user_id": str(self.user_id),
|
"user_id": str(self.user_id),
|
||||||
"memory_id": str(self.memory_id),
|
|
||||||
"ltm_memory_id": str(self.ltm_memory_id),
|
|
||||||
"st_memory_id": str(self.st_memory_id),
|
|
||||||
"buffer_id": str(self.buffer_id),
|
"buffer_id": str(self.buffer_id),
|
||||||
"version": params.get("version", None) or "",
|
"version": params.get("version", None) or "",
|
||||||
"agreement_id": params.get("agreement_id", None) or "",
|
"agreement_id": params.get("agreement_id", None) or "",
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue