Added a few fixes and refactored the base app
This commit is contained in:
parent
f81fc276c4
commit
81b8fd923c
3 changed files with 55 additions and 39 deletions
|
|
@ -13,9 +13,12 @@ class Operation(Base):
|
||||||
|
|
||||||
id = Column(String, primary_key=True)
|
id = Column(String, primary_key=True)
|
||||||
user_id = Column(String, ForeignKey('users.id'), index=True) # Link to User
|
user_id = Column(String, ForeignKey('users.id'), index=True) # Link to User
|
||||||
|
operation_type = Column(String, nullable=True)
|
||||||
|
operation_params = Column(String, nullable=True)
|
||||||
test_set_id = Column(String, ForeignKey('test_sets.id'), index=True)
|
test_set_id = Column(String, ForeignKey('test_sets.id'), index=True)
|
||||||
created_at = Column(DateTime, default=datetime.utcnow)
|
created_at = Column(DateTime, default=datetime.utcnow)
|
||||||
updated_at = Column(DateTime, onupdate=datetime.utcnow)
|
updated_at = Column(DateTime, onupdate=datetime.utcnow)
|
||||||
|
memories = relationship("MemoryModel", back_populates="operation")
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
user = relationship("User", back_populates="operations")
|
user = relationship("User", back_populates="operations")
|
||||||
|
|
|
||||||
|
|
@ -282,7 +282,10 @@ def generate_letter_uuid(length=8):
|
||||||
letters = string.ascii_uppercase # A-Z
|
letters = string.ascii_uppercase # A-Z
|
||||||
return ''.join(random.choice(letters) for _ in range(length))
|
return ''.join(random.choice(letters) for _ in range(length))
|
||||||
|
|
||||||
async def start_test(data, test_set=None, user_id=None, params=None, job_id=None, metadata=None, generate_test_set=False, only_llm_context=False):
|
async def start_test(data, test_set=None, user_id=None, params=None, job_id=None, metadata=None, generate_test_set=False, retriever_type:str=None):
|
||||||
|
|
||||||
|
|
||||||
|
"""retriever_type = "llm_context, single_document_context, multi_document_context, "cognitive_architecture"""""
|
||||||
|
|
||||||
|
|
||||||
async with session_scope(session=AsyncSessionLocal()) as session:
|
async with session_scope(session=AsyncSessionLocal()) as session:
|
||||||
|
|
@ -294,9 +297,7 @@ async def start_test(data, test_set=None, user_id=None, params=None, job_id=None
|
||||||
await memory.add_memory_instance("ExampleMemory")
|
await memory.add_memory_instance("ExampleMemory")
|
||||||
existing_user = await Memory.check_existing_user(user_id, session)
|
existing_user = await Memory.check_existing_user(user_id, session)
|
||||||
|
|
||||||
if job_id is None:
|
|
||||||
job_id = str(uuid.uuid4())
|
|
||||||
await add_entity(session, Operation(id=job_id, user_id=user_id))
|
|
||||||
|
|
||||||
if test_set_id is None:
|
if test_set_id is None:
|
||||||
test_set_id = str(uuid.uuid4())
|
test_set_id = str(uuid.uuid4())
|
||||||
|
|
@ -318,8 +319,13 @@ async def start_test(data, test_set=None, user_id=None, params=None, job_id=None
|
||||||
"path": data
|
"path": data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if job_id is None:
|
||||||
|
job_id = str(uuid.uuid4())
|
||||||
|
|
||||||
async def run_test(test, loader_settings, metadata, test_id=None,only_llm_context=False):
|
await add_entity(session, Operation(id=job_id, user_id=user_id, operation_params =str(test_params), operation_type=retriever_type, test_set_id=test_set_id))
|
||||||
|
|
||||||
|
|
||||||
|
async def run_test(test, loader_settings, metadata, test_id=None,retriever_type=False):
|
||||||
|
|
||||||
if test_id is None:
|
if test_id is None:
|
||||||
test_id = str(generate_letter_uuid()) + "_" +"SEMANTICMEMORY"
|
test_id = str(generate_letter_uuid()) + "_" +"SEMANTICMEMORY"
|
||||||
|
|
@ -372,7 +378,7 @@ async def start_test(data, test_set=None, user_id=None, params=None, job_id=None
|
||||||
test_eval_pipeline =[]
|
test_eval_pipeline =[]
|
||||||
|
|
||||||
|
|
||||||
if only_llm_context:
|
if retriever_type == "llm_context":
|
||||||
for test_qa in test_set:
|
for test_qa in test_set:
|
||||||
context=""
|
context=""
|
||||||
test_result = await run_eval(test_qa, context)
|
test_result = await run_eval(test_qa, context)
|
||||||
|
|
@ -399,13 +405,13 @@ async def start_test(data, test_set=None, user_id=None, params=None, job_id=None
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
if only_llm_context:
|
if retriever_type:
|
||||||
test_id, result = await run_test(test=None, loader_settings=loader_settings, metadata=metadata,
|
test_id, result = await run_test(test=None, loader_settings=loader_settings, metadata=metadata,
|
||||||
only_llm_context=only_llm_context)
|
retriever_type=retriever_type)
|
||||||
results.append(result)
|
results.append(result)
|
||||||
|
|
||||||
for param in test_params:
|
for param in test_params:
|
||||||
test_id, result = await run_test(param, loader_settings, metadata, only_llm_context=only_llm_context)
|
test_id, result = await run_test(param, loader_settings, metadata, retriever_type=retriever_type)
|
||||||
results.append(result)
|
results.append(result)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -458,7 +464,7 @@ async def main():
|
||||||
]
|
]
|
||||||
# "https://www.ibiblio.org/ebooks/London/Call%20of%20Wild.pdf"
|
# "https://www.ibiblio.org/ebooks/London/Call%20of%20Wild.pdf"
|
||||||
#http://public-library.uk/ebooks/59/83.pdf
|
#http://public-library.uk/ebooks/59/83.pdf
|
||||||
result = await start_test(".data/3ZCCCW.pdf", test_set=test_set, user_id="677", params=None, metadata=metadata)
|
result = await start_test(".data/3ZCCCW.pdf", test_set=test_set, user_id="677", params=None, metadata=metadata, retriever_type='llm_context')
|
||||||
#
|
#
|
||||||
# parser = argparse.ArgumentParser(description="Run tests against a document.")
|
# parser = argparse.ArgumentParser(description="Run tests against a document.")
|
||||||
# parser.add_argument("--url", required=True, help="URL of the document to test.")
|
# parser.add_argument("--url", required=True, help="URL of the document to test.")
|
||||||
|
|
|
||||||
|
|
@ -2,13 +2,14 @@
|
||||||
# Make sure to install the following packages: dlt, langchain, duckdb, python-dotenv, openai, weaviate-client
|
# Make sure to install the following packages: dlt, langchain, duckdb, python-dotenv, openai, weaviate-client
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
from marshmallow import Schema, fields
|
from marshmallow import Schema, fields
|
||||||
from loaders.loaders import _document_loader
|
from loaders.loaders import _document_loader
|
||||||
# Add the parent directory to sys.path
|
# Add the parent directory to sys.path
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
from langchain.retrievers import WeaviateHybridSearchRetriever
|
from langchain.retrievers import WeaviateHybridSearchRetriever, ParentDocumentRetriever
|
||||||
from weaviate.gql.get import HybridFusion
|
from weaviate.gql.get import HybridFusion
|
||||||
import tracemalloc
|
import tracemalloc
|
||||||
tracemalloc.start()
|
tracemalloc.start()
|
||||||
|
|
@ -56,9 +57,8 @@ class WeaviateVectorDB(VectorDB):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.init_weaviate(embeddings= self.embeddings, namespace = self.namespace)
|
self.init_weaviate(embeddings= self.embeddings, namespace = self.namespace)
|
||||||
|
|
||||||
def init_weaviate(self, embeddings =OpenAIEmbeddings() , namespace: str=None):
|
def init_weaviate(self, embeddings=OpenAIEmbeddings(), namespace=None,retriever_type="",):
|
||||||
# Weaviate initialization logic
|
# Weaviate initialization logic
|
||||||
# embeddings = OpenAIEmbeddings()
|
|
||||||
auth_config = weaviate.auth.AuthApiKey(
|
auth_config = weaviate.auth.AuthApiKey(
|
||||||
api_key=os.environ.get("WEAVIATE_API_KEY")
|
api_key=os.environ.get("WEAVIATE_API_KEY")
|
||||||
)
|
)
|
||||||
|
|
@ -67,28 +67,36 @@ class WeaviateVectorDB(VectorDB):
|
||||||
auth_client_secret=auth_config,
|
auth_client_secret=auth_config,
|
||||||
additional_headers={"X-OpenAI-Api-Key": os.environ.get("OPENAI_API_KEY")},
|
additional_headers={"X-OpenAI-Api-Key": os.environ.get("OPENAI_API_KEY")},
|
||||||
)
|
)
|
||||||
retriever = WeaviateHybridSearchRetriever(
|
|
||||||
client=client,
|
|
||||||
index_name=namespace,
|
|
||||||
text_key="text",
|
|
||||||
attributes=[],
|
|
||||||
embedding=embeddings,
|
|
||||||
create_schema_if_missing=True,
|
|
||||||
)
|
|
||||||
return retriever
|
|
||||||
|
|
||||||
def init_weaviate_client(self, namespace: str):
|
|
||||||
# Weaviate client initialization logic
|
|
||||||
auth_config = weaviate.auth.AuthApiKey(
|
|
||||||
api_key=os.environ.get("WEAVIATE_API_KEY")
|
|
||||||
)
|
|
||||||
client = weaviate.Client(
|
|
||||||
url=os.environ.get("WEAVIATE_URL"),
|
|
||||||
auth_client_secret=auth_config,
|
|
||||||
additional_headers={"X-OpenAI-Api-Key": os.environ.get("OPENAI_API_KEY")},
|
|
||||||
)
|
|
||||||
return client
|
|
||||||
|
|
||||||
|
if retriever_type == "single_document_context":
|
||||||
|
retriever = WeaviateHybridSearchRetriever(
|
||||||
|
client=client,
|
||||||
|
index_name=namespace,
|
||||||
|
text_key="text",
|
||||||
|
attributes=[],
|
||||||
|
embedding=embeddings,
|
||||||
|
create_schema_if_missing=True,
|
||||||
|
)
|
||||||
|
return retriever
|
||||||
|
elif retriever_type == "multi_document_context":
|
||||||
|
retriever = WeaviateHybridSearchRetriever(
|
||||||
|
client=client,
|
||||||
|
index_name=namespace,
|
||||||
|
text_key="text",
|
||||||
|
attributes=[],
|
||||||
|
embedding=embeddings,
|
||||||
|
create_schema_if_missing=True,
|
||||||
|
)
|
||||||
|
return retriever
|
||||||
|
else :
|
||||||
|
return client
|
||||||
|
# child_splitter = RecursiveCharacterTextSplitter(chunk_size=400)
|
||||||
|
# store = InMemoryStore()
|
||||||
|
# retriever = ParentDocumentRetriever(
|
||||||
|
# vectorstore=vectorstore,
|
||||||
|
# docstore=store,
|
||||||
|
# child_splitter=child_splitter,
|
||||||
|
# )
|
||||||
from marshmallow import Schema, fields
|
from marshmallow import Schema, fields
|
||||||
|
|
||||||
def create_document_structure(observation, params, metadata_schema_class=None):
|
def create_document_structure(observation, params, metadata_schema_class=None):
|
||||||
|
|
@ -140,7 +148,7 @@ class WeaviateVectorDB(VectorDB):
|
||||||
# Update Weaviate memories here
|
# Update Weaviate memories here
|
||||||
if namespace is None:
|
if namespace is None:
|
||||||
namespace = self.namespace
|
namespace = self.namespace
|
||||||
retriever = self.init_weaviate(embeddings=embeddings,namespace = namespace)
|
retriever = self.init_weaviate(embeddings=embeddings,namespace = namespace, retriever_type="single_document_context")
|
||||||
if loader_settings:
|
if loader_settings:
|
||||||
# Assuming _document_loader returns a list of documents
|
# Assuming _document_loader returns a list of documents
|
||||||
documents = await _document_loader(observation, loader_settings)
|
documents = await _document_loader(observation, loader_settings)
|
||||||
|
|
@ -174,7 +182,7 @@ class WeaviateVectorDB(VectorDB):
|
||||||
Example:
|
Example:
|
||||||
fetch_memories(query="some query", search_type='text', additional_param='value')
|
fetch_memories(query="some query", search_type='text', additional_param='value')
|
||||||
"""
|
"""
|
||||||
client = self.init_weaviate_client(self.namespace)
|
client = self.init_weaviate(namespace =self.namespace)
|
||||||
if search_type is None:
|
if search_type is None:
|
||||||
search_type = 'hybrid'
|
search_type = 'hybrid'
|
||||||
|
|
||||||
|
|
@ -258,7 +266,7 @@ class WeaviateVectorDB(VectorDB):
|
||||||
async def delete_memories(self, namespace:str, params: dict = None):
|
async def delete_memories(self, namespace:str, params: dict = None):
|
||||||
if namespace is None:
|
if namespace is None:
|
||||||
namespace = self.namespace
|
namespace = self.namespace
|
||||||
client = self.init_weaviate_client(self.namespace)
|
client = self.init_weaviate(namespace = self.namespace)
|
||||||
if params:
|
if params:
|
||||||
where_filter = {
|
where_filter = {
|
||||||
"path": ["id"],
|
"path": ["id"],
|
||||||
|
|
@ -283,13 +291,12 @@ class WeaviateVectorDB(VectorDB):
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_memories(self, observation, namespace: str, params: dict = None):
|
def update_memories(self, observation, namespace: str, params: dict = None):
|
||||||
client = self.init_weaviate_client(self.namespace)
|
client = self.init_weaviate(namespace = self.namespace)
|
||||||
|
|
||||||
client.data_object.update(
|
client.data_object.update(
|
||||||
data_object={
|
data_object={
|
||||||
# "text": observation,
|
# "text": observation,
|
||||||
"user_id": str(self.user_id),
|
"user_id": str(self.user_id),
|
||||||
"buffer_id": str(self.buffer_id),
|
|
||||||
"version": params.get("version", None) or "",
|
"version": params.get("version", None) or "",
|
||||||
"agreement_id": params.get("agreement_id", None) or "",
|
"agreement_id": params.get("agreement_id", None) or "",
|
||||||
"privacy_policy": params.get("privacy_policy", None) or "",
|
"privacy_policy": params.get("privacy_policy", None) or "",
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue