Added async elements

This commit is contained in:
Vasilije 2023-10-12 13:18:26 +02:00
parent 856e18b35d
commit 21fd10f2ca
9 changed files with 539 additions and 402 deletions

View file

@ -1,5 +1,5 @@
FROM python:3.11-slim FROM python:3.11
# Set build argument # Set build argument
ARG API_ENABLED ARG API_ENABLED

View file

@ -25,27 +25,37 @@ After that, you can run:
```docker compose build promethai_mem ``` ```docker compose build promethai_mem ```
## Run ## Run the level 3
Make sure you have Docker, Poetry, and Python 3.11 installed and postgres installed.
Copy the .env.example to .env and fill the variables
Start the docker:
```docker compose up promethai_mem ``` ```docker compose up promethai_mem ```
Use the poetry environment:
``` poetry shell ``` ``` poetry shell ```
Make sure to run Make sure to run to initialize DB tables
``` python scripts/create_database.py ``` ``` python scripts/create_database.py ```
After that, you can run: After that, you can run the RAG test manager.
``` python rag_test_manager.py \
```
python rag_test_manager.py \
--url "https://www.ibiblio.org/ebooks/London/Call%20of%20Wild.pdf" \ --url "https://www.ibiblio.org/ebooks/London/Call%20of%20Wild.pdf" \
--test_set "example_data/test_set.json" \ --test_set "example_data/test_set.json" \
--user_id "666" \ --user_id "666" \
--metadata "example_data/metadata.json" --metadata "example_data/metadata.json"
``` ```
Examples of metadata structure and test set are in the folder "example_data"
To see example of test_set.json and metadata.json, check the files in the folder "example_data"
## Clean database ## Clean database

View file

@ -1,14 +1,15 @@
import os import os
from sqlalchemy import create_engine from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import declarative_base from sqlalchemy.orm import declarative_base, sessionmaker
from sqlalchemy.orm import sessionmaker from contextlib import asynccontextmanager
from contextlib import contextmanager
from sqlalchemy.exc import OperationalError from sqlalchemy.exc import OperationalError
from time import sleep import asyncio
import sys import sys
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
# this is needed to import classes from other modules # this is needed to import classes from other modules
script_dir = os.path.dirname(os.path.abspath(__file__)) script_dir = os.path.dirname(os.path.abspath(__file__))
# Get the parent directory of your script and add it to sys.path # Get the parent directory of your script and add it to sys.path
@ -25,35 +26,43 @@ password = os.getenv('POSTGRES_PASSWORD')
database_name = os.getenv('POSTGRES_DB') database_name = os.getenv('POSTGRES_DB')
host = os.getenv('POSTGRES_HOST') host = os.getenv('POSTGRES_HOST')
# Use the asyncpg driver for async operation
SQLALCHEMY_DATABASE_URL = f"postgresql+asyncpg://{username}:{password}@{host}:5432/{database_name}"
SQLALCHEMY_DATABASE_URL = f"postgresql://{username}:{password}@{host}:5432/{database_name}" engine = create_async_engine(
engine = create_engine(
SQLALCHEMY_DATABASE_URL, SQLALCHEMY_DATABASE_URL,
pool_recycle=3600, # recycle connections after 1 hour pool_recycle=3600,
pool_pre_ping=True # test the connection for liveness upon each checkout echo=True # Enable logging for tutorial purposes
)
# Use AsyncSession for the session
AsyncSessionLocal = sessionmaker(
bind=engine,
class_=AsyncSession,
expire_on_commit=False,
) )
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base() Base = declarative_base()
@contextmanager # Use asynccontextmanager to define an async context manager
def get_db(): @asynccontextmanager
db = SessionLocal() async def get_db():
db = AsyncSessionLocal()
try: try:
yield db yield db
finally: finally:
db.close() await db.close()
def safe_db_operation(db_op, *args, **kwargs): # Use async/await syntax for the async function
async def safe_db_operation(db_op, *args, **kwargs):
for attempt in range(MAX_RETRIES): for attempt in range(MAX_RETRIES):
with get_db() as db: async with get_db() as db:
try: try:
return db_op(db, *args, **kwargs) # Ensure your db_op is also async
return await db_op(db, *args, **kwargs)
except OperationalError as e: except OperationalError as e:
db.rollback() await db.rollback()
if "server closed the connection unexpectedly" in str(e) and attempt < MAX_RETRIES - 1: if "server closed the connection unexpectedly" in str(e) and attempt < MAX_RETRIES - 1:
sleep(RETRY_DELAY) await asyncio.sleep(RETRY_DELAY)
else: else:
raise raise

View file

@ -1,15 +1,31 @@
from contextlib import asynccontextmanager
import asyncio
from contextlib import asynccontextmanager
import logging
# from database import AsyncSessionLocal
@contextmanager logger = logging.getLogger(__name__)
def session_scope(session):
@asynccontextmanager
async def session_scope(session):
"""Provide a transactional scope around a series of operations.""" """Provide a transactional scope around a series of operations."""
# session = AsyncSessionLocal()
try: try:
yield session yield session
session.commit() await session.commit()
except Exception as e: except Exception as e:
session.rollback() await session.rollback()
logger.error(f"Session rollback due to: {str(e)}") logger.error(f"Session rollback due to: {str(e)}")
raise raise
finally: finally:
session.close() await session.close()
async def add_entity(session, entity):
async with session_scope(session) as s: # Use your async session_scope
s.add(entity) # No need to commit; session_scope takes care of it
s.commit()
return "Successfully added entity"

191
level_3/poetry.lock generated
View file

@ -169,6 +169,59 @@ files = [
{file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"}, {file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"},
] ]
[[package]]
name = "asyncpg"
version = "0.28.0"
description = "An asyncio PostgreSQL driver"
optional = false
python-versions = ">=3.7.0"
files = [
{file = "asyncpg-0.28.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0a6d1b954d2b296292ddff4e0060f494bb4270d87fb3655dd23c5c6096d16d83"},
{file = "asyncpg-0.28.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0740f836985fd2bd73dca42c50c6074d1d61376e134d7ad3ad7566c4f79f8184"},
{file = "asyncpg-0.28.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e907cf620a819fab1737f2dd90c0f185e2a796f139ac7de6aa3212a8af96c050"},
{file = "asyncpg-0.28.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:86b339984d55e8202e0c4b252e9573e26e5afa05617ed02252544f7b3e6de3e9"},
{file = "asyncpg-0.28.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0c402745185414e4c204a02daca3d22d732b37359db4d2e705172324e2d94e85"},
{file = "asyncpg-0.28.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c88eef5e096296626e9688f00ab627231f709d0e7e3fb84bb4413dff81d996d7"},
{file = "asyncpg-0.28.0-cp310-cp310-win32.whl", hash = "sha256:90a7bae882a9e65a9e448fdad3e090c2609bb4637d2a9c90bfdcebbfc334bf89"},
{file = "asyncpg-0.28.0-cp310-cp310-win_amd64.whl", hash = "sha256:76aacdcd5e2e9999e83c8fbcb748208b60925cc714a578925adcb446d709016c"},
{file = "asyncpg-0.28.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a0e08fe2c9b3618459caaef35979d45f4e4f8d4f79490c9fa3367251366af207"},
{file = "asyncpg-0.28.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b24e521f6060ff5d35f761a623b0042c84b9c9b9fb82786aadca95a9cb4a893b"},
{file = "asyncpg-0.28.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:99417210461a41891c4ff301490a8713d1ca99b694fef05dabd7139f9d64bd6c"},
{file = "asyncpg-0.28.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f029c5adf08c47b10bcdc857001bbef551ae51c57b3110964844a9d79ca0f267"},
{file = "asyncpg-0.28.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ad1d6abf6c2f5152f46fff06b0e74f25800ce8ec6c80967f0bc789974de3c652"},
{file = "asyncpg-0.28.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d7fa81ada2807bc50fea1dc741b26a4e99258825ba55913b0ddbf199a10d69d8"},
{file = "asyncpg-0.28.0-cp311-cp311-win32.whl", hash = "sha256:f33c5685e97821533df3ada9384e7784bd1e7865d2b22f153f2e4bd4a083e102"},
{file = "asyncpg-0.28.0-cp311-cp311-win_amd64.whl", hash = "sha256:5e7337c98fb493079d686a4a6965e8bcb059b8e1b8ec42106322fc6c1c889bb0"},
{file = "asyncpg-0.28.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:1c56092465e718a9fdcc726cc3d9dcf3a692e4834031c9a9f871d92a75d20d48"},
{file = "asyncpg-0.28.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4acd6830a7da0eb4426249d71353e8895b350daae2380cb26d11e0d4a01c5472"},
{file = "asyncpg-0.28.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63861bb4a540fa033a56db3bb58b0c128c56fad5d24e6d0a8c37cb29b17c1c7d"},
{file = "asyncpg-0.28.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:a93a94ae777c70772073d0512f21c74ac82a8a49be3a1d982e3f259ab5f27307"},
{file = "asyncpg-0.28.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:d14681110e51a9bc9c065c4e7944e8139076a778e56d6f6a306a26e740ed86d2"},
{file = "asyncpg-0.28.0-cp37-cp37m-win32.whl", hash = "sha256:8aec08e7310f9ab322925ae5c768532e1d78cfb6440f63c078b8392a38aa636a"},
{file = "asyncpg-0.28.0-cp37-cp37m-win_amd64.whl", hash = "sha256:319f5fa1ab0432bc91fb39b3960b0d591e6b5c7844dafc92c79e3f1bff96abef"},
{file = "asyncpg-0.28.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b337ededaabc91c26bf577bfcd19b5508d879c0ad009722be5bb0a9dd30b85a0"},
{file = "asyncpg-0.28.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4d32b680a9b16d2957a0a3cc6b7fa39068baba8e6b728f2e0a148a67644578f4"},
{file = "asyncpg-0.28.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f4f62f04cdf38441a70f279505ef3b4eadf64479b17e707c950515846a2df197"},
{file = "asyncpg-0.28.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f20cac332c2576c79c2e8e6464791c1f1628416d1115935a34ddd7121bfc6a4"},
{file = "asyncpg-0.28.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:59f9712ce01e146ff71d95d561fb68bd2d588a35a187116ef05028675462d5ed"},
{file = "asyncpg-0.28.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fc9e9f9ff1aa0eddcc3247a180ac9e9b51a62311e988809ac6152e8fb8097756"},
{file = "asyncpg-0.28.0-cp38-cp38-win32.whl", hash = "sha256:9e721dccd3838fcff66da98709ed884df1e30a95f6ba19f595a3706b4bc757e3"},
{file = "asyncpg-0.28.0-cp38-cp38-win_amd64.whl", hash = "sha256:8ba7d06a0bea539e0487234511d4adf81dc8762249858ed2a580534e1720db00"},
{file = "asyncpg-0.28.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:d009b08602b8b18edef3a731f2ce6d3f57d8dac2a0a4140367e194eabd3de457"},
{file = "asyncpg-0.28.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ec46a58d81446d580fb21b376ec6baecab7288ce5a578943e2fc7ab73bf7eb39"},
{file = "asyncpg-0.28.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7b48ceed606cce9e64fd5480a9b0b9a95cea2b798bb95129687abd8599c8b019"},
{file = "asyncpg-0.28.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8858f713810f4fe67876728680f42e93b7e7d5c7b61cf2118ef9153ec16b9423"},
{file = "asyncpg-0.28.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:5e18438a0730d1c0c1715016eacda6e9a505fc5aa931b37c97d928d44941b4bf"},
{file = "asyncpg-0.28.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e9c433f6fcdd61c21a715ee9128a3ca48be8ac16fa07be69262f016bb0f4dbd2"},
{file = "asyncpg-0.28.0-cp39-cp39-win32.whl", hash = "sha256:41e97248d9076bc8e4849da9e33e051be7ba37cd507cbd51dfe4b2d99c70e3dc"},
{file = "asyncpg-0.28.0-cp39-cp39-win_amd64.whl", hash = "sha256:3ed77f00c6aacfe9d79e9eff9e21729ce92a4b38e80ea99a58ed382f42ebd55b"},
{file = "asyncpg-0.28.0.tar.gz", hash = "sha256:7252cdc3acb2f52feaa3664280d3bcd78a46bd6c10bfd681acfffefa1120e278"},
]
[package.extras]
docs = ["Sphinx (>=5.3.0,<5.4.0)", "sphinx-rtd-theme (>=1.2.2)", "sphinxcontrib-asyncio (>=0.3.0,<0.4.0)"]
test = ["flake8 (>=5.0,<6.0)", "uvloop (>=0.15.3)"]
[[package]] [[package]]
name = "atlassian-python-api" name = "atlassian-python-api"
version = "3.41.2" version = "3.41.2"
@ -685,13 +738,13 @@ pdf = ["pypdf (>=3.3.0,<4.0.0)"]
[[package]] [[package]]
name = "deepeval" name = "deepeval"
version = "0.20.1" version = "0.20.6"
description = "DeepEval provides evaluation and unit testing to accelerate development of LLMs and Agents." description = "DeepEval provides evaluation and unit testing to accelerate development of LLMs and Agents."
optional = false optional = false
python-versions = "*" python-versions = "*"
files = [ files = [
{file = "deepeval-0.20.1-py3-none-any.whl", hash = "sha256:f9880a1246a2a8ba77d88b1d2f977759d34741df6d584bb3c55fadc95c52bc89"}, {file = "deepeval-0.20.6-py3-none-any.whl", hash = "sha256:aa0b96fa062f63398858fe2af1c4982ee9e4d53cd3e322c7bbc3812fe1267614"},
{file = "deepeval-0.20.1.tar.gz", hash = "sha256:e3e36745f5e77bc6055def0b98e7a3274c87564f498f50337b670a291fde32a5"}, {file = "deepeval-0.20.6.tar.gz", hash = "sha256:502c6bb8bc27069d4bbac171c2aac1a760ec8e5c11e3c87a7a8ed2a81ef21db6"},
] ]
[package.dependencies] [package.dependencies]
@ -702,6 +755,7 @@ pytest = "*"
requests = "*" requests = "*"
rich = "*" rich = "*"
sentence-transformers = "*" sentence-transformers = "*"
sentry-sdk = "*"
tabulate = "*" tabulate = "*"
tqdm = "*" tqdm = "*"
transformers = "*" transformers = "*"
@ -3519,6 +3573,51 @@ files = [
{file = "sentencepiece-0.1.99.tar.gz", hash = "sha256:189c48f5cb2949288f97ccdb97f0473098d9c3dcf5a3d99d4eabe719ec27297f"}, {file = "sentencepiece-0.1.99.tar.gz", hash = "sha256:189c48f5cb2949288f97ccdb97f0473098d9c3dcf5a3d99d4eabe719ec27297f"},
] ]
[[package]]
name = "sentry-sdk"
version = "1.32.0"
description = "Python client for Sentry (https://sentry.io)"
optional = false
python-versions = "*"
files = [
{file = "sentry-sdk-1.32.0.tar.gz", hash = "sha256:935e8fbd7787a3702457393b74b13d89a5afb67185bc0af85c00cb27cbd42e7c"},
{file = "sentry_sdk-1.32.0-py2.py3-none-any.whl", hash = "sha256:eeb0b3550536f3bbc05bb1c7e0feb3a78d74acb43b607159a606ed2ec0a33a4d"},
]
[package.dependencies]
certifi = "*"
urllib3 = {version = ">=1.26.11", markers = "python_version >= \"3.6\""}
[package.extras]
aiohttp = ["aiohttp (>=3.5)"]
arq = ["arq (>=0.23)"]
asyncpg = ["asyncpg (>=0.23)"]
beam = ["apache-beam (>=2.12)"]
bottle = ["bottle (>=0.12.13)"]
celery = ["celery (>=3)"]
chalice = ["chalice (>=1.16.0)"]
clickhouse-driver = ["clickhouse-driver (>=0.2.0)"]
django = ["django (>=1.8)"]
falcon = ["falcon (>=1.4)"]
fastapi = ["fastapi (>=0.79.0)"]
flask = ["blinker (>=1.1)", "flask (>=0.11)", "markupsafe"]
grpcio = ["grpcio (>=1.21.1)"]
httpx = ["httpx (>=0.16.0)"]
huey = ["huey (>=2)"]
loguru = ["loguru (>=0.5)"]
opentelemetry = ["opentelemetry-distro (>=0.35b0)"]
opentelemetry-experimental = ["opentelemetry-distro (>=0.40b0,<1.0)", "opentelemetry-instrumentation-aiohttp-client (>=0.40b0,<1.0)", "opentelemetry-instrumentation-django (>=0.40b0,<1.0)", "opentelemetry-instrumentation-fastapi (>=0.40b0,<1.0)", "opentelemetry-instrumentation-flask (>=0.40b0,<1.0)", "opentelemetry-instrumentation-requests (>=0.40b0,<1.0)", "opentelemetry-instrumentation-sqlite3 (>=0.40b0,<1.0)", "opentelemetry-instrumentation-urllib (>=0.40b0,<1.0)"]
pure-eval = ["asttokens", "executing", "pure-eval"]
pymongo = ["pymongo (>=3.1)"]
pyspark = ["pyspark (>=2.4.4)"]
quart = ["blinker (>=1.1)", "quart (>=0.16.1)"]
rq = ["rq (>=0.6)"]
sanic = ["sanic (>=0.8)"]
sqlalchemy = ["sqlalchemy (>=1.2)"]
starlette = ["starlette (>=0.19.1)"]
starlite = ["starlite (>=1.48)"]
tornado = ["tornado (>=5)"]
[[package]] [[package]]
name = "setuptools" name = "setuptools"
version = "68.1.2" version = "68.1.2"
@ -3816,52 +3915,52 @@ files = [
[[package]] [[package]]
name = "sqlalchemy" name = "sqlalchemy"
version = "2.0.20" version = "2.0.21"
description = "Database Abstraction Library" description = "Database Abstraction Library"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
{file = "SQLAlchemy-2.0.20-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:759b51346aa388c2e606ee206c0bc6f15a5299f6174d1e10cadbe4530d3c7a98"}, {file = "SQLAlchemy-2.0.21-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:1e7dc99b23e33c71d720c4ae37ebb095bebebbd31a24b7d99dfc4753d2803ede"},
{file = "SQLAlchemy-2.0.20-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1506e988ebeaaf316f183da601f24eedd7452e163010ea63dbe52dc91c7fc70e"}, {file = "SQLAlchemy-2.0.21-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:7f0c4ee579acfe6c994637527c386d1c22eb60bc1c1d36d940d8477e482095d4"},
{file = "SQLAlchemy-2.0.20-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5768c268df78bacbde166b48be788b83dddaa2a5974b8810af422ddfe68a9bc8"}, {file = "SQLAlchemy-2.0.21-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3f7d57a7e140efe69ce2d7b057c3f9a595f98d0bbdfc23fd055efdfbaa46e3a5"},
{file = "SQLAlchemy-2.0.20-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a3f0dd6d15b6dc8b28a838a5c48ced7455c3e1fb47b89da9c79cc2090b072a50"}, {file = "SQLAlchemy-2.0.21-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ca38746eac23dd7c20bec9278d2058c7ad662b2f1576e4c3dbfcd7c00cc48fa"},
{file = "SQLAlchemy-2.0.20-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:243d0fb261f80a26774829bc2cee71df3222587ac789b7eaf6555c5b15651eed"}, {file = "SQLAlchemy-2.0.21-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:3cf229704074bce31f7f47d12883afee3b0a02bb233a0ba45ddbfe542939cca4"},
{file = "SQLAlchemy-2.0.20-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6eb6d77c31e1bf4268b4d61b549c341cbff9842f8e115ba6904249c20cb78a61"}, {file = "SQLAlchemy-2.0.21-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:fb87f763b5d04a82ae84ccff25554ffd903baafba6698e18ebaf32561f2fe4aa"},
{file = "SQLAlchemy-2.0.20-cp310-cp310-win32.whl", hash = "sha256:bcb04441f370cbe6e37c2b8d79e4af9e4789f626c595899d94abebe8b38f9a4d"}, {file = "SQLAlchemy-2.0.21-cp310-cp310-win32.whl", hash = "sha256:89e274604abb1a7fd5c14867a412c9d49c08ccf6ce3e1e04fffc068b5b6499d4"},
{file = "SQLAlchemy-2.0.20-cp310-cp310-win_amd64.whl", hash = "sha256:d32b5ffef6c5bcb452723a496bad2d4c52b346240c59b3e6dba279f6dcc06c14"}, {file = "SQLAlchemy-2.0.21-cp310-cp310-win_amd64.whl", hash = "sha256:e36339a68126ffb708dc6d1948161cea2a9e85d7d7b0c54f6999853d70d44430"},
{file = "SQLAlchemy-2.0.20-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:dd81466bdbc82b060c3c110b2937ab65ace41dfa7b18681fdfad2f37f27acdd7"}, {file = "SQLAlchemy-2.0.21-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:bf8eebccc66829010f06fbd2b80095d7872991bfe8415098b9fe47deaaa58063"},
{file = "SQLAlchemy-2.0.20-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6fe7d61dc71119e21ddb0094ee994418c12f68c61b3d263ebaae50ea8399c4d4"}, {file = "SQLAlchemy-2.0.21-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b977bfce15afa53d9cf6a632482d7968477625f030d86a109f7bdfe8ce3c064a"},
{file = "SQLAlchemy-2.0.20-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e4e571af672e1bb710b3cc1a9794b55bce1eae5aed41a608c0401885e3491179"}, {file = "SQLAlchemy-2.0.21-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6ff3dc2f60dbf82c9e599c2915db1526d65415be323464f84de8db3e361ba5b9"},
{file = "SQLAlchemy-2.0.20-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3364b7066b3c7f4437dd345d47271f1251e0cfb0aba67e785343cdbdb0fff08c"}, {file = "SQLAlchemy-2.0.21-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:44ac5c89b6896f4740e7091f4a0ff2e62881da80c239dd9408f84f75a293dae9"},
{file = "SQLAlchemy-2.0.20-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:1be86ccea0c965a1e8cd6ccf6884b924c319fcc85765f16c69f1ae7148eba64b"}, {file = "SQLAlchemy-2.0.21-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:87bf91ebf15258c4701d71dcdd9c4ba39521fb6a37379ea68088ce8cd869b446"},
{file = "SQLAlchemy-2.0.20-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:1d35d49a972649b5080557c603110620a86aa11db350d7a7cb0f0a3f611948a0"}, {file = "SQLAlchemy-2.0.21-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:b69f1f754d92eb1cc6b50938359dead36b96a1dcf11a8670bff65fd9b21a4b09"},
{file = "SQLAlchemy-2.0.20-cp311-cp311-win32.whl", hash = "sha256:27d554ef5d12501898d88d255c54eef8414576f34672e02fe96d75908993cf53"}, {file = "SQLAlchemy-2.0.21-cp311-cp311-win32.whl", hash = "sha256:af520a730d523eab77d754f5cf44cc7dd7ad2d54907adeb3233177eeb22f271b"},
{file = "SQLAlchemy-2.0.20-cp311-cp311-win_amd64.whl", hash = "sha256:411e7f140200c02c4b953b3dbd08351c9f9818d2bd591b56d0fa0716bd014f1e"}, {file = "SQLAlchemy-2.0.21-cp311-cp311-win_amd64.whl", hash = "sha256:141675dae56522126986fa4ca713739d00ed3a6f08f3c2eb92c39c6dfec463ce"},
{file = "SQLAlchemy-2.0.20-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:3c6aceebbc47db04f2d779db03afeaa2c73ea3f8dcd3987eb9efdb987ffa09a3"}, {file = "SQLAlchemy-2.0.21-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:7614f1eab4336df7dd6bee05bc974f2b02c38d3d0c78060c5faa4cd1ca2af3b8"},
{file = "SQLAlchemy-2.0.20-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7d3f175410a6db0ad96b10bfbb0a5530ecd4fcf1e2b5d83d968dd64791f810ed"}, {file = "SQLAlchemy-2.0.21-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d59cb9e20d79686aa473e0302e4a82882d7118744d30bb1dfb62d3c47141b3ec"},
{file = "SQLAlchemy-2.0.20-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ea8186be85da6587456c9ddc7bf480ebad1a0e6dcbad3967c4821233a4d4df57"}, {file = "SQLAlchemy-2.0.21-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a95aa0672e3065d43c8aa80080cdd5cc40fe92dc873749e6c1cf23914c4b83af"},
{file = "SQLAlchemy-2.0.20-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:c3d99ba99007dab8233f635c32b5cd24fb1df8d64e17bc7df136cedbea427897"}, {file = "SQLAlchemy-2.0.21-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:8c323813963b2503e54d0944813cd479c10c636e3ee223bcbd7bd478bf53c178"},
{file = "SQLAlchemy-2.0.20-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:76fdfc0f6f5341987474ff48e7a66c3cd2b8a71ddda01fa82fedb180b961630a"}, {file = "SQLAlchemy-2.0.21-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:419b1276b55925b5ac9b4c7044e999f1787c69761a3c9756dec6e5c225ceca01"},
{file = "SQLAlchemy-2.0.20-cp37-cp37m-win32.whl", hash = "sha256:d3793dcf5bc4d74ae1e9db15121250c2da476e1af8e45a1d9a52b1513a393459"}, {file = "SQLAlchemy-2.0.21-cp37-cp37m-win32.whl", hash = "sha256:4615623a490e46be85fbaa6335f35cf80e61df0783240afe7d4f544778c315a9"},
{file = "SQLAlchemy-2.0.20-cp37-cp37m-win_amd64.whl", hash = "sha256:79fde625a0a55220d3624e64101ed68a059c1c1f126c74f08a42097a72ff66a9"}, {file = "SQLAlchemy-2.0.21-cp37-cp37m-win_amd64.whl", hash = "sha256:cca720d05389ab1a5877ff05af96551e58ba65e8dc65582d849ac83ddde3e231"},
{file = "SQLAlchemy-2.0.20-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:599ccd23a7146e126be1c7632d1d47847fa9f333104d03325c4e15440fc7d927"}, {file = "SQLAlchemy-2.0.21-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b4eae01faee9f2b17f08885e3f047153ae0416648f8e8c8bd9bc677c5ce64be9"},
{file = "SQLAlchemy-2.0.20-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:1a58052b5a93425f656675673ef1f7e005a3b72e3f2c91b8acca1b27ccadf5f4"}, {file = "SQLAlchemy-2.0.21-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:3eb7c03fe1cd3255811cd4e74db1ab8dca22074d50cd8937edf4ef62d758cdf4"},
{file = "SQLAlchemy-2.0.20-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:79543f945be7a5ada9943d555cf9b1531cfea49241809dd1183701f94a748624"}, {file = "SQLAlchemy-2.0.21-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c2d494b6a2a2d05fb99f01b84cc9af9f5f93bf3e1e5dbdafe4bed0c2823584c1"},
{file = "SQLAlchemy-2.0.20-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63e73da7fb030ae0a46a9ffbeef7e892f5def4baf8064786d040d45c1d6d1dc5"}, {file = "SQLAlchemy-2.0.21-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b19ae41ef26c01a987e49e37c77b9ad060c59f94d3b3efdfdbf4f3daaca7b5fe"},
{file = "SQLAlchemy-2.0.20-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:3ce5e81b800a8afc870bb8e0a275d81957e16f8c4b62415a7b386f29a0cb9763"}, {file = "SQLAlchemy-2.0.21-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:fc6b15465fabccc94bf7e38777d665b6a4f95efd1725049d6184b3a39fd54880"},
{file = "SQLAlchemy-2.0.20-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:cb0d3e94c2a84215532d9bcf10229476ffd3b08f481c53754113b794afb62d14"}, {file = "SQLAlchemy-2.0.21-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:014794b60d2021cc8ae0f91d4d0331fe92691ae5467a00841f7130fe877b678e"},
{file = "SQLAlchemy-2.0.20-cp38-cp38-win32.whl", hash = "sha256:8dd77fd6648b677d7742d2c3cc105a66e2681cc5e5fb247b88c7a7b78351cf74"}, {file = "SQLAlchemy-2.0.21-cp38-cp38-win32.whl", hash = "sha256:0268256a34806e5d1c8f7ee93277d7ea8cc8ae391f487213139018b6805aeaf6"},
{file = "SQLAlchemy-2.0.20-cp38-cp38-win_amd64.whl", hash = "sha256:6f8a934f9dfdf762c844e5164046a9cea25fabbc9ec865c023fe7f300f11ca4a"}, {file = "SQLAlchemy-2.0.21-cp38-cp38-win_amd64.whl", hash = "sha256:73c079e21d10ff2be54a4699f55865d4b275fd6c8bd5d90c5b1ef78ae0197301"},
{file = "SQLAlchemy-2.0.20-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:26a3399eaf65e9ab2690c07bd5cf898b639e76903e0abad096cd609233ce5208"}, {file = "SQLAlchemy-2.0.21-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:785e2f2c1cb50d0a44e2cdeea5fd36b5bf2d79c481c10f3a88a8be4cfa2c4615"},
{file = "SQLAlchemy-2.0.20-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4cde2e1096cbb3e62002efdb7050113aa5f01718035ba9f29f9d89c3758e7e4e"}, {file = "SQLAlchemy-2.0.21-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c111cd40910ffcb615b33605fc8f8e22146aeb7933d06569ac90f219818345ef"},
{file = "SQLAlchemy-2.0.20-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d1b09ba72e4e6d341bb5bdd3564f1cea6095d4c3632e45dc69375a1dbe4e26ec"}, {file = "SQLAlchemy-2.0.21-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c9cba4e7369de663611ce7460a34be48e999e0bbb1feb9130070f0685e9a6b66"},
{file = "SQLAlchemy-2.0.20-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1b74eeafaa11372627ce94e4dc88a6751b2b4d263015b3523e2b1e57291102f0"}, {file = "SQLAlchemy-2.0.21-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:50a69067af86ec7f11a8e50ba85544657b1477aabf64fa447fd3736b5a0a4f67"},
{file = "SQLAlchemy-2.0.20-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:77d37c1b4e64c926fa3de23e8244b964aab92963d0f74d98cbc0783a9e04f501"}, {file = "SQLAlchemy-2.0.21-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:ccb99c3138c9bde118b51a289d90096a3791658da9aea1754667302ed6564f6e"},
{file = "SQLAlchemy-2.0.20-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:eefebcc5c555803065128401a1e224a64607259b5eb907021bf9b175f315d2a6"}, {file = "SQLAlchemy-2.0.21-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:513fd5b6513d37e985eb5b7ed89da5fd9e72354e3523980ef00d439bc549c9e9"},
{file = "SQLAlchemy-2.0.20-cp39-cp39-win32.whl", hash = "sha256:3423dc2a3b94125094897118b52bdf4d37daf142cbcf26d48af284b763ab90e9"}, {file = "SQLAlchemy-2.0.21-cp39-cp39-win32.whl", hash = "sha256:f9fefd6298433b6e9188252f3bff53b9ff0443c8fde27298b8a2b19f6617eeb9"},
{file = "SQLAlchemy-2.0.20-cp39-cp39-win_amd64.whl", hash = "sha256:5ed61e3463021763b853628aef8bc5d469fe12d95f82c74ef605049d810f3267"}, {file = "SQLAlchemy-2.0.21-cp39-cp39-win_amd64.whl", hash = "sha256:2e617727fe4091cedb3e4409b39368f424934c7faa78171749f704b49b4bb4ce"},
{file = "SQLAlchemy-2.0.20-py3-none-any.whl", hash = "sha256:63a368231c53c93e2b67d0c5556a9836fdcd383f7e3026a39602aad775b14acf"}, {file = "SQLAlchemy-2.0.21-py3-none-any.whl", hash = "sha256:ea7da25ee458d8f404b93eb073116156fd7d8c2a776d8311534851f28277b4ce"},
{file = "SQLAlchemy-2.0.20.tar.gz", hash = "sha256:ca8a5ff2aa7f3ade6c498aaafce25b1eaeabe4e42b73e25519183e4566a16fc6"}, {file = "SQLAlchemy-2.0.21.tar.gz", hash = "sha256:05b971ab1ac2994a14c56b35eaaa91f86ba080e9ad481b20d99d77f381bb6258"},
] ]
[package.dependencies] [package.dependencies]
@ -4778,4 +4877,4 @@ multidict = ">=4.0"
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.10" python-versions = "^3.10"
content-hash = "cf14b96576c57633fea1b0cc1f7266a5fb265d4ff4ce2d3479a036483082b3b4" content-hash = "b3f795baef70806c8dab1058b644d95f3c36b1c2fae53a7e5784a633731bc268"

View file

@ -40,11 +40,13 @@ weaviate-client = "^3.22.1"
python-multipart = "^0.0.6" python-multipart = "^0.0.6"
deep-translator = "^1.11.4" deep-translator = "^1.11.4"
humanize = "^4.8.0" humanize = "^4.8.0"
deepeval = "^0.20.1" deepeval = "^0.20.6"
pymupdf = "^1.23.3" pymupdf = "^1.23.3"
psycopg2 = "^2.9.8" psycopg2 = "^2.9.8"
llama-index = "^0.8.39.post2" llama-index = "^0.8.39.post2"
llama-hub = "^0.0.34" llama-hub = "^0.0.34"
sqlalchemy = "^2.0.21"
asyncpg = "^0.28.0"

View file

@ -20,6 +20,8 @@ from database.database import engine
from vectorstore_manager import Memory from vectorstore_manager import Memory
import uuid import uuid
from contextlib import contextmanager from contextlib import contextmanager
from database.database import AsyncSessionLocal
from database.database_crud import session_scope
import random import random
import string import string
@ -30,103 +32,65 @@ dotenv.load_dotenv()
import openai import openai
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
openai.api_key = os.getenv("OPENAI_API_KEY", "") openai.api_key = os.getenv("OPENAI_API_KEY", "")
@contextmanager
def session_scope(session): async def retrieve_latest_test_case(session, user_id, memory_id):
"""Provide a transactional scope around a series of operations."""
try: try:
yield session # Use await with session.execute() and row.fetchone() or row.all() for async query execution
session.commit() result = await session.execute(
except Exception as e:
session.rollback()
logger.error(f"Session rollback due to: {str(e)}")
raise
finally:
session.close()
def retrieve_latest_test_case(session, user_id, memory_id):
"""
Retrieve the most recently created test case from the database filtered by user_id and memory_id.
Parameters:
- session (Session): A database session.
- user_id (int/str): The ID of the user to filter test cases by.
- memory_id (int/str): The ID of the memory to filter test cases by.
Returns:
- Object: The most recent test case attributes filtered by user_id and memory_id, or None if an error occurs.
"""
try:
return (
session.query(TestSet.attributes_list) session.query(TestSet.attributes_list)
.filter_by(user_id=user_id, memory_id=memory_id) .filter_by(user_id=user_id, memory_id=memory_id)
.order_by(TestSet.created_at.desc()) .order_by(TestSet.created_at).first()
.first()
) )
return result.scalar_one_or_none() # scalar_one_or_none() is a non-blocking call
except Exception as e: except Exception as e:
logger.error(f"An error occurred while retrieving the latest test case: {str(e)}") logger.error(f"An error occurred while retrieving the latest test case: {str(e)}")
return None return None
async def add_entity(session, entity):
def add_entity(session, entity): async with session_scope(session) as s: # Use your async session_scope
""" s.add(entity) # No need to commit; session_scope takes care of it
Add an entity (like TestOutput, Session, etc.) to the database. s.commit()
Parameters:
- session (Session): A database session.
- entity (Base): An instance of an SQLAlchemy model.
Returns:
- str: A message indicating whether the addition was successful.
"""
with session_scope(session):
session.add(entity)
session.commit()
return "Successfully added entity" return "Successfully added entity"
async def retrieve_job_by_id(session, user_id, job_id):
def retrieve_job_by_id(session, user_id, job_id):
"""
Retrieve a job by user ID and job ID.
Parameters:
- session (Session): A database session.
- user_id (int/str): The ID of the user.
- job_id (int/str): The ID of the job to retrieve.
Returns:
- Object: The job attributes filtered by user_id and job_id, or None if an error occurs.
"""
try: try:
return ( result = await session.execute(
session.query(Session.id) session.query(Session.id)
.filter_by(user_id=user_id, id=job_id) .filter_by(user_id=user_id, id=job_id)
.order_by(Session.created_at.desc()) .order_by(Session.created_at)
.first()
) )
return result.scalar_one_or_none()
except Exception as e: except Exception as e:
logger.error(f"An error occurred while retrieving the job: {str(e)}") logger.error(f"An error occurred while retrieving the job: {str(e)}")
return None return None
def fetch_job_id(session, user_id=None, memory_id=None, job_id=None): async def fetch_job_id(session, user_id=None, memory_id=None, job_id=None):
try: try:
return ( result = await session.execute(
session.query(Session.id) session.query(Session.id)
.filter_by(user_id=user_id, id=job_id) .filter_by(user_id=user_id, id=job_id)
.order_by(Session.created_at.desc()) .order_by(Session.created_at).first()
.first()
) )
return result.scalar_one_or_none()
except Exception as e: except Exception as e:
# Handle exceptions as per your application's requirements. logger.error(f"An error occurred: {str(e)}")
print(f"An error occurred: {str(e)}")
return None return None
def compare_output(output, expected_output): async def fetch_test_set_id(session, user_id, id):
"""Compare the output against the expected output.""" try:
pass # Await the execution of the query and fetching of the result
result = await session.execute(
session.query(TestSet.id)
.filter_by(user_id=user_id, id=id)
.order_by(TestSet.created_at).desc().first()
)
return result.scalar_one_or_none() # scalar_one_or_none() is a non-blocking call
except Exception as e:
logger.error(f"An error occurred while retrieving the test set: {str(e)}")
return None
# Adding "embeddings" to the parameter variants function
def generate_param_variants(base_params=None, increments=None, ranges=None, included_params=None): def generate_param_variants(base_params=None, increments=None, ranges=None, included_params=None):
"""Generate parameter variants for testing. """Generate parameter variants for testing.
@ -142,38 +106,39 @@ def generate_param_variants(base_params=None, increments=None, ranges=None, incl
list: A list of dictionaries containing parameter variants. list: A list of dictionaries containing parameter variants.
""" """
# Default base values # Default values
defaults = { defaults = {
'chunk_size': 500, 'chunk_size': 250,
'chunk_overlap': 20, 'chunk_overlap': 20,
'similarity_score': 0.5, 'similarity_score': 0.5,
'metadata_variation': 0, 'metadata_variation': 0,
'search_type': 'hybrid' 'search_type': 'hybrid',
'embeddings': 'openai' # Default value added for 'embeddings'
} }
# Update defaults with provided base parameters # Update defaults with provided base parameters
params = {**defaults, **(base_params if base_params is not None else {})} params = {**defaults, **(base_params or {})}
default_increments = { default_increments = {
'chunk_size': 500, 'chunk_size': 150,
'chunk_overlap': 10, 'chunk_overlap': 10,
'similarity_score': 0.1, 'similarity_score': 0.1,
'metadata_variation': 1 'metadata_variation': 1
} }
# Update default increments with provided increments # Update default increments with provided increments
increments = {**default_increments, **(increments if increments is not None else {})} increments = {**default_increments, **(increments or {})}
# Default ranges # Default ranges
default_ranges = { default_ranges = {
'chunk_size': 3, 'chunk_size': 2,
'chunk_overlap': 3, 'chunk_overlap': 2,
'similarity_score': 3, 'similarity_score': 2,
'metadata_variation': 3 'metadata_variation': 2
} }
# Update default ranges with provided ranges # Update default ranges with provided ranges
ranges = {**default_ranges, **(ranges if ranges is not None else {})} ranges = {**default_ranges, **(ranges or {})}
# Generate parameter variant ranges # Generate parameter variant ranges
param_ranges = { param_ranges = {
@ -181,10 +146,9 @@ def generate_param_variants(base_params=None, increments=None, ranges=None, incl
for key in ['chunk_size', 'chunk_overlap', 'similarity_score', 'metadata_variation'] for key in ['chunk_size', 'chunk_overlap', 'similarity_score', 'metadata_variation']
} }
# Add search_type and embeddings with possible values
param_ranges['cognitive_architecture'] = ["simple_index", "cognitive_architecture"]
# Add search_type with possible values
param_ranges['search_type'] = ['text', 'hybrid', 'bm25', 'generate', 'generate_grouped'] param_ranges['search_type'] = ['text', 'hybrid', 'bm25', 'generate', 'generate_grouped']
param_ranges['embeddings'] = ['openai', 'cohere', 'huggingface'] # Added 'embeddings' values
# Filter param_ranges based on included_params # Filter param_ranges based on included_params
if included_params is not None: if included_params is not None:
@ -197,6 +161,10 @@ def generate_param_variants(base_params=None, increments=None, ranges=None, incl
return param_variants return param_variants
# Generate parameter variants and display a sample of the generated combinations
async def generate_chatgpt_output(query:str, context:str=None): async def generate_chatgpt_output(query:str, context:str=None):
response = openai.ChatCompletion.create( response = openai.ChatCompletion.create(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
@ -278,195 +246,200 @@ def generate_letter_uuid(length=8):
letters = string.ascii_uppercase # A-Z letters = string.ascii_uppercase # A-Z
return ''.join(random.choice(letters) for _ in range(length)) return ''.join(random.choice(letters) for _ in range(length))
async def start_test(data, test_set=None, user_id=None, params=None, job_id=None, metadata=None):
def fetch_test_set_id(session, user_id, id):
try: async with session_scope(session=AsyncSessionLocal()) as session:
return ( memory = await Memory.create_memory(user_id, session, namespace="SEMANTICMEMORY")
session.query(TestSet.id) job_id = await fetch_job_id(session, user_id=user_id, job_id=job_id)
.filter_by(user_id=user_id, id=id) test_set_id = await fetch_test_set_id(session, user_id=user_id, id=job_id)
.order_by(TestSet.created_at)
.desc().first() if job_id is None:
job_id = str(uuid.uuid4())
await add_entity(session, Operation(id=job_id, user_id=user_id))
if test_set_id is None:
test_set_id = str(uuid.uuid4())
await add_entity(session, TestSet(id=test_set_id, user_id=user_id, content=str(test_set)))
if params is None:
data_format = data_format_route(data) # Assume data_format_route is predefined
data_location = data_location_route(data) # Assume data_location_route is predefined
test_params = generate_param_variants(
included_params=['chunk_size', 'chunk_overlap', 'search_type'])
print("Here are the test params", str(test_params))
loader_settings = {
"format": f"{data_format}",
"source": f"{data_location}",
"path": "https://www.ibiblio.org/ebooks/London/Call%20of%20Wild.pdf"
}
async def run_test(test, loader_settings, metadata):
test_id = str(generate_letter_uuid()) + "_" + "SEMANTICEMEMORY"
await memory.add_memory_instance("ExampleMemory")
existing_user = await Memory.check_existing_user(user_id, session)
await memory.manage_memory_attributes(existing_user)
test_class = test_id + "_class"
await memory.add_dynamic_memory_class(test_id.lower(), test_id)
dynamic_memory_class = getattr(memory, test_class.lower(), None)
# Assuming add_method_to_class and dynamic_method_call are predefined methods in Memory class
if dynamic_memory_class is not None:
await memory.add_method_to_class(dynamic_memory_class, 'add_memories')
else:
print(f"No attribute named {test_class.lower()} in memory.")
if dynamic_memory_class is not None:
await memory.add_method_to_class(dynamic_memory_class, 'fetch_memories')
else:
print(f"No attribute named {test_class.lower()} in memory.")
print(f"Trying to access: {test_class.lower()}")
print("Available memory classes:", await memory.list_memory_classes())
print(f"Trying to check: ", test)
loader_settings.update(test)
async def run_load_test_element(test, loader_settings, metadata, test_id):
test_class = test_id + "_class"
# memory.test_class
await memory.add_dynamic_memory_class(test_id.lower(), test_id)
dynamic_memory_class = getattr(memory, test_class.lower(), None)
if dynamic_memory_class is not None:
await memory.add_method_to_class(dynamic_memory_class, 'add_memories')
else:
print(f"No attribute named {test_class.lower()} in memory.")
if dynamic_memory_class is not None:
await memory.add_method_to_class(dynamic_memory_class, 'fetch_memories')
else:
print(f"No attribute named {test_class.lower()} in memory.")
print(f"Trying to access: {test_class.lower()}")
print("Available memory classes:", await memory.list_memory_classes())
print(f"Trying to check: ", test)
# print("Here is the loader settings", str(loader_settings))
# print("Here is the medatadata", str(metadata))
load_action = await memory.dynamic_method_call(dynamic_memory_class, 'add_memories',
observation='some_observation', params=metadata,
loader_settings=loader_settings)
async def run_search_eval_element(test_item, test_id):
test_class = test_id + "_class"
await memory.add_dynamic_memory_class(test_id.lower(), test_id)
dynamic_memory_class = getattr(memory, test_class.lower(), None)
retrieve_action = await memory.dynamic_method_call(dynamic_memory_class, 'fetch_memories',
observation=test_item["question"],
search_type=test_item["search_type"])
test_result = await eval_test(query=test_item["question"], expected_output=test_item["answer"],
context=str(retrieve_action))
print(test_result)
delete_mems = await memory.dynamic_method_call(dynamic_memory_class, 'delete_memories',
namespace=test_id)
return test_result
test_load_pipeline = await asyncio.gather(
*(run_load_test_element(test_item,loader_settings, metadata, test_id) for test_item in test_set)
)
test_eval_pipeline = await asyncio.gather(
*(run_search_eval_element(test_item, test_id) for test_item in test_set)
)
logging.info("Results of the eval pipeline %s", str(test_eval_pipeline))
await add_entity(session, TestOutput(id=test_id, user_id=user_id, test_results=str(test_eval_pipeline)))
# # Gather and run all tests in parallel
results = await asyncio.gather(
*(run_test(test, loader_settings, metadata) for test in test_params)
) )
except Exception as e: return results
logger.error(f"An error occurred while retrieving the job: {str(e)}")
return None
async def start_test(data, test_set=None, user_id=None, params=None, job_id=None ,metadata=None):
Session = sessionmaker(bind=engine)
session = Session()
#do i need namespace in memory instance, fix it
memory = Memory.create_memory(user_id, session, namespace="SEMANTICMEMORY")
job_id = fetch_job_id(session, user_id = user_id,job_id =job_id)
test_set_id = fetch_test_set_id(session, user_id=user_id, id=job_id)
if job_id is None:
job_id = str(uuid.uuid4())
logging.info("we are adding a new job ID")
add_entity(session, Operation(id = job_id, user_id = user_id))
if test_set_id is None:
test_set_id = str(uuid.uuid4())
add_entity(session, TestSet(id = test_set_id, user_id = user_id, content = str(test_set)))
if params is None:
data_format = data_format_route(data)
data_location = data_location_route(data)
test_params = generate_param_variants( included_params=['chunk_size', 'chunk_overlap', 'similarity_score'])
loader_settings = {
"format": f"{data_format}",
"source": f"{data_location}",
"path": "https://www.ibiblio.org/ebooks/London/Call%20of%20Wild.pdf"
}
for test in test_params:
test_id = str(generate_letter_uuid()) + "_" + "SEMANTICEMEMORY"
# Adding a memory instance
memory.add_memory_instance("ExampleMemory")
# Managing memory attributes
existing_user = Memory.check_existing_user(user_id, session)
print("here is the existing user", existing_user)
memory.manage_memory_attributes(existing_user)
test_class = test_id + "_class"
# memory.test_class
memory.add_dynamic_memory_class(test_id.lower(), test_id)
dynamic_memory_class = getattr(memory, test_class.lower(), None)
if dynamic_memory_class is not None:
memory.add_method_to_class(dynamic_memory_class, 'add_memories')
else:
print(f"No attribute named {test_class.lower()} in memory.")
if dynamic_memory_class is not None:
memory.add_method_to_class(dynamic_memory_class, 'fetch_memories')
else:
print(f"No attribute named {test_class.lower()} in memory.")
print(f"Trying to access: {test_class.lower()}")
print("Available memory classes:", memory.list_memory_classes())
print(f"Trying to check: ", test)
loader_settings.update(test)
load_action = await memory.dynamic_method_call(dynamic_memory_class, 'add_memories',
observation='some_observation', params=metadata, loader_settings=loader_settings)
loader_settings = {key: value for key, value in loader_settings.items() if key not in test}
test_result_collection =[]
for test in test_set:
retrieve_action = await memory.dynamic_method_call(dynamic_memory_class, 'fetch_memories',
observation=test["question"])
test_results = await eval_test( query=test["question"], expected_output=test["answer"], context= str(retrieve_action))
test_result_collection.append(test_results)
print(test_results)
if dynamic_memory_class is not None:
memory.add_method_to_class(dynamic_memory_class, 'delete_memories')
else:
print(f"No attribute named {test_class.lower()} in memory.")
delete_mems = await memory.dynamic_method_call(dynamic_memory_class, 'delete_memories',
namespace =test_id)
print(test_result_collection)
add_entity(session, TestOutput(id=test_id, user_id=user_id, test_results=str(test_result_collection)))
async def main(): async def main():
params = {
"version": "1.0",
"agreement_id": "AG123456",
"privacy_policy": "https://example.com/privacy",
"terms_of_service": "https://example.com/terms",
"format": "json",
"schema_version": "1.1",
"checksum": "a1b2c3d4e5f6",
"owner": "John Doe",
"license": "MIT",
"validity_start": "2023-08-01",
"validity_end": "2024-07-31",
}
test_set = [
{
"question": "Who is the main character in 'The Call of the Wild'?",
"answer": "Buck"
},
{
"question": "Who wrote 'The Call of the Wild'?",
"answer": "Jack London"
},
{
"question": "Where does Buck live at the start of the book?",
"answer": "In the Santa Clara Valley, at Judge Millers place."
},
{
"question": "Why is Buck kidnapped?",
"answer": "He is kidnapped to be sold as a sled dog in the Yukon during the Klondike Gold Rush."
},
{
"question": "How does Buck become the leader of the sled dog team?",
"answer": "Buck becomes the leader after defeating the original leader, Spitz, in a fight."
}
]
result = await start_test("https://www.ibiblio.org/ebooks/London/Call%20of%20Wild.pdf", test_set=test_set, user_id="666", params=None, metadata=params)
# #
# params = { # parser = argparse.ArgumentParser(description="Run tests against a document.")
# "version": "1.0", # parser.add_argument("--url", required=True, help="URL of the document to test.")
# "agreement_id": "AG123456", # parser.add_argument("--test_set", required=True, help="Path to JSON file containing the test set.")
# "privacy_policy": "https://example.com/privacy", # parser.add_argument("--user_id", required=True, help="User ID.")
# "terms_of_service": "https://example.com/terms", # parser.add_argument("--params", help="Additional parameters in JSON format.")
# "format": "json", # parser.add_argument("--metadata", required=True, help="Path to JSON file containing metadata.")
# "schema_version": "1.1",
# "checksum": "a1b2c3d4e5f6",
# "owner": "John Doe",
# "license": "MIT",
# "validity_start": "2023-08-01",
# "validity_end": "2024-07-31",
# }
# #
# test_set = [ # args = parser.parse_args()
# {
# "question": "Who is the main character in 'The Call of the Wild'?",
# "answer": "Buck"
# },
# {
# "question": "Who wrote 'The Call of the Wild'?",
# "answer": "Jack London"
# },
# {
# "question": "Where does Buck live at the start of the book?",
# "answer": "In the Santa Clara Valley, at Judge Millers place."
# },
# {
# "question": "Why is Buck kidnapped?",
# "answer": "He is kidnapped to be sold as a sled dog in the Yukon during the Klondike Gold Rush."
# },
# {
# "question": "How does Buck become the leader of the sled dog team?",
# "answer": "Buck becomes the leader after defeating the original leader, Spitz, in a fight."
# }
# ]
# result = await start_test("https://www.ibiblio.org/ebooks/London/Call%20of%20Wild.pdf", test_set=test_set, user_id="666", params=None, metadata=params)
# #
parser = argparse.ArgumentParser(description="Run tests against a document.") # try:
parser.add_argument("--url", required=True, help="URL of the document to test.") # with open(args.test_set, "r") as file:
parser.add_argument("--test_set", required=True, help="Path to JSON file containing the test set.") # test_set = json.load(file)
parser.add_argument("--user_id", required=True, help="User ID.") # if not isinstance(test_set, list): # Expecting a list
parser.add_argument("--params", help="Additional parameters in JSON format.") # raise TypeError("Parsed test_set JSON is not a list.")
parser.add_argument("--metadata", required=True, help="Path to JSON file containing metadata.") # except Exception as e:
# print(f"Error loading test_set: {str(e)}")
args = parser.parse_args() # return
#
try: # try:
with open(args.test_set, "r") as file: # with open(args.metadata, "r") as file:
test_set = json.load(file) # metadata = json.load(file)
if not isinstance(test_set, list): # Expecting a list # if not isinstance(metadata, dict):
raise TypeError("Parsed test_set JSON is not a list.") # raise TypeError("Parsed metadata JSON is not a dictionary.")
except Exception as e: # except Exception as e:
print(f"Error loading test_set: {str(e)}") # print(f"Error loading metadata: {str(e)}")
return # return
#
try: # if args.params:
with open(args.metadata, "r") as file: # try:
metadata = json.load(file) # params = json.loads(args.params)
if not isinstance(metadata, dict): # if not isinstance(params, dict):
raise TypeError("Parsed metadata JSON is not a dictionary.") # raise TypeError("Parsed params JSON is not a dictionary.")
except Exception as e: # except json.JSONDecodeError as e:
print(f"Error loading metadata: {str(e)}") # print(f"Error parsing params: {str(e)}")
return # return
# else:
if args.params: # params = None
try: # #clean up params here
params = json.loads(args.params) # await start_test(args.url, test_set, args.user_id, params=None, metadata=metadata)
if not isinstance(params, dict):
raise TypeError("Parsed params JSON is not a dictionary.")
except json.JSONDecodeError as e:
print(f"Error parsing params: {str(e)}")
return
else:
params = None
#clean up params here
await start_test(args.url, test_set, args.user_id, params=None, metadata=metadata)
if __name__ == "__main__": if __name__ == "__main__":
import asyncio import asyncio

View file

@ -179,17 +179,17 @@ class WeaviateVectorDB(VectorDB):
["id", "creationTimeUnix", "lastUpdateTimeUnix", "score", 'distance'] ["id", "creationTimeUnix", "lastUpdateTimeUnix", "score", 'distance']
).with_where(params_user_id).with_limit(10) ).with_where(params_user_id).with_limit(10)
n_of_observations = kwargs.get('n_of_observations', 2)
try: try:
if search_type == 'text': if search_type == 'text':
query_output = ( query_output = (
base_query base_query
.with_near_text({"concepts": [observation]}) .with_near_text({"concepts": [observation]})
.with_autocut(n_of_observations)
.do() .do()
) )
elif search_type == 'hybrid': elif search_type == 'hybrid':
n_of_observations = kwargs.get('n_of_observations', 2)
query_output = ( query_output = (
base_query base_query
.with_hybrid(query=observation, fusion_type=HybridFusion.RELATIVE_SCORE) .with_hybrid(query=observation, fusion_type=HybridFusion.RELATIVE_SCORE)
@ -200,6 +200,7 @@ class WeaviateVectorDB(VectorDB):
query_output = ( query_output = (
base_query base_query
.with_bm25(query=observation) .with_bm25(query=observation)
.with_autocut(n_of_observations)
.do() .do()
) )
elif search_type == 'generate': elif search_type == 'generate':
@ -208,6 +209,7 @@ class WeaviateVectorDB(VectorDB):
base_query base_query
.with_generate(single_prompt=generate_prompt) .with_generate(single_prompt=generate_prompt)
.with_near_text({"concepts": [observation]}) .with_near_text({"concepts": [observation]})
.with_autocut(n_of_observations)
.do() .do()
) )
elif search_type == 'generate_grouped': elif search_type == 'generate_grouped':
@ -216,6 +218,7 @@ class WeaviateVectorDB(VectorDB):
base_query base_query
.with_generate(grouped_task=generate_prompt) .with_generate(grouped_task=generate_prompt)
.with_near_text({"concepts": [observation]}) .with_near_text({"concepts": [observation]})
.with_autocut(n_of_observations)
.do() .do()
) )
else: else:

View file

@ -1,5 +1,7 @@
import logging import logging
from sqlalchemy.future import select
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
import marvin import marvin
from dotenv import load_dotenv from dotenv import load_dotenv
@ -15,6 +17,7 @@ from models.operation import Operation
load_dotenv() load_dotenv()
import ast import ast
import tracemalloc import tracemalloc
from database.database_crud import session_scope, add_entity
tracemalloc.start() tracemalloc.start()
@ -48,7 +51,7 @@ class DynamicBaseMemory(BaseMemory):
self.associations = [] self.associations = []
def add_method(self, method_name): async def add_method(self, method_name):
""" """
Add a method to the memory class. Add a method to the memory class.
@ -60,7 +63,7 @@ class DynamicBaseMemory(BaseMemory):
""" """
self.methods.add(method_name) self.methods.add(method_name)
def add_attribute(self, attribute_name): async def add_attribute(self, attribute_name):
""" """
Add an attribute to the memory class. Add an attribute to the memory class.
@ -72,7 +75,7 @@ class DynamicBaseMemory(BaseMemory):
""" """
self.attributes.add(attribute_name) self.attributes.add(attribute_name)
def get_attribute(self, attribute_name): async def get_attribute(self, attribute_name):
""" """
Check if the attribute is in the memory class. Check if the attribute is in the memory class.
@ -84,7 +87,7 @@ class DynamicBaseMemory(BaseMemory):
""" """
return attribute_name in self.attributes return attribute_name in self.attributes
def add_association(self, associated_memory): async def add_association(self, associated_memory):
""" """
Add an association to another memory class. Add an association to another memory class.
@ -149,26 +152,26 @@ class Memory:
self.OPENAI_TEMPERATURE = float(os.getenv("OPENAI_TEMPERATURE", 0.0)) self.OPENAI_TEMPERATURE = float(os.getenv("OPENAI_TEMPERATURE", 0.0))
self.OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "") self.OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
@classmethod @classmethod
def create_memory(cls, user_id: str, session, **kwargs): async def create_memory(cls, user_id: str, session, **kwargs):
""" """
Class method that acts as a factory method for creating Memory instances. Class method that acts as a factory method for creating Memory instances.
It performs necessary DB checks or updates before instance creation. It performs necessary DB checks or updates before instance creation.
""" """
existing_user = cls.check_existing_user(user_id, session) existing_user = await cls.check_existing_user(user_id, session)
if existing_user: if existing_user:
# Handle existing user scenario... # Handle existing user scenario...
memory_id = cls.check_existing_memory(user_id, session) memory_id = await cls.check_existing_memory(user_id, session)
logging.info(f"Existing user {user_id} found in the DB. Memory ID: {memory_id}") logging.info(f"Existing user {user_id} found in the DB. Memory ID: {memory_id}")
else: else:
# Handle new user scenario... # Handle new user scenario...
memory_id = cls.handle_new_user(user_id, session) memory_id = await cls.handle_new_user(user_id, session)
logging.info(f"New user {user_id} created in the DB. Memory ID: {memory_id}") logging.info(f"New user {user_id} created in the DB. Memory ID: {memory_id}")
return cls(user_id=user_id, session=session, memory_id=memory_id, **kwargs) return cls(user_id=user_id, session=session, memory_id=memory_id, **kwargs)
def list_memory_classes(self): async def list_memory_classes(self):
""" """
Lists all available memory classes in the memory instance. Lists all available memory classes in the memory instance.
""" """
@ -176,32 +179,36 @@ class Memory:
return [attr for attr in dir(self) if attr.endswith("_class")] return [attr for attr in dir(self) if attr.endswith("_class")]
@staticmethod @staticmethod
def check_existing_user(user_id: str, session): async def check_existing_user(user_id: str, session):
"""Check if a user exists in the DB and return it.""" """Check if a user exists in the DB and return it."""
return session.query(User).filter_by(id=user_id).first() result = await session.execute(
select(User).where(User.id == user_id)
)
return result.scalar_one_or_none()
@staticmethod @staticmethod
def check_existing_memory(user_id: str, session): async def check_existing_memory(user_id: str, session):
"""Check if a user memory exists in the DB and return it.""" """Check if a user memory exists in the DB and return it."""
return session.query(MemoryModel.id).filter_by(user_id=user_id).first() result = await session.execute(
select(MemoryModel.id).where(MemoryModel.user_id == user_id)
)
return result.scalar_one_or_none()
@staticmethod @staticmethod
def handle_new_user(user_id: str, session): async def handle_new_user(user_id: str, session):
"""Handle new user creation in the DB and return the new memory ID.""" """Handle new user creation in the DB and return the new memory ID."""
#handle these better in terms of retry and error handling #handle these better in terms of retry and error handling
memory_id = str(uuid.uuid4()) memory_id = str(uuid.uuid4())
new_user = User(id=user_id) new_user = User(id=user_id)
session.add(new_user) await add_entity(session, new_user)
session.commit()
memory = MemoryModel(id=memory_id, user_id=user_id, methods_list=str(['Memory', 'SemanticMemory', 'EpisodicMemory']), memory = MemoryModel(id=memory_id, user_id=user_id, methods_list=str(['Memory', 'SemanticMemory', 'EpisodicMemory']),
attributes_list=str(['user_id', 'index_name', 'db_type', 'knowledge_source', 'knowledge_type', 'memory_id', 'long_term_memory', 'short_term_memory', 'namespace'])) attributes_list=str(['user_id', 'index_name', 'db_type', 'knowledge_source', 'knowledge_type', 'memory_id', 'long_term_memory', 'short_term_memory', 'namespace']))
session.add(memory) await add_entity(session, memory)
session.commit()
return memory_id return memory_id
def add_memory_instance(self, memory_class_name: str): async def add_memory_instance(self, memory_class_name: str):
"""Add a new memory instance to the memory_instances list.""" """Add a new memory instance to the memory_instances list."""
instance = DynamicBaseMemory(memory_class_name, self.user_id, instance = DynamicBaseMemory(memory_class_name, self.user_id,
self.memory_id, self.index_name, self.memory_id, self.index_name,
@ -209,15 +216,26 @@ class Memory:
print("The following instance was defined", instance) print("The following instance was defined", instance)
self.memory_instances.append(instance) self.memory_instances.append(instance)
def manage_memory_attributes(self, existing_user): async def query_method(self):
methods_list = await self.session.execute(
select(MemoryModel.methods_list).where(MemoryModel.id == self.memory_id)
)
methods_list = methods_list.scalar_one_or_none()
return methods_list
async def manage_memory_attributes(self, existing_user):
"""Manage memory attributes based on the user existence.""" """Manage memory attributes based on the user existence."""
if existing_user: if existing_user:
print(f"ID before query: {self.memory_id}, type: {type(self.memory_id)}") print(f"ID before query: {self.memory_id}, type: {type(self.memory_id)}")
attributes_list = self.session.query(MemoryModel.attributes_list).filter_by(id=self.memory_id[0]).scalar()
# attributes_list = await self.session.query(MemoryModel.attributes_list).filter_by(id=self.memory_id[0]).scalar()
attributes_list = await self.query_method()
logging.info(f"Attributes list: {attributes_list}") logging.info(f"Attributes list: {attributes_list}")
if attributes_list is not None: if attributes_list is not None:
attributes_list = ast.literal_eval(attributes_list) attributes_list = ast.literal_eval(attributes_list)
self.handle_attributes(attributes_list) await self.handle_attributes(attributes_list)
else: else:
logging.warning("attributes_list is None!") logging.warning("attributes_list is None!")
else: else:
@ -225,20 +243,25 @@ class Memory:
'knowledge_source', 'knowledge_type', 'knowledge_source', 'knowledge_type',
'memory_id', 'long_term_memory', 'memory_id', 'long_term_memory',
'short_term_memory', 'namespace'] 'short_term_memory', 'namespace']
self.handle_attributes(attributes_list) await self.handle_attributes(attributes_list)
def handle_attributes(self, attributes_list): async def handle_attributes(self, attributes_list):
"""Handle attributes for existing memory instances.""" """Handle attributes for existing memory instances."""
for attr in attributes_list: for attr in attributes_list:
self.memory_class.add_attribute(attr) await self.memory_class.add_attribute(attr)
def manage_memory_methods(self, existing_user): async def manage_memory_methods(self, existing_user):
""" """
Manage memory methods based on the user existence. Manage memory methods based on the user existence.
""" """
if existing_user: if existing_user:
# Fetch existing methods from the database # Fetch existing methods from the database
methods_list = self.session.query(MemoryModel.methods_list).filter_by(id=self.memory_id).scalar() # methods_list = await self.session.query(MemoryModel.methods_list).filter_by(id=self.memory_id).scalar()
methods_list = await self.session.execute(
select(MemoryModel.methods_list).where(MemoryModel.id == self.memory_id[0])
)
methods_list = methods_list.scalar_one_or_none()
methods_list = ast.literal_eval(methods_list) methods_list = ast.literal_eval(methods_list)
else: else:
# Define default methods for a new user # Define default methods for a new user
@ -260,20 +283,20 @@ class Memory:
return await method(*args, **kwargs) return await method(*args, **kwargs)
raise AttributeError(f"{dynamic_base_memory_instance.name} object has no attribute {method_name}") raise AttributeError(f"{dynamic_base_memory_instance.name} object has no attribute {method_name}")
def add_dynamic_memory_class(self, class_name: str, namespace: str): async def add_dynamic_memory_class(self, class_name: str, namespace: str):
logging.info("Here is the memory id %s", self.memory_id[0]) logging.info("Here is the memory id %s", self.memory_id[0])
new_memory_class = DynamicBaseMemory(class_name, self.user_id, self.memory_id[0], self.index_name, new_memory_class = DynamicBaseMemory(class_name, self.user_id, self.memory_id[0], self.index_name,
self.db_type, namespace) self.db_type, namespace)
setattr(self, f"{class_name.lower()}_class", new_memory_class) setattr(self, f"{class_name.lower()}_class", new_memory_class)
return new_memory_class return new_memory_class
def add_attribute_to_class(self, class_instance, attribute_name: str): async def add_attribute_to_class(self, class_instance, attribute_name: str):
#add this to database for a particular user and load under memory id #add this to database for a particular user and load under memory id
class_instance.add_attribute(attribute_name) await class_instance.add_attribute(attribute_name)
def add_method_to_class(self, class_instance, method_name: str): async def add_method_to_class(self, class_instance, method_name: str):
#add this to database for a particular user and load under memory id #add this to database for a particular user and load under memory id
class_instance.add_method(method_name) await class_instance.add_method(method_name)
@ -297,35 +320,37 @@ async def main():
} }
loader_settings = { loader_settings = {
"format": "PDF", "format": "PDF",
"source": "url", "source": "URL",
"path": "https://www.ibiblio.org/ebooks/London/Call%20of%20Wild.pdf" "path": "https://www.ibiblio.org/ebooks/London/Call%20of%20Wild.pdf"
} }
# memory_instance = Memory(namespace='SEMANTICMEMORY') # memory_instance = Memory(namespace='SEMANTICMEMORY')
# sss = await memory_instance.dynamic_method_call(memory_instance.semantic_memory_class, 'fetch_memories', observation='some_observation') # sss = await memory_instance.dynamic_method_call(memory_instance.semantic_memory_class, 'fetch_memories', observation='some_observation')
#
# Create a DB session
Session = sessionmaker(bind=engine)
session = Session()
memory = Memory.create_memory("676", session, namespace='SEMANTICMEMORY')
# Adding a memory instance from database.database_crud import session_scope
memory.add_memory_instance("ExampleMemory") from database.database import AsyncSessionLocal
async with session_scope(AsyncSessionLocal()) as session:
memory = await Memory.create_memory("676", session, namespace='SEMANTICMEMORY')
# Adding a memory instance
await memory.add_memory_instance("ExampleMemory")
# Managing memory attributes # Managing memory attributes
existing_user = Memory.check_existing_user("676", session) existing_user = await Memory.check_existing_user("676", session)
print("here is the existing user", existing_user) print("here is the existing user", existing_user)
memory.manage_memory_attributes(existing_user) await memory.manage_memory_attributes(existing_user)
# aeehuvyq_semanticememory_class
memory.add_dynamic_memory_class('SemanticMemory', 'SEMANTICMEMORY') await memory.add_dynamic_memory_class('SemanticMemory', 'SEMANTICMEMORY')
memory.add_method_to_class(memory.semanticmemory_class, 'add_memories') await memory.add_method_to_class(memory.semanticmemory_class, 'add_memories')
memory.add_method_to_class(memory.semanticmemory_class, 'fetch_memories') await memory.add_method_to_class(memory.semanticmemory_class, 'fetch_memories')
sss = await memory.dynamic_method_call(memory.semanticmemory_class, 'add_memories', sss = await memory.dynamic_method_call(memory.semanticmemory_class, 'add_memories',
observation='some_observation', params=params) observation='some_observation', params=params, loader_settings=loader_settings)
susu = await memory.dynamic_method_call(memory.semanticmemory_class, 'fetch_memories', susu = await memory.dynamic_method_call(memory.semanticmemory_class, 'fetch_memories',
observation='some_observation') observation='some_observation')
print(susu) print(susu)
# Adding a dynamic memory class # Adding a dynamic memory class
# dynamic_memory = memory.add_dynamic_memory_class("DynamicMemory", "ExampleNamespace") # dynamic_memory = memory.add_dynamic_memory_class("DynamicMemory", "ExampleNamespace")
@ -337,8 +362,8 @@ async def main():
# print(sss) # print(sss)
# load_jack_london = await memory._add_semantic_memory(observation = "bla", loader_settings=loader_settings, params=params) load_jack_london = await memory._add_semantic_memory(observation = "bla", loader_settings=loader_settings, params=params)
# print(load_jack_london) print(load_jack_london)
modulator = {"relevance": 0.1, "frequency": 0.1} modulator = {"relevance": 0.1, "frequency": 0.1}