Added a few fixes and refactored the base app

This commit is contained in:
Vasilije 2023-10-25 15:28:20 +02:00
parent f81fc276c4
commit 81b8fd923c
3 changed files with 55 additions and 39 deletions

View file

@ -13,9 +13,12 @@ class Operation(Base):
id = Column(String, primary_key=True)
user_id = Column(String, ForeignKey('users.id'), index=True) # Link to User
operation_type = Column(String, nullable=True)
operation_params = Column(String, nullable=True)
test_set_id = Column(String, ForeignKey('test_sets.id'), index=True)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, onupdate=datetime.utcnow)
memories = relationship("MemoryModel", back_populates="operation")
# Relationships
user = relationship("User", back_populates="operations")

View file

@ -282,7 +282,10 @@ def generate_letter_uuid(length=8):
letters = string.ascii_uppercase # A-Z
return ''.join(random.choice(letters) for _ in range(length))
async def start_test(data, test_set=None, user_id=None, params=None, job_id=None, metadata=None, generate_test_set=False, only_llm_context=False):
async def start_test(data, test_set=None, user_id=None, params=None, job_id=None, metadata=None, generate_test_set=False, retriever_type:str=None):
"""retriever_type = "llm_context, single_document_context, multi_document_context, "cognitive_architecture"""""
async with session_scope(session=AsyncSessionLocal()) as session:
@ -294,9 +297,7 @@ async def start_test(data, test_set=None, user_id=None, params=None, job_id=None
await memory.add_memory_instance("ExampleMemory")
existing_user = await Memory.check_existing_user(user_id, session)
if job_id is None:
job_id = str(uuid.uuid4())
await add_entity(session, Operation(id=job_id, user_id=user_id))
if test_set_id is None:
test_set_id = str(uuid.uuid4())
@ -318,8 +319,13 @@ async def start_test(data, test_set=None, user_id=None, params=None, job_id=None
"path": data
}
if job_id is None:
job_id = str(uuid.uuid4())
async def run_test(test, loader_settings, metadata, test_id=None,only_llm_context=False):
await add_entity(session, Operation(id=job_id, user_id=user_id, operation_params =str(test_params), operation_type=retriever_type, test_set_id=test_set_id))
async def run_test(test, loader_settings, metadata, test_id=None,retriever_type=False):
if test_id is None:
test_id = str(generate_letter_uuid()) + "_" +"SEMANTICMEMORY"
@ -372,7 +378,7 @@ async def start_test(data, test_set=None, user_id=None, params=None, job_id=None
test_eval_pipeline =[]
if only_llm_context:
if retriever_type == "llm_context":
for test_qa in test_set:
context=""
test_result = await run_eval(test_qa, context)
@ -399,13 +405,13 @@ async def start_test(data, test_set=None, user_id=None, params=None, job_id=None
results = []
if only_llm_context:
if retriever_type:
test_id, result = await run_test(test=None, loader_settings=loader_settings, metadata=metadata,
only_llm_context=only_llm_context)
retriever_type=retriever_type)
results.append(result)
for param in test_params:
test_id, result = await run_test(param, loader_settings, metadata, only_llm_context=only_llm_context)
test_id, result = await run_test(param, loader_settings, metadata, retriever_type=retriever_type)
results.append(result)
@ -458,7 +464,7 @@ async def main():
]
# "https://www.ibiblio.org/ebooks/London/Call%20of%20Wild.pdf"
#http://public-library.uk/ebooks/59/83.pdf
result = await start_test(".data/3ZCCCW.pdf", test_set=test_set, user_id="677", params=None, metadata=metadata)
result = await start_test(".data/3ZCCCW.pdf", test_set=test_set, user_id="677", params=None, metadata=metadata, retriever_type='llm_context')
#
# parser = argparse.ArgumentParser(description="Run tests against a document.")
# parser.add_argument("--url", required=True, help="URL of the document to test.")

View file

@ -2,13 +2,14 @@
# Make sure to install the following packages: dlt, langchain, duckdb, python-dotenv, openai, weaviate-client
import logging
from langchain.text_splitter import RecursiveCharacterTextSplitter
from marshmallow import Schema, fields
from loaders.loaders import _document_loader
# Add the parent directory to sys.path
logging.basicConfig(level=logging.INFO)
from langchain.retrievers import WeaviateHybridSearchRetriever
from langchain.retrievers import WeaviateHybridSearchRetriever, ParentDocumentRetriever
from weaviate.gql.get import HybridFusion
import tracemalloc
tracemalloc.start()
@ -56,9 +57,8 @@ class WeaviateVectorDB(VectorDB):
super().__init__(*args, **kwargs)
self.init_weaviate(embeddings= self.embeddings, namespace = self.namespace)
def init_weaviate(self, embeddings =OpenAIEmbeddings() , namespace: str=None):
def init_weaviate(self, embeddings=OpenAIEmbeddings(), namespace=None,retriever_type="",):
# Weaviate initialization logic
# embeddings = OpenAIEmbeddings()
auth_config = weaviate.auth.AuthApiKey(
api_key=os.environ.get("WEAVIATE_API_KEY")
)
@ -67,28 +67,36 @@ class WeaviateVectorDB(VectorDB):
auth_client_secret=auth_config,
additional_headers={"X-OpenAI-Api-Key": os.environ.get("OPENAI_API_KEY")},
)
retriever = WeaviateHybridSearchRetriever(
client=client,
index_name=namespace,
text_key="text",
attributes=[],
embedding=embeddings,
create_schema_if_missing=True,
)
return retriever
def init_weaviate_client(self, namespace: str):
# Weaviate client initialization logic
auth_config = weaviate.auth.AuthApiKey(
api_key=os.environ.get("WEAVIATE_API_KEY")
)
client = weaviate.Client(
url=os.environ.get("WEAVIATE_URL"),
auth_client_secret=auth_config,
additional_headers={"X-OpenAI-Api-Key": os.environ.get("OPENAI_API_KEY")},
)
return client
if retriever_type == "single_document_context":
retriever = WeaviateHybridSearchRetriever(
client=client,
index_name=namespace,
text_key="text",
attributes=[],
embedding=embeddings,
create_schema_if_missing=True,
)
return retriever
elif retriever_type == "multi_document_context":
retriever = WeaviateHybridSearchRetriever(
client=client,
index_name=namespace,
text_key="text",
attributes=[],
embedding=embeddings,
create_schema_if_missing=True,
)
return retriever
else :
return client
# child_splitter = RecursiveCharacterTextSplitter(chunk_size=400)
# store = InMemoryStore()
# retriever = ParentDocumentRetriever(
# vectorstore=vectorstore,
# docstore=store,
# child_splitter=child_splitter,
# )
from marshmallow import Schema, fields
def create_document_structure(observation, params, metadata_schema_class=None):
@ -140,7 +148,7 @@ class WeaviateVectorDB(VectorDB):
# Update Weaviate memories here
if namespace is None:
namespace = self.namespace
retriever = self.init_weaviate(embeddings=embeddings,namespace = namespace)
retriever = self.init_weaviate(embeddings=embeddings,namespace = namespace, retriever_type="single_document_context")
if loader_settings:
# Assuming _document_loader returns a list of documents
documents = await _document_loader(observation, loader_settings)
@ -174,7 +182,7 @@ class WeaviateVectorDB(VectorDB):
Example:
fetch_memories(query="some query", search_type='text', additional_param='value')
"""
client = self.init_weaviate_client(self.namespace)
client = self.init_weaviate(namespace =self.namespace)
if search_type is None:
search_type = 'hybrid'
@ -258,7 +266,7 @@ class WeaviateVectorDB(VectorDB):
async def delete_memories(self, namespace:str, params: dict = None):
if namespace is None:
namespace = self.namespace
client = self.init_weaviate_client(self.namespace)
client = self.init_weaviate(namespace = self.namespace)
if params:
where_filter = {
"path": ["id"],
@ -283,13 +291,12 @@ class WeaviateVectorDB(VectorDB):
)
def update_memories(self, observation, namespace: str, params: dict = None):
client = self.init_weaviate_client(self.namespace)
client = self.init_weaviate(namespace = self.namespace)
client.data_object.update(
data_object={
# "text": observation,
"user_id": str(self.user_id),
"buffer_id": str(self.buffer_id),
"version": params.get("version", None) or "",
"agreement_id": params.get("agreement_id", None) or "",
"privacy_policy": params.get("privacy_policy", None) or "",