diff --git a/level_3/models/operation.py b/level_3/models/operation.py index 6be8b4df0..3c3877286 100644 --- a/level_3/models/operation.py +++ b/level_3/models/operation.py @@ -13,9 +13,12 @@ class Operation(Base): id = Column(String, primary_key=True) user_id = Column(String, ForeignKey('users.id'), index=True) # Link to User + operation_type = Column(String, nullable=True) + operation_params = Column(String, nullable=True) test_set_id = Column(String, ForeignKey('test_sets.id'), index=True) created_at = Column(DateTime, default=datetime.utcnow) updated_at = Column(DateTime, onupdate=datetime.utcnow) + memories = relationship("MemoryModel", back_populates="operation") # Relationships user = relationship("User", back_populates="operations") diff --git a/level_3/rag_test_manager.py b/level_3/rag_test_manager.py index edb0ed787..5035f2f69 100644 --- a/level_3/rag_test_manager.py +++ b/level_3/rag_test_manager.py @@ -282,7 +282,10 @@ def generate_letter_uuid(length=8): letters = string.ascii_uppercase # A-Z return ''.join(random.choice(letters) for _ in range(length)) -async def start_test(data, test_set=None, user_id=None, params=None, job_id=None, metadata=None, generate_test_set=False, only_llm_context=False): +async def start_test(data, test_set=None, user_id=None, params=None, job_id=None, metadata=None, generate_test_set=False, retriever_type:str=None): + + + """retriever_type = "llm_context, single_document_context, multi_document_context, "cognitive_architecture""""" async with session_scope(session=AsyncSessionLocal()) as session: @@ -294,9 +297,7 @@ async def start_test(data, test_set=None, user_id=None, params=None, job_id=None await memory.add_memory_instance("ExampleMemory") existing_user = await Memory.check_existing_user(user_id, session) - if job_id is None: - job_id = str(uuid.uuid4()) - await add_entity(session, Operation(id=job_id, user_id=user_id)) + if test_set_id is None: test_set_id = str(uuid.uuid4()) @@ -318,8 +319,13 @@ async def start_test(data, test_set=None, user_id=None, params=None, job_id=None "path": data } + if job_id is None: + job_id = str(uuid.uuid4()) - async def run_test(test, loader_settings, metadata, test_id=None,only_llm_context=False): + await add_entity(session, Operation(id=job_id, user_id=user_id, operation_params =str(test_params), operation_type=retriever_type, test_set_id=test_set_id)) + + + async def run_test(test, loader_settings, metadata, test_id=None,retriever_type=False): if test_id is None: test_id = str(generate_letter_uuid()) + "_" +"SEMANTICMEMORY" @@ -372,7 +378,7 @@ async def start_test(data, test_set=None, user_id=None, params=None, job_id=None test_eval_pipeline =[] - if only_llm_context: + if retriever_type == "llm_context": for test_qa in test_set: context="" test_result = await run_eval(test_qa, context) @@ -399,13 +405,13 @@ async def start_test(data, test_set=None, user_id=None, params=None, job_id=None results = [] - if only_llm_context: + if retriever_type: test_id, result = await run_test(test=None, loader_settings=loader_settings, metadata=metadata, - only_llm_context=only_llm_context) + retriever_type=retriever_type) results.append(result) for param in test_params: - test_id, result = await run_test(param, loader_settings, metadata, only_llm_context=only_llm_context) + test_id, result = await run_test(param, loader_settings, metadata, retriever_type=retriever_type) results.append(result) @@ -458,7 +464,7 @@ async def main(): ] # "https://www.ibiblio.org/ebooks/London/Call%20of%20Wild.pdf" #http://public-library.uk/ebooks/59/83.pdf - result = await start_test(".data/3ZCCCW.pdf", test_set=test_set, user_id="677", params=None, metadata=metadata) + result = await start_test(".data/3ZCCCW.pdf", test_set=test_set, user_id="677", params=None, metadata=metadata, retriever_type='llm_context') # # parser = argparse.ArgumentParser(description="Run tests against a document.") # parser.add_argument("--url", required=True, help="URL of the document to test.") diff --git a/level_3/vectordb/vectordb.py b/level_3/vectordb/vectordb.py index d413fe4d3..e11655db1 100644 --- a/level_3/vectordb/vectordb.py +++ b/level_3/vectordb/vectordb.py @@ -2,13 +2,14 @@ # Make sure to install the following packages: dlt, langchain, duckdb, python-dotenv, openai, weaviate-client import logging +from langchain.text_splitter import RecursiveCharacterTextSplitter from marshmallow import Schema, fields from loaders.loaders import _document_loader # Add the parent directory to sys.path logging.basicConfig(level=logging.INFO) -from langchain.retrievers import WeaviateHybridSearchRetriever +from langchain.retrievers import WeaviateHybridSearchRetriever, ParentDocumentRetriever from weaviate.gql.get import HybridFusion import tracemalloc tracemalloc.start() @@ -56,9 +57,8 @@ class WeaviateVectorDB(VectorDB): super().__init__(*args, **kwargs) self.init_weaviate(embeddings= self.embeddings, namespace = self.namespace) - def init_weaviate(self, embeddings =OpenAIEmbeddings() , namespace: str=None): + def init_weaviate(self, embeddings=OpenAIEmbeddings(), namespace=None,retriever_type="",): # Weaviate initialization logic - # embeddings = OpenAIEmbeddings() auth_config = weaviate.auth.AuthApiKey( api_key=os.environ.get("WEAVIATE_API_KEY") ) @@ -67,28 +67,36 @@ class WeaviateVectorDB(VectorDB): auth_client_secret=auth_config, additional_headers={"X-OpenAI-Api-Key": os.environ.get("OPENAI_API_KEY")}, ) - retriever = WeaviateHybridSearchRetriever( - client=client, - index_name=namespace, - text_key="text", - attributes=[], - embedding=embeddings, - create_schema_if_missing=True, - ) - return retriever - - def init_weaviate_client(self, namespace: str): - # Weaviate client initialization logic - auth_config = weaviate.auth.AuthApiKey( - api_key=os.environ.get("WEAVIATE_API_KEY") - ) - client = weaviate.Client( - url=os.environ.get("WEAVIATE_URL"), - auth_client_secret=auth_config, - additional_headers={"X-OpenAI-Api-Key": os.environ.get("OPENAI_API_KEY")}, - ) - return client + if retriever_type == "single_document_context": + retriever = WeaviateHybridSearchRetriever( + client=client, + index_name=namespace, + text_key="text", + attributes=[], + embedding=embeddings, + create_schema_if_missing=True, + ) + return retriever + elif retriever_type == "multi_document_context": + retriever = WeaviateHybridSearchRetriever( + client=client, + index_name=namespace, + text_key="text", + attributes=[], + embedding=embeddings, + create_schema_if_missing=True, + ) + return retriever + else : + return client + # child_splitter = RecursiveCharacterTextSplitter(chunk_size=400) + # store = InMemoryStore() + # retriever = ParentDocumentRetriever( + # vectorstore=vectorstore, + # docstore=store, + # child_splitter=child_splitter, + # ) from marshmallow import Schema, fields def create_document_structure(observation, params, metadata_schema_class=None): @@ -140,7 +148,7 @@ class WeaviateVectorDB(VectorDB): # Update Weaviate memories here if namespace is None: namespace = self.namespace - retriever = self.init_weaviate(embeddings=embeddings,namespace = namespace) + retriever = self.init_weaviate(embeddings=embeddings,namespace = namespace, retriever_type="single_document_context") if loader_settings: # Assuming _document_loader returns a list of documents documents = await _document_loader(observation, loader_settings) @@ -174,7 +182,7 @@ class WeaviateVectorDB(VectorDB): Example: fetch_memories(query="some query", search_type='text', additional_param='value') """ - client = self.init_weaviate_client(self.namespace) + client = self.init_weaviate(namespace =self.namespace) if search_type is None: search_type = 'hybrid' @@ -258,7 +266,7 @@ class WeaviateVectorDB(VectorDB): async def delete_memories(self, namespace:str, params: dict = None): if namespace is None: namespace = self.namespace - client = self.init_weaviate_client(self.namespace) + client = self.init_weaviate(namespace = self.namespace) if params: where_filter = { "path": ["id"], @@ -283,13 +291,12 @@ class WeaviateVectorDB(VectorDB): ) def update_memories(self, observation, namespace: str, params: dict = None): - client = self.init_weaviate_client(self.namespace) + client = self.init_weaviate(namespace = self.namespace) client.data_object.update( data_object={ # "text": observation, "user_id": str(self.user_id), - "buffer_id": str(self.buffer_id), "version": params.get("version", None) or "", "agreement_id": params.get("agreement_id", None) or "", "privacy_policy": params.get("privacy_policy", None) or "",