Added a few fixes and refactored the base app
This commit is contained in:
parent
f81fc276c4
commit
81b8fd923c
3 changed files with 55 additions and 39 deletions
|
|
@ -13,9 +13,12 @@ class Operation(Base):
|
|||
|
||||
id = Column(String, primary_key=True)
|
||||
user_id = Column(String, ForeignKey('users.id'), index=True) # Link to User
|
||||
operation_type = Column(String, nullable=True)
|
||||
operation_params = Column(String, nullable=True)
|
||||
test_set_id = Column(String, ForeignKey('test_sets.id'), index=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, onupdate=datetime.utcnow)
|
||||
memories = relationship("MemoryModel", back_populates="operation")
|
||||
|
||||
# Relationships
|
||||
user = relationship("User", back_populates="operations")
|
||||
|
|
|
|||
|
|
@ -282,7 +282,10 @@ def generate_letter_uuid(length=8):
|
|||
letters = string.ascii_uppercase # A-Z
|
||||
return ''.join(random.choice(letters) for _ in range(length))
|
||||
|
||||
async def start_test(data, test_set=None, user_id=None, params=None, job_id=None, metadata=None, generate_test_set=False, only_llm_context=False):
|
||||
async def start_test(data, test_set=None, user_id=None, params=None, job_id=None, metadata=None, generate_test_set=False, retriever_type:str=None):
|
||||
|
||||
|
||||
"""retriever_type = "llm_context, single_document_context, multi_document_context, "cognitive_architecture"""""
|
||||
|
||||
|
||||
async with session_scope(session=AsyncSessionLocal()) as session:
|
||||
|
|
@ -294,9 +297,7 @@ async def start_test(data, test_set=None, user_id=None, params=None, job_id=None
|
|||
await memory.add_memory_instance("ExampleMemory")
|
||||
existing_user = await Memory.check_existing_user(user_id, session)
|
||||
|
||||
if job_id is None:
|
||||
job_id = str(uuid.uuid4())
|
||||
await add_entity(session, Operation(id=job_id, user_id=user_id))
|
||||
|
||||
|
||||
if test_set_id is None:
|
||||
test_set_id = str(uuid.uuid4())
|
||||
|
|
@ -318,8 +319,13 @@ async def start_test(data, test_set=None, user_id=None, params=None, job_id=None
|
|||
"path": data
|
||||
}
|
||||
|
||||
if job_id is None:
|
||||
job_id = str(uuid.uuid4())
|
||||
|
||||
async def run_test(test, loader_settings, metadata, test_id=None,only_llm_context=False):
|
||||
await add_entity(session, Operation(id=job_id, user_id=user_id, operation_params =str(test_params), operation_type=retriever_type, test_set_id=test_set_id))
|
||||
|
||||
|
||||
async def run_test(test, loader_settings, metadata, test_id=None,retriever_type=False):
|
||||
|
||||
if test_id is None:
|
||||
test_id = str(generate_letter_uuid()) + "_" +"SEMANTICMEMORY"
|
||||
|
|
@ -372,7 +378,7 @@ async def start_test(data, test_set=None, user_id=None, params=None, job_id=None
|
|||
test_eval_pipeline =[]
|
||||
|
||||
|
||||
if only_llm_context:
|
||||
if retriever_type == "llm_context":
|
||||
for test_qa in test_set:
|
||||
context=""
|
||||
test_result = await run_eval(test_qa, context)
|
||||
|
|
@ -399,13 +405,13 @@ async def start_test(data, test_set=None, user_id=None, params=None, job_id=None
|
|||
|
||||
results = []
|
||||
|
||||
if only_llm_context:
|
||||
if retriever_type:
|
||||
test_id, result = await run_test(test=None, loader_settings=loader_settings, metadata=metadata,
|
||||
only_llm_context=only_llm_context)
|
||||
retriever_type=retriever_type)
|
||||
results.append(result)
|
||||
|
||||
for param in test_params:
|
||||
test_id, result = await run_test(param, loader_settings, metadata, only_llm_context=only_llm_context)
|
||||
test_id, result = await run_test(param, loader_settings, metadata, retriever_type=retriever_type)
|
||||
results.append(result)
|
||||
|
||||
|
||||
|
|
@ -458,7 +464,7 @@ async def main():
|
|||
]
|
||||
# "https://www.ibiblio.org/ebooks/London/Call%20of%20Wild.pdf"
|
||||
#http://public-library.uk/ebooks/59/83.pdf
|
||||
result = await start_test(".data/3ZCCCW.pdf", test_set=test_set, user_id="677", params=None, metadata=metadata)
|
||||
result = await start_test(".data/3ZCCCW.pdf", test_set=test_set, user_id="677", params=None, metadata=metadata, retriever_type='llm_context')
|
||||
#
|
||||
# parser = argparse.ArgumentParser(description="Run tests against a document.")
|
||||
# parser.add_argument("--url", required=True, help="URL of the document to test.")
|
||||
|
|
|
|||
|
|
@ -2,13 +2,14 @@
|
|||
# Make sure to install the following packages: dlt, langchain, duckdb, python-dotenv, openai, weaviate-client
|
||||
import logging
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from marshmallow import Schema, fields
|
||||
from loaders.loaders import _document_loader
|
||||
# Add the parent directory to sys.path
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
from langchain.retrievers import WeaviateHybridSearchRetriever
|
||||
from langchain.retrievers import WeaviateHybridSearchRetriever, ParentDocumentRetriever
|
||||
from weaviate.gql.get import HybridFusion
|
||||
import tracemalloc
|
||||
tracemalloc.start()
|
||||
|
|
@ -56,9 +57,8 @@ class WeaviateVectorDB(VectorDB):
|
|||
super().__init__(*args, **kwargs)
|
||||
self.init_weaviate(embeddings= self.embeddings, namespace = self.namespace)
|
||||
|
||||
def init_weaviate(self, embeddings =OpenAIEmbeddings() , namespace: str=None):
|
||||
def init_weaviate(self, embeddings=OpenAIEmbeddings(), namespace=None,retriever_type="",):
|
||||
# Weaviate initialization logic
|
||||
# embeddings = OpenAIEmbeddings()
|
||||
auth_config = weaviate.auth.AuthApiKey(
|
||||
api_key=os.environ.get("WEAVIATE_API_KEY")
|
||||
)
|
||||
|
|
@ -67,28 +67,36 @@ class WeaviateVectorDB(VectorDB):
|
|||
auth_client_secret=auth_config,
|
||||
additional_headers={"X-OpenAI-Api-Key": os.environ.get("OPENAI_API_KEY")},
|
||||
)
|
||||
retriever = WeaviateHybridSearchRetriever(
|
||||
client=client,
|
||||
index_name=namespace,
|
||||
text_key="text",
|
||||
attributes=[],
|
||||
embedding=embeddings,
|
||||
create_schema_if_missing=True,
|
||||
)
|
||||
return retriever
|
||||
|
||||
def init_weaviate_client(self, namespace: str):
|
||||
# Weaviate client initialization logic
|
||||
auth_config = weaviate.auth.AuthApiKey(
|
||||
api_key=os.environ.get("WEAVIATE_API_KEY")
|
||||
)
|
||||
client = weaviate.Client(
|
||||
url=os.environ.get("WEAVIATE_URL"),
|
||||
auth_client_secret=auth_config,
|
||||
additional_headers={"X-OpenAI-Api-Key": os.environ.get("OPENAI_API_KEY")},
|
||||
)
|
||||
return client
|
||||
|
||||
if retriever_type == "single_document_context":
|
||||
retriever = WeaviateHybridSearchRetriever(
|
||||
client=client,
|
||||
index_name=namespace,
|
||||
text_key="text",
|
||||
attributes=[],
|
||||
embedding=embeddings,
|
||||
create_schema_if_missing=True,
|
||||
)
|
||||
return retriever
|
||||
elif retriever_type == "multi_document_context":
|
||||
retriever = WeaviateHybridSearchRetriever(
|
||||
client=client,
|
||||
index_name=namespace,
|
||||
text_key="text",
|
||||
attributes=[],
|
||||
embedding=embeddings,
|
||||
create_schema_if_missing=True,
|
||||
)
|
||||
return retriever
|
||||
else :
|
||||
return client
|
||||
# child_splitter = RecursiveCharacterTextSplitter(chunk_size=400)
|
||||
# store = InMemoryStore()
|
||||
# retriever = ParentDocumentRetriever(
|
||||
# vectorstore=vectorstore,
|
||||
# docstore=store,
|
||||
# child_splitter=child_splitter,
|
||||
# )
|
||||
from marshmallow import Schema, fields
|
||||
|
||||
def create_document_structure(observation, params, metadata_schema_class=None):
|
||||
|
|
@ -140,7 +148,7 @@ class WeaviateVectorDB(VectorDB):
|
|||
# Update Weaviate memories here
|
||||
if namespace is None:
|
||||
namespace = self.namespace
|
||||
retriever = self.init_weaviate(embeddings=embeddings,namespace = namespace)
|
||||
retriever = self.init_weaviate(embeddings=embeddings,namespace = namespace, retriever_type="single_document_context")
|
||||
if loader_settings:
|
||||
# Assuming _document_loader returns a list of documents
|
||||
documents = await _document_loader(observation, loader_settings)
|
||||
|
|
@ -174,7 +182,7 @@ class WeaviateVectorDB(VectorDB):
|
|||
Example:
|
||||
fetch_memories(query="some query", search_type='text', additional_param='value')
|
||||
"""
|
||||
client = self.init_weaviate_client(self.namespace)
|
||||
client = self.init_weaviate(namespace =self.namespace)
|
||||
if search_type is None:
|
||||
search_type = 'hybrid'
|
||||
|
||||
|
|
@ -258,7 +266,7 @@ class WeaviateVectorDB(VectorDB):
|
|||
async def delete_memories(self, namespace:str, params: dict = None):
|
||||
if namespace is None:
|
||||
namespace = self.namespace
|
||||
client = self.init_weaviate_client(self.namespace)
|
||||
client = self.init_weaviate(namespace = self.namespace)
|
||||
if params:
|
||||
where_filter = {
|
||||
"path": ["id"],
|
||||
|
|
@ -283,13 +291,12 @@ class WeaviateVectorDB(VectorDB):
|
|||
)
|
||||
|
||||
def update_memories(self, observation, namespace: str, params: dict = None):
|
||||
client = self.init_weaviate_client(self.namespace)
|
||||
client = self.init_weaviate(namespace = self.namespace)
|
||||
|
||||
client.data_object.update(
|
||||
data_object={
|
||||
# "text": observation,
|
||||
"user_id": str(self.user_id),
|
||||
"buffer_id": str(self.buffer_id),
|
||||
"version": params.get("version", None) or "",
|
||||
"agreement_id": params.get("agreement_id", None) or "",
|
||||
"privacy_policy": params.get("privacy_policy", None) or "",
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue