Updated and tested retry logic, still more to be done

This commit is contained in:
Vasilije 2023-10-08 21:23:30 +02:00
parent 44c595d929
commit 8638a7efe6
30 changed files with 479 additions and 142 deletions

0
level_3/__init__.py Normal file
View file

0
level_3/auth/__init__.py Normal file
View file

View file

View file

View file

View file

View file

View file

View file

View file

@ -1,17 +0,0 @@
from sqlalchemy import create_engine, Column, Integer, String
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
import os
import sys
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from database.database import Base
class MemoryAssociation(Base):
__tablename__ = 'memory_associations'
id = Column(Integer, primary_key=True)
user_id = Column(String)
source_memory_id = Column(String)
target_memory_id = Column(String)

View file

@ -12,11 +12,14 @@ class Operation(Base):
__tablename__ = 'operations'
id = Column(String, primary_key=True)
session_id = Column(String, ForeignKey('sessions.id'), index=True)
user_id = Column(String, ForeignKey('users.id'), index=True) # Link to User
test_set_id = Column(String, ForeignKey('test_sets.id'), index=True)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, onupdate=datetime.utcnow)
session = relationship("Session", back_populates="operations")
# Relationships
user = relationship("User", back_populates="operations")
test_set = relationship("TestSet", back_populates="operations")
def __repr__(self):
return f"<Operation(id={self.id}, session_id={self.session_id}, created_at={self.created_at}, updated_at={self.updated_at})>"
return f"<Operation(id={self.id}, user_id={self.user_id}, created_at={self.created_at}, updated_at={self.updated_at})>"

View file

@ -14,10 +14,12 @@ class Session(Base):
id = Column(String, primary_key=True)
user_id = Column(String, ForeignKey('users.id'), index=True)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, onupdate=datetime.utcnow)
updated_at = Column(DateTime, onupdate=datetime.utcnow)
# Corrected relationship name
user = relationship("User", back_populates="sessions")
operations = relationship("Operation", back_populates="session", cascade="all, delete-orphan")
# operations = relationship("Operation", back_populates="session", cascade="all, delete-orphan")
def __repr__(self):
return f"<Session(id={self.id}, user_id={self.user_id}, created_at={self.created_at}, updated_at={self.updated_at})>"

View file

@ -1,22 +0,0 @@
# test_output.py
from datetime import datetime
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
from sqlalchemy.orm import relationship
import os
import sys
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from database.database import Base
class TestOutput(Base):
__tablename__ = 'test_outputs'
id = Column(String, primary_key=True)
test_set_id = Column(String, ForeignKey('test_sets.id'), index=True)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, onupdate=datetime.utcnow)
test_set = relationship("TestSet", back_populates="test_outputs")
def __repr__(self):
return f"<TestOutput(id={self.id}, test_set_id={self.test_set_id}, created_at={self.created_at}, updated_at={self.updated_at})>"

View file

@ -0,0 +1,40 @@
# test_output.py
from datetime import datetime
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
from sqlalchemy.orm import relationship
import os
import sys
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from database.database import Base
from datetime import datetime
from sqlalchemy import Column, String, DateTime, ForeignKey
from sqlalchemy.orm import relationship
import os
import sys
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from database.database import Base
class TestOutput(Base):
"""
Represents the output result of a specific test set.
"""
__tablename__ = 'test_outputs'
id = Column(String, primary_key=True)
user_id = Column(String, ForeignKey('users.id'), index=True) # Added user_id field
test_set_id = Column(String, ForeignKey('test_sets.id'), index=True)
operation_id = Column(String, ForeignKey('operations.id'), index=True)
test_results = Column(String, nullable=True)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, onupdate=datetime.utcnow)
# Relationships
user = relationship("User", back_populates="test_outputs") # Added relationship with User
test_set = relationship("TestSet", back_populates="test_outputs")
operation = relationship("Operation", backref="test_outputs")
def __repr__(self):
return f"<TestOutput(id={self.id}, user_id={self.user_id}, test_set_id={self.test_set_id}, operation_id={self.operation_id}, created_at={self.created_at}, updated_at={self.updated_at})>"

View file

@ -12,12 +12,15 @@ class TestSet(Base):
__tablename__ = 'test_sets'
id = Column(String, primary_key=True)
user_id = Column(String, ForeignKey('users.id'), index=True)
content = Column(String, ForeignKey('users.id'), index=True)
user_id = Column(String, ForeignKey('users.id'), index=True) # Ensure uniqueness
content = Column(String, nullable=True)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, onupdate=datetime.utcnow)
updated_at = Column(DateTime, onupdate=datetime.utcnow)
user = relationship("User", back_populates="test_sets")
operations = relationship("Operation", back_populates="test_set")
test_outputs = relationship("TestOutput", back_populates="test_set", cascade="all, delete-orphan")
def __repr__(self):

View file

@ -1,6 +1,6 @@
# user.py
from datetime import datetime
from sqlalchemy import Column, Integer, String, DateTime
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
from sqlalchemy.orm import relationship
from sqlalchemy.ext.declarative import declarative_base
import os
@ -14,12 +14,16 @@ class User(Base):
id = Column(String, primary_key=True)
name = Column(String, nullable=False, unique=True, index=True)
session_id = Column(String, nullable=True, unique=True)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, onupdate=datetime.utcnow)
# Relationships
memories = relationship("MemoryModel", back_populates="user", cascade="all, delete-orphan")
operations = relationship("Operation", back_populates="user", cascade="all, delete-orphan")
sessions = relationship("Session", back_populates="user", cascade="all, delete-orphan")
test_sets = relationship("TestSet", back_populates="user", cascade="all, delete-orphan")
test_outputs = relationship("TestOutput", back_populates="user", cascade="all, delete-orphan")
metadatas = relationship("MetaDatas", back_populates="user")
def __repr__(self):

View file

@ -40,7 +40,7 @@ weaviate-client = "^3.22.1"
python-multipart = "^0.0.6"
deep-translator = "^1.11.4"
humanize = "^4.8.0"
deepeval = "^0.20.0"
deepeval = "^0.20.1"
pymupdf = "^1.23.3"
psycopg2 = "^2.9.8"
llama-index = "^0.8.39.post2"

View file

@ -1,23 +1,122 @@
from enum import Enum
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from deepeval.metrics.overall_score import OverallScoreMetric
from deepeval.test_case import LLMTestCase
from deepeval.run_test import assert_test, run_test
from gptcache.embedding import openai
from marvin import ai_classifier
from models.sessions import Session
from models.testset import TestSet
from models.testoutput import TestOutput
from models.metadatas import MetaDatas
from models.operation import Operation
from sqlalchemy.orm import sessionmaker
from database.database import engine
from vectorstore_manager import Memory
import uuid
from contextlib import contextmanager
import random
import string
import itertools
import logging
import dotenv
dotenv.load_dotenv()
import openai
logger = logging.getLogger(__name__)
openai.api_key = os.getenv("OPENAI_API_KEY", "")
@contextmanager
def session_scope(session):
"""Provide a transactional scope around a series of operations."""
try:
yield session
session.commit()
except Exception as e:
session.rollback()
logger.error(f"Session rollback due to: {str(e)}")
raise
finally:
session.close()
def retrieve_latest_test_case(session, user_id, memory_id):
"""
Retrieve the most recently created test case from the database filtered by user_id and memory_id.
Parameters:
- session (Session): A database session.
- user_id (int/str): The ID of the user to filter test cases by.
- memory_id (int/str): The ID of the memory to filter test cases by.
Returns:
- Object: The most recent test case attributes filtered by user_id and memory_id, or None if an error occurs.
"""
try:
return (
session.query(TestSet.attributes_list)
.filter_by(user_id=user_id, memory_id=memory_id)
.order_by(TestSet.created_at.desc())
.first()
)
except Exception as e:
logger.error(f"An error occurred while retrieving the latest test case: {str(e)}")
return None
def retrieve_test_cases():
"""Retrieve test cases from a database or a file."""
pass
def add_entity(session, entity):
"""
Add an entity (like TestOutput, Session, etc.) to the database.
Parameters:
- session (Session): A database session.
- entity (Base): An instance of an SQLAlchemy model.
Returns:
- str: A message indicating whether the addition was successful.
"""
with session_scope(session):
session.add(entity)
session.commit()
return "Successfully added entity"
def check_params(chunk_size, chunk_overlap, chunk_strategy, loader_strategy, query, context, metadata):
"""Check parameters for test case runs and set defaults if necessary."""
pass
def retrieve_job_by_id(session, user_id, job_id):
"""
Retrieve a job by user ID and job ID.
Parameters:
- session (Session): A database session.
- user_id (int/str): The ID of the user.
- job_id (int/str): The ID of the job to retrieve.
def run_load(test_id, document, **kwargs):
"""Run load for the given test_id and document with other parameters."""
pass
Returns:
- Object: The job attributes filtered by user_id and job_id, or None if an error occurs.
"""
try:
return (
session.query(Session.id)
.filter_by(user_id=user_id, id=job_id)
.order_by(Session.created_at.desc())
.first()
)
except Exception as e:
logger.error(f"An error occurred while retrieving the job: {str(e)}")
return None
def fetch_job_id(session, user_id=None, memory_id=None, job_id=None):
try:
return (
session.query(Session.id)
.filter_by(user_id=user_id, id=job_id)
.order_by(Session.created_at.desc())
.first()
)
except Exception as e:
# Handle exceptions as per your application's requirements.
print(f"An error occurred: {str(e)}")
return None
def compare_output(output, expected_output):
@ -25,73 +124,97 @@ def compare_output(output, expected_output):
pass
def generate_param_variants(base_params):
"""Generate parameter variants for testing."""
params_variants = [
{'chunk_size': base_params['chunk_size'] + i} for i in range(1, 4)
] + [
{'chunk_overlap': base_params['chunk_overlap'] + i} for i in range(1, 4)
]
# Add more parameter variations here as needed
return params_variants
def run_tests_with_variants(document, base_params, param_variants, expected_output):
"""Run tests with various parameter variants and validate the output."""
for variant in param_variants:
test_id = str(uuid.uuid4()) # Set new test id
updated_params = {**base_params, **variant} # Update parameters
output = run_load(test_id, document, **updated_params) # Run load with varied parameters
compare_output(output, expected_output) # Validate output
def generate_param_variants(base_params=None, increments=None, ranges=None, included_params=None):
"""Generate parameter variants for testing.
Args:
base_params (dict): Base parameter values.
increments (dict): Increment values for each parameter variant.
ranges (dict): Range (number of variants) to generate for each parameter.
included_params (list, optional): Parameters to include in the combinations.
If None, all parameters are included.
def run_rag_tests(document, chunk_size, chunk_overlap, chunk_strategy, loader_strategy, query, output, expected_output,
context, metadata):
"""Run RAG tests with various scenarios and parameter variants."""
test_cases = retrieve_test_cases() # Retrieve test cases
Returns:
list: A list of dictionaries containing parameter variants.
"""
# Check and set parameters
base_params = check_params(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
chunk_strategy=chunk_strategy,
loader_strategy=loader_strategy,
query=query,
context=context,
metadata=metadata
# Default base values
defaults = {
'chunk_size': 500,
'chunk_overlap': 20,
'similarity_score': 0.5,
'metadata_variation': 0
}
# Update defaults with provided base parameters
params = {**defaults, **(base_params if base_params is not None else {})}
default_increments = {
'chunk_size': 500,
'chunk_overlap': 10,
'similarity_score': 0.1,
'metadata_variation': 1
}
# Update default increments with provided increments
increments = {**default_increments, **(increments if increments is not None else {})}
# Default ranges
default_ranges = {
'chunk_size': 3,
'chunk_overlap': 3,
'similarity_score': 3,
'metadata_variation': 3
}
# Update default ranges with provided ranges
ranges = {**default_ranges, **(ranges if ranges is not None else {})}
# Generate parameter variant ranges
param_ranges = {
key: [params[key] + i * increments.get(key, 1) for i in range(ranges.get(key, 1))]
for key in ['chunk_size', 'chunk_overlap', 'similarity_score', 'metadata_variation']
}
param_ranges['cognitive_architecture'] = ["simple_index", "cognitive_architecture"]
param_ranges['search_strategy'] = ["similarity_score", "fusion_score"]
# Filter param_ranges based on included_params
if included_params is not None:
param_ranges = {key: val for key, val in param_ranges.items() if key in included_params}
# Generate all combinations of parameter variants
keys = param_ranges.keys()
values = param_ranges.values()
param_variants = [dict(zip(keys, combination)) for combination in itertools.product(*values)]
return param_variants
async def generate_chatgpt_output(query:str, context:str=None):
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "assistant", "content": f"{context}"},
{"role": "user", "content": query}
]
)
llm_output = response.choices[0].message.content
# print(llm_output)
return llm_output
# Set test id and run initial load test
test_id = str(uuid.uuid4())
output = run_load(test_id, document, **base_params)
compare_output(output, expected_output)
# Generate parameter variants for further tests
param_variants = generate_param_variants(base_params)
# Run tests with varied parameters for the single document
run_tests_with_variants(document, base_params, param_variants, expected_output)
# Assuming two documents are concatenated and treated as one
combined_document = document + document
# Run initial load test for combined document
output = run_load(test_id, combined_document, **base_params)
compare_output(output, expected_output)
# Run tests with varied parameters for the combined document
run_tests_with_variants(combined_document, base_params, param_variants, expected_output)
def test_0():
query = "How does photosynthesis work?"
output = "Photosynthesis is the process by which green plants and some other organisms use sunlight to synthesize foods with the help of chlorophyll pigment."
expected_output = "Photosynthesis is the process by which green plants and some other organisms use sunlight to synthesize food with the help of chlorophyll pigment."
context = "Biology"
async def eval_test(query=None, output=None, expected_output=None, context=None):
# query = "How does photosynthesis work?"
# output = "Photosynthesis is the process by which green plants and some other organisms use sunlight to synthesize foods with the help of chlorophyll pigment."
# expected_output = "Photosynthesis is the process by which green plants and some other organisms use sunlight to synthesize food with the help of chlorophyll pigment."
# context = "Biology"
result_output = await generate_chatgpt_output(query, context)
test_case = LLMTestCase(
query=query,
output=output,
output=result_output,
expected_output=expected_output,
context=context,
)
@ -103,4 +226,204 @@ def test_0():
test_result = run_test(test_case, metrics=[metric])
# You can also inspect the test result class
print(test_result)
print(test_result)
def data_format_route( data_string: str):
@ai_classifier
class FormatRoute(Enum):
"""Represents classifier for the data format"""
PDF = "PDF"
UNSTRUCTURED_WEB = "UNSTRUCTURED_WEB"
GITHUB = "GITHUB"
TEXT = "TEXT"
CSV = "CSV"
WIKIPEDIA = "WIKIPEDIA"
return FormatRoute(data_string).name
def data_location_route(data_string: str):
@ai_classifier
class LocationRoute(Enum):
"""Represents classifier for the data location"""
DEVICE = "DEVICE"
URL = "URL"
DATABASE = "DATABASE"
return LocationRoute(data_string).name
def dynamic_test_manager(data, test_set=None, user=None, params=None):
from deepeval.dataset import create_evaluation_query_answer_pairs
# fetch random chunks from the document
#feed them to the evaluation pipeline
dataset = create_evaluation_query_answer_pairs(
"Python is a great language for mathematical expression and machine learning.")
return dataset
def generate_letter_uuid(length=8):
"""Generate a random string of uppercase letters with the specified length."""
letters = string.ascii_uppercase # A-Z
return ''.join(random.choice(letters) for _ in range(length))
def fetch_test_set_id(session, user_id, id):
try:
return (
session.query(TestSet.id)
.filter_by(user_id=user_id, id=id)
.order_by(TestSet.created_at.desc())
.first()
)
except Exception as e:
logger.error(f"An error occurred while retrieving the job: {str(e)}")
return None
async def start_test(data, test_set=None, user_id=None, params=None, job_id=None ,metadata=None):
Session = sessionmaker(bind=engine)
session = Session()
job_id = fetch_job_id(session, user_id = user_id,job_id =job_id)
test_set_id = fetch_test_set_id(session, user_id=user_id, id=job_id)
if job_id is None:
job_id = str(uuid.uuid4())
logging.info("we are adding a new job ID")
add_entity(session, Operation(id = job_id, user_id = user_id))
if test_set_id is None:
test_set_id = str(uuid.uuid4())
add_entity(session, TestSet(id = test_set_id, user_id = user_id, content = str(test_set)))
if params is None:
data_format = data_format_route(data)
data_location = data_location_route(data)
test_params = generate_param_variants( included_params=['chunk_size', 'chunk_overlap', 'similarity_score'])
loader_settings = {
"format": f"{data_format}",
"source": f"{data_location}",
"path": "https://www.ibiblio.org/ebooks/London/Call%20of%20Wild.pdf"
}
for test in test_params:
test_id = str(generate_letter_uuid()) + "_" + "SEMANTICEMEMORY"
#handle test data here
Session = sessionmaker(bind=engine)
session = Session()
memory = Memory.create_memory(user_id, session, namespace=test_id)
# Adding a memory instance
memory.add_memory_instance("ExampleMemory")
# Managing memory attributes
existing_user = Memory.check_existing_user(user_id, session)
print("here is the existing user", existing_user)
memory.manage_memory_attributes(existing_user)
test_class = test_id + "_class"
# memory.test_class
memory.add_dynamic_memory_class(test_id.lower(), test_id)
dynamic_memory_class = getattr(memory, test_class.lower(), None)
if dynamic_memory_class is not None:
memory.add_method_to_class(dynamic_memory_class, 'add_memories')
else:
print(f"No attribute named {test_class.lower()} in memory.")
if dynamic_memory_class is not None:
memory.add_method_to_class(dynamic_memory_class, 'fetch_memories')
else:
print(f"No attribute named {test_class.lower()} in memory.")
print(f"Trying to access: {test_class.lower()}")
print("Available memory classes:", memory.list_memory_classes())
print(f"Trying to check: ", test)
loader_settings.update(test)
load_action = await memory.dynamic_method_call(dynamic_memory_class, 'add_memories',
observation='some_observation', params=metadata, loader_settings=loader_settings)
loader_settings = {key: value for key, value in loader_settings.items() if key not in test}
test_result_colletion =[]
for test in test_set:
retrieve_action = await memory.dynamic_method_call(dynamic_memory_class, 'fetch_memories',
observation=test["question"])
test_results = await eval_test( query=test["question"], expected_output=test["answer"], context= str(retrieve_action))
test_result_colletion.append(test_results)
print(test_results)
add_entity(session, TestOutput(id=test_id, user_id=user_id, content=str(test_result_colletion)))
async def main():
params = {
"version": "1.0",
"agreement_id": "AG123456",
"privacy_policy": "https://example.com/privacy",
"terms_of_service": "https://example.com/terms",
"format": "json",
"schema_version": "1.1",
"checksum": "a1b2c3d4e5f6",
"owner": "John Doe",
"license": "MIT",
"validity_start": "2023-08-01",
"validity_end": "2024-07-31",
}
test_set = [
{
"question": "Who is the main character in 'The Call of the Wild'?",
"answer": "Buck"
},
{
"question": "Who wrote 'The Call of the Wild'?",
"answer": "Jack London"
},
{
"question": "Where does Buck live at the start of the book?",
"answer": "In the Santa Clara Valley, at Judge Millers place."
},
{
"question": "Why is Buck kidnapped?",
"answer": "He is kidnapped to be sold as a sled dog in the Yukon during the Klondike Gold Rush."
},
{
"question": "How does Buck become the leader of the sled dog team?",
"answer": "Buck becomes the leader after defeating the original leader, Spitz, in a fight."
}
]
result = await start_test("https://www.ibiblio.org/ebooks/London/Call%20of%20Wild.pdf", test_set=test_set, user_id="666", params=None, metadata=params)
if __name__ == "__main__":
import asyncio
asyncio.run(main())

View file

View file

@ -12,8 +12,8 @@ import models.memory
import models.metadatas
import models.operation
import models.sessions
import models.test_output
import models.test_set
import models.testoutput
import models.testset
import models.user
from sqlalchemy import create_engine, text
import psycopg2

View file

View file

View file

View file

@ -14,8 +14,8 @@ from langchain.document_loaders import PyPDFLoader
from langchain.retrievers import WeaviateHybridSearchRetriever
from weaviate.gql.get import HybridFusion
from models.sessions import Session
from models.test_set import TestSet
from models.test_output import TestOutput
from models.testset import TestSet
from models.testoutput import TestOutput
from models.metadatas import MetaDatas
from models.operation import Operation
from sqlalchemy.orm import sessionmaker

View file

View file

@ -1,7 +1,7 @@
from langchain.document_loaders import PyPDFLoader
import sys, os
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from shared.chunk_strategy import ChunkStrategy
from level_3.shared.chunk_strategy import ChunkStrategy
import re
def chunk_data(chunk_strategy=None, source_data=None, chunk_size=None, chunk_overlap=None):

View file

View file

@ -1,23 +1,22 @@
import os
from io import BytesIO
import sys, os
import fitz
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
# sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from chunkers.chunkers import chunk_data
from level_3.vectordb.chunkers.chunkers import chunk_data
from llama_hub.file.base import SimpleDirectoryReader
from langchain.document_loaders import PyPDFLoader
import requests
def _document_loader( observation: str, loader_settings: dict):
# Check the format of the document
document_format = loader_settings.get("format", "text")
loader_strategy = loader_settings.get("strategy", "VANILLA")
chunk_size = loader_settings.get("chunk_size", 100)
chunk_size = loader_settings.get("chunk_size", 500)
chunk_overlap = loader_settings.get("chunk_overlap", 20)
print("LOADER SETTINGS", loader_settings)
if document_format == "PDF":
if loader_settings.get("source") == "url":
if loader_settings.get("source") == "URL":
pdf_response = requests.get(loader_settings["path"])
pdf_stream = BytesIO(pdf_response.content)
with fitz.open(stream=pdf_stream, filetype='pdf') as doc:

View file

@ -1,29 +1,20 @@
# Make sure to install the following packages: dlt, langchain, duckdb, python-dotenv, openai, weaviate-client
import logging
from io import BytesIO
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from marshmallow import Schema, fields
from loaders.loaders import _document_loader
# Add the parent directory to sys.path
logging.basicConfig(level=logging.INFO)
import marvin
import requests
from langchain.document_loaders import PyPDFLoader
from langchain.retrievers import WeaviateHybridSearchRetriever
from weaviate.gql.get import HybridFusion
import tracemalloc
tracemalloc.start()
import os
from datetime import datetime
from langchain.embeddings.openai import OpenAIEmbeddings
from dotenv import load_dotenv
from schema.semantic.semantic_schema import DocumentSchema, SCHEMA_VERSIONS, DocumentMetadataSchemaV1
from langchain.schema import Document
import weaviate
@ -125,10 +116,11 @@ class WeaviateVectorDB(VectorDB):
# Update Weaviate memories here
if namespace is None:
namespace = self.namespace
retriever = self.init_weaviate(namespace) # Assuming `init_weaviate` is a method of the class
retriever = self.init_weaviate(namespace)
if loader_settings:
# Assuming _document_loader returns a list of documents
documents = _document_loader(observation, loader_settings)
logging.info("here are the docs %s", str(documents))
for doc in documents:
document_to_load = self._stuct(doc.page_content, params, metadata_schema_class)
print("here is the doc to load1", document_to_load)

View file

@ -8,8 +8,8 @@ from database.database import engine # Ensure you have database engine defined
from models.user import User
from models.memory import MemoryModel
from models.sessions import Session
from models.test_set import TestSet
from models.test_output import TestOutput
from models.testset import TestSet
from models.testoutput import TestOutput
from models.metadatas import MetaDatas
from models.operation import Operation
load_dotenv()
@ -103,6 +103,13 @@ class Memory:
return cls(user_id=user_id, session=session, memory_id=memory_id, **kwargs)
def list_memory_classes(self):
"""
Lists all available memory classes in the memory instance.
"""
# Use a list comprehension to filter attributes that end with '_class'
return [attr for attr in dir(self) if attr.endswith("_class")]
@staticmethod
def check_existing_user(user_id: str, session):
"""Check if a user exists in the DB and return it."""
@ -141,8 +148,11 @@ class Memory:
print(f"ID before query: {self.memory_id}, type: {type(self.memory_id)}")
attributes_list = self.session.query(MemoryModel.attributes_list).filter_by(id=self.memory_id[0]).scalar()
logging.info(f"Attributes list: {attributes_list}")
attributes_list = ast.literal_eval(attributes_list)
self.handle_attributes(attributes_list)
if attributes_list is not None:
attributes_list = ast.literal_eval(attributes_list)
self.handle_attributes(attributes_list)
else:
logging.warning("attributes_list is None!")
else:
attributes_list = ['user_id', 'index_name', 'db_type',
'knowledge_source', 'knowledge_type',