Merge pull request #23 from topoteretes/add_async_elements

Added async elements
This commit is contained in:
Vasilije 2023-10-12 13:24:22 +02:00 committed by GitHub
commit 6ba24d162a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 539 additions and 402 deletions

View file

@ -1,5 +1,5 @@
FROM python:3.11-slim
FROM python:3.11
# Set build argument
ARG API_ENABLED

View file

@ -25,27 +25,37 @@ After that, you can run:
```docker compose build promethai_mem ```
## Run
## Run the level 3
Make sure you have Docker, Poetry, and Python 3.11 installed and postgres installed.
Copy the .env.example to .env and fill the variables
Start the docker:
```docker compose up promethai_mem ```
Use the poetry environment:
``` poetry shell ```
Make sure to run
Make sure to run to initialize DB tables
``` python scripts/create_database.py ```
After that, you can run:
After that, you can run the RAG test manager.
``` python rag_test_manager.py \
```
python rag_test_manager.py \
--url "https://www.ibiblio.org/ebooks/London/Call%20of%20Wild.pdf" \
--test_set "example_data/test_set.json" \
--user_id "666" \
--metadata "example_data/metadata.json"
```
To see example of test_set.json and metadata.json, check the files in the folder "example_data"
Examples of metadata structure and test set are in the folder "example_data"
## Clean database

View file

@ -1,14 +1,15 @@
import os
from sqlalchemy import create_engine
from sqlalchemy.orm import declarative_base
from sqlalchemy.orm import sessionmaker
from contextlib import contextmanager
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import declarative_base, sessionmaker
from contextlib import asynccontextmanager
from sqlalchemy.exc import OperationalError
from time import sleep
import asyncio
import sys
from dotenv import load_dotenv
load_dotenv()
# this is needed to import classes from other modules
script_dir = os.path.dirname(os.path.abspath(__file__))
# Get the parent directory of your script and add it to sys.path
@ -25,35 +26,43 @@ password = os.getenv('POSTGRES_PASSWORD')
database_name = os.getenv('POSTGRES_DB')
host = os.getenv('POSTGRES_HOST')
# Use the asyncpg driver for async operation
SQLALCHEMY_DATABASE_URL = f"postgresql+asyncpg://{username}:{password}@{host}:5432/{database_name}"
SQLALCHEMY_DATABASE_URL = f"postgresql://{username}:{password}@{host}:5432/{database_name}"
engine = create_engine(
engine = create_async_engine(
SQLALCHEMY_DATABASE_URL,
pool_recycle=3600, # recycle connections after 1 hour
pool_pre_ping=True # test the connection for liveness upon each checkout
pool_recycle=3600,
echo=True # Enable logging for tutorial purposes
)
# Use AsyncSession for the session
AsyncSessionLocal = sessionmaker(
bind=engine,
class_=AsyncSession,
expire_on_commit=False,
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
@contextmanager
def get_db():
db = SessionLocal()
# Use asynccontextmanager to define an async context manager
@asynccontextmanager
async def get_db():
db = AsyncSessionLocal()
try:
yield db
finally:
db.close()
await db.close()
def safe_db_operation(db_op, *args, **kwargs):
# Use async/await syntax for the async function
async def safe_db_operation(db_op, *args, **kwargs):
for attempt in range(MAX_RETRIES):
with get_db() as db:
async with get_db() as db:
try:
return db_op(db, *args, **kwargs)
# Ensure your db_op is also async
return await db_op(db, *args, **kwargs)
except OperationalError as e:
db.rollback()
await db.rollback()
if "server closed the connection unexpectedly" in str(e) and attempt < MAX_RETRIES - 1:
sleep(RETRY_DELAY)
await asyncio.sleep(RETRY_DELAY)
else:
raise

View file

@ -1,15 +1,31 @@
from contextlib import asynccontextmanager
import asyncio
from contextlib import asynccontextmanager
import logging
# from database import AsyncSessionLocal
@contextmanager
def session_scope(session):
logger = logging.getLogger(__name__)
@asynccontextmanager
async def session_scope(session):
"""Provide a transactional scope around a series of operations."""
# session = AsyncSessionLocal()
try:
yield session
session.commit()
await session.commit()
except Exception as e:
session.rollback()
await session.rollback()
logger.error(f"Session rollback due to: {str(e)}")
raise
finally:
session.close()
await session.close()
async def add_entity(session, entity):
async with session_scope(session) as s: # Use your async session_scope
s.add(entity) # No need to commit; session_scope takes care of it
s.commit()
return "Successfully added entity"

191
level_3/poetry.lock generated
View file

@ -169,6 +169,59 @@ files = [
{file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"},
]
[[package]]
name = "asyncpg"
version = "0.28.0"
description = "An asyncio PostgreSQL driver"
optional = false
python-versions = ">=3.7.0"
files = [
{file = "asyncpg-0.28.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0a6d1b954d2b296292ddff4e0060f494bb4270d87fb3655dd23c5c6096d16d83"},
{file = "asyncpg-0.28.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0740f836985fd2bd73dca42c50c6074d1d61376e134d7ad3ad7566c4f79f8184"},
{file = "asyncpg-0.28.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e907cf620a819fab1737f2dd90c0f185e2a796f139ac7de6aa3212a8af96c050"},
{file = "asyncpg-0.28.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:86b339984d55e8202e0c4b252e9573e26e5afa05617ed02252544f7b3e6de3e9"},
{file = "asyncpg-0.28.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0c402745185414e4c204a02daca3d22d732b37359db4d2e705172324e2d94e85"},
{file = "asyncpg-0.28.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c88eef5e096296626e9688f00ab627231f709d0e7e3fb84bb4413dff81d996d7"},
{file = "asyncpg-0.28.0-cp310-cp310-win32.whl", hash = "sha256:90a7bae882a9e65a9e448fdad3e090c2609bb4637d2a9c90bfdcebbfc334bf89"},
{file = "asyncpg-0.28.0-cp310-cp310-win_amd64.whl", hash = "sha256:76aacdcd5e2e9999e83c8fbcb748208b60925cc714a578925adcb446d709016c"},
{file = "asyncpg-0.28.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a0e08fe2c9b3618459caaef35979d45f4e4f8d4f79490c9fa3367251366af207"},
{file = "asyncpg-0.28.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b24e521f6060ff5d35f761a623b0042c84b9c9b9fb82786aadca95a9cb4a893b"},
{file = "asyncpg-0.28.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:99417210461a41891c4ff301490a8713d1ca99b694fef05dabd7139f9d64bd6c"},
{file = "asyncpg-0.28.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f029c5adf08c47b10bcdc857001bbef551ae51c57b3110964844a9d79ca0f267"},
{file = "asyncpg-0.28.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ad1d6abf6c2f5152f46fff06b0e74f25800ce8ec6c80967f0bc789974de3c652"},
{file = "asyncpg-0.28.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d7fa81ada2807bc50fea1dc741b26a4e99258825ba55913b0ddbf199a10d69d8"},
{file = "asyncpg-0.28.0-cp311-cp311-win32.whl", hash = "sha256:f33c5685e97821533df3ada9384e7784bd1e7865d2b22f153f2e4bd4a083e102"},
{file = "asyncpg-0.28.0-cp311-cp311-win_amd64.whl", hash = "sha256:5e7337c98fb493079d686a4a6965e8bcb059b8e1b8ec42106322fc6c1c889bb0"},
{file = "asyncpg-0.28.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:1c56092465e718a9fdcc726cc3d9dcf3a692e4834031c9a9f871d92a75d20d48"},
{file = "asyncpg-0.28.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4acd6830a7da0eb4426249d71353e8895b350daae2380cb26d11e0d4a01c5472"},
{file = "asyncpg-0.28.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63861bb4a540fa033a56db3bb58b0c128c56fad5d24e6d0a8c37cb29b17c1c7d"},
{file = "asyncpg-0.28.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:a93a94ae777c70772073d0512f21c74ac82a8a49be3a1d982e3f259ab5f27307"},
{file = "asyncpg-0.28.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:d14681110e51a9bc9c065c4e7944e8139076a778e56d6f6a306a26e740ed86d2"},
{file = "asyncpg-0.28.0-cp37-cp37m-win32.whl", hash = "sha256:8aec08e7310f9ab322925ae5c768532e1d78cfb6440f63c078b8392a38aa636a"},
{file = "asyncpg-0.28.0-cp37-cp37m-win_amd64.whl", hash = "sha256:319f5fa1ab0432bc91fb39b3960b0d591e6b5c7844dafc92c79e3f1bff96abef"},
{file = "asyncpg-0.28.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b337ededaabc91c26bf577bfcd19b5508d879c0ad009722be5bb0a9dd30b85a0"},
{file = "asyncpg-0.28.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4d32b680a9b16d2957a0a3cc6b7fa39068baba8e6b728f2e0a148a67644578f4"},
{file = "asyncpg-0.28.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f4f62f04cdf38441a70f279505ef3b4eadf64479b17e707c950515846a2df197"},
{file = "asyncpg-0.28.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f20cac332c2576c79c2e8e6464791c1f1628416d1115935a34ddd7121bfc6a4"},
{file = "asyncpg-0.28.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:59f9712ce01e146ff71d95d561fb68bd2d588a35a187116ef05028675462d5ed"},
{file = "asyncpg-0.28.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fc9e9f9ff1aa0eddcc3247a180ac9e9b51a62311e988809ac6152e8fb8097756"},
{file = "asyncpg-0.28.0-cp38-cp38-win32.whl", hash = "sha256:9e721dccd3838fcff66da98709ed884df1e30a95f6ba19f595a3706b4bc757e3"},
{file = "asyncpg-0.28.0-cp38-cp38-win_amd64.whl", hash = "sha256:8ba7d06a0bea539e0487234511d4adf81dc8762249858ed2a580534e1720db00"},
{file = "asyncpg-0.28.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:d009b08602b8b18edef3a731f2ce6d3f57d8dac2a0a4140367e194eabd3de457"},
{file = "asyncpg-0.28.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ec46a58d81446d580fb21b376ec6baecab7288ce5a578943e2fc7ab73bf7eb39"},
{file = "asyncpg-0.28.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7b48ceed606cce9e64fd5480a9b0b9a95cea2b798bb95129687abd8599c8b019"},
{file = "asyncpg-0.28.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8858f713810f4fe67876728680f42e93b7e7d5c7b61cf2118ef9153ec16b9423"},
{file = "asyncpg-0.28.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:5e18438a0730d1c0c1715016eacda6e9a505fc5aa931b37c97d928d44941b4bf"},
{file = "asyncpg-0.28.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e9c433f6fcdd61c21a715ee9128a3ca48be8ac16fa07be69262f016bb0f4dbd2"},
{file = "asyncpg-0.28.0-cp39-cp39-win32.whl", hash = "sha256:41e97248d9076bc8e4849da9e33e051be7ba37cd507cbd51dfe4b2d99c70e3dc"},
{file = "asyncpg-0.28.0-cp39-cp39-win_amd64.whl", hash = "sha256:3ed77f00c6aacfe9d79e9eff9e21729ce92a4b38e80ea99a58ed382f42ebd55b"},
{file = "asyncpg-0.28.0.tar.gz", hash = "sha256:7252cdc3acb2f52feaa3664280d3bcd78a46bd6c10bfd681acfffefa1120e278"},
]
[package.extras]
docs = ["Sphinx (>=5.3.0,<5.4.0)", "sphinx-rtd-theme (>=1.2.2)", "sphinxcontrib-asyncio (>=0.3.0,<0.4.0)"]
test = ["flake8 (>=5.0,<6.0)", "uvloop (>=0.15.3)"]
[[package]]
name = "atlassian-python-api"
version = "3.41.2"
@ -685,13 +738,13 @@ pdf = ["pypdf (>=3.3.0,<4.0.0)"]
[[package]]
name = "deepeval"
version = "0.20.1"
version = "0.20.6"
description = "DeepEval provides evaluation and unit testing to accelerate development of LLMs and Agents."
optional = false
python-versions = "*"
files = [
{file = "deepeval-0.20.1-py3-none-any.whl", hash = "sha256:f9880a1246a2a8ba77d88b1d2f977759d34741df6d584bb3c55fadc95c52bc89"},
{file = "deepeval-0.20.1.tar.gz", hash = "sha256:e3e36745f5e77bc6055def0b98e7a3274c87564f498f50337b670a291fde32a5"},
{file = "deepeval-0.20.6-py3-none-any.whl", hash = "sha256:aa0b96fa062f63398858fe2af1c4982ee9e4d53cd3e322c7bbc3812fe1267614"},
{file = "deepeval-0.20.6.tar.gz", hash = "sha256:502c6bb8bc27069d4bbac171c2aac1a760ec8e5c11e3c87a7a8ed2a81ef21db6"},
]
[package.dependencies]
@ -702,6 +755,7 @@ pytest = "*"
requests = "*"
rich = "*"
sentence-transformers = "*"
sentry-sdk = "*"
tabulate = "*"
tqdm = "*"
transformers = "*"
@ -3519,6 +3573,51 @@ files = [
{file = "sentencepiece-0.1.99.tar.gz", hash = "sha256:189c48f5cb2949288f97ccdb97f0473098d9c3dcf5a3d99d4eabe719ec27297f"},
]
[[package]]
name = "sentry-sdk"
version = "1.32.0"
description = "Python client for Sentry (https://sentry.io)"
optional = false
python-versions = "*"
files = [
{file = "sentry-sdk-1.32.0.tar.gz", hash = "sha256:935e8fbd7787a3702457393b74b13d89a5afb67185bc0af85c00cb27cbd42e7c"},
{file = "sentry_sdk-1.32.0-py2.py3-none-any.whl", hash = "sha256:eeb0b3550536f3bbc05bb1c7e0feb3a78d74acb43b607159a606ed2ec0a33a4d"},
]
[package.dependencies]
certifi = "*"
urllib3 = {version = ">=1.26.11", markers = "python_version >= \"3.6\""}
[package.extras]
aiohttp = ["aiohttp (>=3.5)"]
arq = ["arq (>=0.23)"]
asyncpg = ["asyncpg (>=0.23)"]
beam = ["apache-beam (>=2.12)"]
bottle = ["bottle (>=0.12.13)"]
celery = ["celery (>=3)"]
chalice = ["chalice (>=1.16.0)"]
clickhouse-driver = ["clickhouse-driver (>=0.2.0)"]
django = ["django (>=1.8)"]
falcon = ["falcon (>=1.4)"]
fastapi = ["fastapi (>=0.79.0)"]
flask = ["blinker (>=1.1)", "flask (>=0.11)", "markupsafe"]
grpcio = ["grpcio (>=1.21.1)"]
httpx = ["httpx (>=0.16.0)"]
huey = ["huey (>=2)"]
loguru = ["loguru (>=0.5)"]
opentelemetry = ["opentelemetry-distro (>=0.35b0)"]
opentelemetry-experimental = ["opentelemetry-distro (>=0.40b0,<1.0)", "opentelemetry-instrumentation-aiohttp-client (>=0.40b0,<1.0)", "opentelemetry-instrumentation-django (>=0.40b0,<1.0)", "opentelemetry-instrumentation-fastapi (>=0.40b0,<1.0)", "opentelemetry-instrumentation-flask (>=0.40b0,<1.0)", "opentelemetry-instrumentation-requests (>=0.40b0,<1.0)", "opentelemetry-instrumentation-sqlite3 (>=0.40b0,<1.0)", "opentelemetry-instrumentation-urllib (>=0.40b0,<1.0)"]
pure-eval = ["asttokens", "executing", "pure-eval"]
pymongo = ["pymongo (>=3.1)"]
pyspark = ["pyspark (>=2.4.4)"]
quart = ["blinker (>=1.1)", "quart (>=0.16.1)"]
rq = ["rq (>=0.6)"]
sanic = ["sanic (>=0.8)"]
sqlalchemy = ["sqlalchemy (>=1.2)"]
starlette = ["starlette (>=0.19.1)"]
starlite = ["starlite (>=1.48)"]
tornado = ["tornado (>=5)"]
[[package]]
name = "setuptools"
version = "68.1.2"
@ -3816,52 +3915,52 @@ files = [
[[package]]
name = "sqlalchemy"
version = "2.0.20"
version = "2.0.21"
description = "Database Abstraction Library"
optional = false
python-versions = ">=3.7"
files = [
{file = "SQLAlchemy-2.0.20-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:759b51346aa388c2e606ee206c0bc6f15a5299f6174d1e10cadbe4530d3c7a98"},
{file = "SQLAlchemy-2.0.20-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1506e988ebeaaf316f183da601f24eedd7452e163010ea63dbe52dc91c7fc70e"},
{file = "SQLAlchemy-2.0.20-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5768c268df78bacbde166b48be788b83dddaa2a5974b8810af422ddfe68a9bc8"},
{file = "SQLAlchemy-2.0.20-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a3f0dd6d15b6dc8b28a838a5c48ced7455c3e1fb47b89da9c79cc2090b072a50"},
{file = "SQLAlchemy-2.0.20-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:243d0fb261f80a26774829bc2cee71df3222587ac789b7eaf6555c5b15651eed"},
{file = "SQLAlchemy-2.0.20-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6eb6d77c31e1bf4268b4d61b549c341cbff9842f8e115ba6904249c20cb78a61"},
{file = "SQLAlchemy-2.0.20-cp310-cp310-win32.whl", hash = "sha256:bcb04441f370cbe6e37c2b8d79e4af9e4789f626c595899d94abebe8b38f9a4d"},
{file = "SQLAlchemy-2.0.20-cp310-cp310-win_amd64.whl", hash = "sha256:d32b5ffef6c5bcb452723a496bad2d4c52b346240c59b3e6dba279f6dcc06c14"},
{file = "SQLAlchemy-2.0.20-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:dd81466bdbc82b060c3c110b2937ab65ace41dfa7b18681fdfad2f37f27acdd7"},
{file = "SQLAlchemy-2.0.20-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6fe7d61dc71119e21ddb0094ee994418c12f68c61b3d263ebaae50ea8399c4d4"},
{file = "SQLAlchemy-2.0.20-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e4e571af672e1bb710b3cc1a9794b55bce1eae5aed41a608c0401885e3491179"},
{file = "SQLAlchemy-2.0.20-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3364b7066b3c7f4437dd345d47271f1251e0cfb0aba67e785343cdbdb0fff08c"},
{file = "SQLAlchemy-2.0.20-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:1be86ccea0c965a1e8cd6ccf6884b924c319fcc85765f16c69f1ae7148eba64b"},
{file = "SQLAlchemy-2.0.20-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:1d35d49a972649b5080557c603110620a86aa11db350d7a7cb0f0a3f611948a0"},
{file = "SQLAlchemy-2.0.20-cp311-cp311-win32.whl", hash = "sha256:27d554ef5d12501898d88d255c54eef8414576f34672e02fe96d75908993cf53"},
{file = "SQLAlchemy-2.0.20-cp311-cp311-win_amd64.whl", hash = "sha256:411e7f140200c02c4b953b3dbd08351c9f9818d2bd591b56d0fa0716bd014f1e"},
{file = "SQLAlchemy-2.0.20-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:3c6aceebbc47db04f2d779db03afeaa2c73ea3f8dcd3987eb9efdb987ffa09a3"},
{file = "SQLAlchemy-2.0.20-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7d3f175410a6db0ad96b10bfbb0a5530ecd4fcf1e2b5d83d968dd64791f810ed"},
{file = "SQLAlchemy-2.0.20-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ea8186be85da6587456c9ddc7bf480ebad1a0e6dcbad3967c4821233a4d4df57"},
{file = "SQLAlchemy-2.0.20-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:c3d99ba99007dab8233f635c32b5cd24fb1df8d64e17bc7df136cedbea427897"},
{file = "SQLAlchemy-2.0.20-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:76fdfc0f6f5341987474ff48e7a66c3cd2b8a71ddda01fa82fedb180b961630a"},
{file = "SQLAlchemy-2.0.20-cp37-cp37m-win32.whl", hash = "sha256:d3793dcf5bc4d74ae1e9db15121250c2da476e1af8e45a1d9a52b1513a393459"},
{file = "SQLAlchemy-2.0.20-cp37-cp37m-win_amd64.whl", hash = "sha256:79fde625a0a55220d3624e64101ed68a059c1c1f126c74f08a42097a72ff66a9"},
{file = "SQLAlchemy-2.0.20-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:599ccd23a7146e126be1c7632d1d47847fa9f333104d03325c4e15440fc7d927"},
{file = "SQLAlchemy-2.0.20-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:1a58052b5a93425f656675673ef1f7e005a3b72e3f2c91b8acca1b27ccadf5f4"},
{file = "SQLAlchemy-2.0.20-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:79543f945be7a5ada9943d555cf9b1531cfea49241809dd1183701f94a748624"},
{file = "SQLAlchemy-2.0.20-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63e73da7fb030ae0a46a9ffbeef7e892f5def4baf8064786d040d45c1d6d1dc5"},
{file = "SQLAlchemy-2.0.20-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:3ce5e81b800a8afc870bb8e0a275d81957e16f8c4b62415a7b386f29a0cb9763"},
{file = "SQLAlchemy-2.0.20-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:cb0d3e94c2a84215532d9bcf10229476ffd3b08f481c53754113b794afb62d14"},
{file = "SQLAlchemy-2.0.20-cp38-cp38-win32.whl", hash = "sha256:8dd77fd6648b677d7742d2c3cc105a66e2681cc5e5fb247b88c7a7b78351cf74"},
{file = "SQLAlchemy-2.0.20-cp38-cp38-win_amd64.whl", hash = "sha256:6f8a934f9dfdf762c844e5164046a9cea25fabbc9ec865c023fe7f300f11ca4a"},
{file = "SQLAlchemy-2.0.20-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:26a3399eaf65e9ab2690c07bd5cf898b639e76903e0abad096cd609233ce5208"},
{file = "SQLAlchemy-2.0.20-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4cde2e1096cbb3e62002efdb7050113aa5f01718035ba9f29f9d89c3758e7e4e"},
{file = "SQLAlchemy-2.0.20-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d1b09ba72e4e6d341bb5bdd3564f1cea6095d4c3632e45dc69375a1dbe4e26ec"},
{file = "SQLAlchemy-2.0.20-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1b74eeafaa11372627ce94e4dc88a6751b2b4d263015b3523e2b1e57291102f0"},
{file = "SQLAlchemy-2.0.20-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:77d37c1b4e64c926fa3de23e8244b964aab92963d0f74d98cbc0783a9e04f501"},
{file = "SQLAlchemy-2.0.20-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:eefebcc5c555803065128401a1e224a64607259b5eb907021bf9b175f315d2a6"},
{file = "SQLAlchemy-2.0.20-cp39-cp39-win32.whl", hash = "sha256:3423dc2a3b94125094897118b52bdf4d37daf142cbcf26d48af284b763ab90e9"},
{file = "SQLAlchemy-2.0.20-cp39-cp39-win_amd64.whl", hash = "sha256:5ed61e3463021763b853628aef8bc5d469fe12d95f82c74ef605049d810f3267"},
{file = "SQLAlchemy-2.0.20-py3-none-any.whl", hash = "sha256:63a368231c53c93e2b67d0c5556a9836fdcd383f7e3026a39602aad775b14acf"},
{file = "SQLAlchemy-2.0.20.tar.gz", hash = "sha256:ca8a5ff2aa7f3ade6c498aaafce25b1eaeabe4e42b73e25519183e4566a16fc6"},
{file = "SQLAlchemy-2.0.21-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:1e7dc99b23e33c71d720c4ae37ebb095bebebbd31a24b7d99dfc4753d2803ede"},
{file = "SQLAlchemy-2.0.21-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:7f0c4ee579acfe6c994637527c386d1c22eb60bc1c1d36d940d8477e482095d4"},
{file = "SQLAlchemy-2.0.21-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3f7d57a7e140efe69ce2d7b057c3f9a595f98d0bbdfc23fd055efdfbaa46e3a5"},
{file = "SQLAlchemy-2.0.21-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ca38746eac23dd7c20bec9278d2058c7ad662b2f1576e4c3dbfcd7c00cc48fa"},
{file = "SQLAlchemy-2.0.21-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:3cf229704074bce31f7f47d12883afee3b0a02bb233a0ba45ddbfe542939cca4"},
{file = "SQLAlchemy-2.0.21-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:fb87f763b5d04a82ae84ccff25554ffd903baafba6698e18ebaf32561f2fe4aa"},
{file = "SQLAlchemy-2.0.21-cp310-cp310-win32.whl", hash = "sha256:89e274604abb1a7fd5c14867a412c9d49c08ccf6ce3e1e04fffc068b5b6499d4"},
{file = "SQLAlchemy-2.0.21-cp310-cp310-win_amd64.whl", hash = "sha256:e36339a68126ffb708dc6d1948161cea2a9e85d7d7b0c54f6999853d70d44430"},
{file = "SQLAlchemy-2.0.21-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:bf8eebccc66829010f06fbd2b80095d7872991bfe8415098b9fe47deaaa58063"},
{file = "SQLAlchemy-2.0.21-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b977bfce15afa53d9cf6a632482d7968477625f030d86a109f7bdfe8ce3c064a"},
{file = "SQLAlchemy-2.0.21-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6ff3dc2f60dbf82c9e599c2915db1526d65415be323464f84de8db3e361ba5b9"},
{file = "SQLAlchemy-2.0.21-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:44ac5c89b6896f4740e7091f4a0ff2e62881da80c239dd9408f84f75a293dae9"},
{file = "SQLAlchemy-2.0.21-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:87bf91ebf15258c4701d71dcdd9c4ba39521fb6a37379ea68088ce8cd869b446"},
{file = "SQLAlchemy-2.0.21-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:b69f1f754d92eb1cc6b50938359dead36b96a1dcf11a8670bff65fd9b21a4b09"},
{file = "SQLAlchemy-2.0.21-cp311-cp311-win32.whl", hash = "sha256:af520a730d523eab77d754f5cf44cc7dd7ad2d54907adeb3233177eeb22f271b"},
{file = "SQLAlchemy-2.0.21-cp311-cp311-win_amd64.whl", hash = "sha256:141675dae56522126986fa4ca713739d00ed3a6f08f3c2eb92c39c6dfec463ce"},
{file = "SQLAlchemy-2.0.21-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:7614f1eab4336df7dd6bee05bc974f2b02c38d3d0c78060c5faa4cd1ca2af3b8"},
{file = "SQLAlchemy-2.0.21-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d59cb9e20d79686aa473e0302e4a82882d7118744d30bb1dfb62d3c47141b3ec"},
{file = "SQLAlchemy-2.0.21-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a95aa0672e3065d43c8aa80080cdd5cc40fe92dc873749e6c1cf23914c4b83af"},
{file = "SQLAlchemy-2.0.21-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:8c323813963b2503e54d0944813cd479c10c636e3ee223bcbd7bd478bf53c178"},
{file = "SQLAlchemy-2.0.21-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:419b1276b55925b5ac9b4c7044e999f1787c69761a3c9756dec6e5c225ceca01"},
{file = "SQLAlchemy-2.0.21-cp37-cp37m-win32.whl", hash = "sha256:4615623a490e46be85fbaa6335f35cf80e61df0783240afe7d4f544778c315a9"},
{file = "SQLAlchemy-2.0.21-cp37-cp37m-win_amd64.whl", hash = "sha256:cca720d05389ab1a5877ff05af96551e58ba65e8dc65582d849ac83ddde3e231"},
{file = "SQLAlchemy-2.0.21-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b4eae01faee9f2b17f08885e3f047153ae0416648f8e8c8bd9bc677c5ce64be9"},
{file = "SQLAlchemy-2.0.21-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:3eb7c03fe1cd3255811cd4e74db1ab8dca22074d50cd8937edf4ef62d758cdf4"},
{file = "SQLAlchemy-2.0.21-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c2d494b6a2a2d05fb99f01b84cc9af9f5f93bf3e1e5dbdafe4bed0c2823584c1"},
{file = "SQLAlchemy-2.0.21-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b19ae41ef26c01a987e49e37c77b9ad060c59f94d3b3efdfdbf4f3daaca7b5fe"},
{file = "SQLAlchemy-2.0.21-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:fc6b15465fabccc94bf7e38777d665b6a4f95efd1725049d6184b3a39fd54880"},
{file = "SQLAlchemy-2.0.21-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:014794b60d2021cc8ae0f91d4d0331fe92691ae5467a00841f7130fe877b678e"},
{file = "SQLAlchemy-2.0.21-cp38-cp38-win32.whl", hash = "sha256:0268256a34806e5d1c8f7ee93277d7ea8cc8ae391f487213139018b6805aeaf6"},
{file = "SQLAlchemy-2.0.21-cp38-cp38-win_amd64.whl", hash = "sha256:73c079e21d10ff2be54a4699f55865d4b275fd6c8bd5d90c5b1ef78ae0197301"},
{file = "SQLAlchemy-2.0.21-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:785e2f2c1cb50d0a44e2cdeea5fd36b5bf2d79c481c10f3a88a8be4cfa2c4615"},
{file = "SQLAlchemy-2.0.21-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c111cd40910ffcb615b33605fc8f8e22146aeb7933d06569ac90f219818345ef"},
{file = "SQLAlchemy-2.0.21-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c9cba4e7369de663611ce7460a34be48e999e0bbb1feb9130070f0685e9a6b66"},
{file = "SQLAlchemy-2.0.21-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:50a69067af86ec7f11a8e50ba85544657b1477aabf64fa447fd3736b5a0a4f67"},
{file = "SQLAlchemy-2.0.21-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:ccb99c3138c9bde118b51a289d90096a3791658da9aea1754667302ed6564f6e"},
{file = "SQLAlchemy-2.0.21-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:513fd5b6513d37e985eb5b7ed89da5fd9e72354e3523980ef00d439bc549c9e9"},
{file = "SQLAlchemy-2.0.21-cp39-cp39-win32.whl", hash = "sha256:f9fefd6298433b6e9188252f3bff53b9ff0443c8fde27298b8a2b19f6617eeb9"},
{file = "SQLAlchemy-2.0.21-cp39-cp39-win_amd64.whl", hash = "sha256:2e617727fe4091cedb3e4409b39368f424934c7faa78171749f704b49b4bb4ce"},
{file = "SQLAlchemy-2.0.21-py3-none-any.whl", hash = "sha256:ea7da25ee458d8f404b93eb073116156fd7d8c2a776d8311534851f28277b4ce"},
{file = "SQLAlchemy-2.0.21.tar.gz", hash = "sha256:05b971ab1ac2994a14c56b35eaaa91f86ba080e9ad481b20d99d77f381bb6258"},
]
[package.dependencies]
@ -4778,4 +4877,4 @@ multidict = ">=4.0"
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "cf14b96576c57633fea1b0cc1f7266a5fb265d4ff4ce2d3479a036483082b3b4"
content-hash = "b3f795baef70806c8dab1058b644d95f3c36b1c2fae53a7e5784a633731bc268"

View file

@ -40,11 +40,13 @@ weaviate-client = "^3.22.1"
python-multipart = "^0.0.6"
deep-translator = "^1.11.4"
humanize = "^4.8.0"
deepeval = "^0.20.1"
deepeval = "^0.20.6"
pymupdf = "^1.23.3"
psycopg2 = "^2.9.8"
llama-index = "^0.8.39.post2"
llama-hub = "^0.0.34"
sqlalchemy = "^2.0.21"
asyncpg = "^0.28.0"

View file

@ -20,6 +20,8 @@ from database.database import engine
from vectorstore_manager import Memory
import uuid
from contextlib import contextmanager
from database.database import AsyncSessionLocal
from database.database_crud import session_scope
import random
import string
@ -30,103 +32,65 @@ dotenv.load_dotenv()
import openai
logger = logging.getLogger(__name__)
openai.api_key = os.getenv("OPENAI_API_KEY", "")
@contextmanager
def session_scope(session):
"""Provide a transactional scope around a series of operations."""
async def retrieve_latest_test_case(session, user_id, memory_id):
try:
yield session
session.commit()
except Exception as e:
session.rollback()
logger.error(f"Session rollback due to: {str(e)}")
raise
finally:
session.close()
def retrieve_latest_test_case(session, user_id, memory_id):
"""
Retrieve the most recently created test case from the database filtered by user_id and memory_id.
Parameters:
- session (Session): A database session.
- user_id (int/str): The ID of the user to filter test cases by.
- memory_id (int/str): The ID of the memory to filter test cases by.
Returns:
- Object: The most recent test case attributes filtered by user_id and memory_id, or None if an error occurs.
"""
try:
return (
# Use await with session.execute() and row.fetchone() or row.all() for async query execution
result = await session.execute(
session.query(TestSet.attributes_list)
.filter_by(user_id=user_id, memory_id=memory_id)
.order_by(TestSet.created_at.desc())
.first()
.order_by(TestSet.created_at).first()
)
return result.scalar_one_or_none() # scalar_one_or_none() is a non-blocking call
except Exception as e:
logger.error(f"An error occurred while retrieving the latest test case: {str(e)}")
return None
def add_entity(session, entity):
"""
Add an entity (like TestOutput, Session, etc.) to the database.
Parameters:
- session (Session): A database session.
- entity (Base): An instance of an SQLAlchemy model.
Returns:
- str: A message indicating whether the addition was successful.
"""
with session_scope(session):
session.add(entity)
session.commit()
async def add_entity(session, entity):
async with session_scope(session) as s: # Use your async session_scope
s.add(entity) # No need to commit; session_scope takes care of it
s.commit()
return "Successfully added entity"
def retrieve_job_by_id(session, user_id, job_id):
"""
Retrieve a job by user ID and job ID.
Parameters:
- session (Session): A database session.
- user_id (int/str): The ID of the user.
- job_id (int/str): The ID of the job to retrieve.
Returns:
- Object: The job attributes filtered by user_id and job_id, or None if an error occurs.
"""
async def retrieve_job_by_id(session, user_id, job_id):
try:
return (
result = await session.execute(
session.query(Session.id)
.filter_by(user_id=user_id, id=job_id)
.order_by(Session.created_at.desc())
.first()
.order_by(Session.created_at)
)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"An error occurred while retrieving the job: {str(e)}")
return None
def fetch_job_id(session, user_id=None, memory_id=None, job_id=None):
async def fetch_job_id(session, user_id=None, memory_id=None, job_id=None):
try:
return (
result = await session.execute(
session.query(Session.id)
.filter_by(user_id=user_id, id=job_id)
.order_by(Session.created_at.desc())
.first()
.order_by(Session.created_at).first()
)
return result.scalar_one_or_none()
except Exception as e:
# Handle exceptions as per your application's requirements.
print(f"An error occurred: {str(e)}")
logger.error(f"An error occurred: {str(e)}")
return None
def compare_output(output, expected_output):
"""Compare the output against the expected output."""
pass
async def fetch_test_set_id(session, user_id, id):
try:
# Await the execution of the query and fetching of the result
result = await session.execute(
session.query(TestSet.id)
.filter_by(user_id=user_id, id=id)
.order_by(TestSet.created_at).desc().first()
)
return result.scalar_one_or_none() # scalar_one_or_none() is a non-blocking call
except Exception as e:
logger.error(f"An error occurred while retrieving the test set: {str(e)}")
return None
# Adding "embeddings" to the parameter variants function
def generate_param_variants(base_params=None, increments=None, ranges=None, included_params=None):
"""Generate parameter variants for testing.
@ -142,38 +106,39 @@ def generate_param_variants(base_params=None, increments=None, ranges=None, incl
list: A list of dictionaries containing parameter variants.
"""
# Default base values
# Default values
defaults = {
'chunk_size': 500,
'chunk_size': 250,
'chunk_overlap': 20,
'similarity_score': 0.5,
'metadata_variation': 0,
'search_type': 'hybrid'
'search_type': 'hybrid',
'embeddings': 'openai' # Default value added for 'embeddings'
}
# Update defaults with provided base parameters
params = {**defaults, **(base_params if base_params is not None else {})}
params = {**defaults, **(base_params or {})}
default_increments = {
'chunk_size': 500,
'chunk_size': 150,
'chunk_overlap': 10,
'similarity_score': 0.1,
'metadata_variation': 1
}
# Update default increments with provided increments
increments = {**default_increments, **(increments if increments is not None else {})}
increments = {**default_increments, **(increments or {})}
# Default ranges
default_ranges = {
'chunk_size': 3,
'chunk_overlap': 3,
'similarity_score': 3,
'metadata_variation': 3
'chunk_size': 2,
'chunk_overlap': 2,
'similarity_score': 2,
'metadata_variation': 2
}
# Update default ranges with provided ranges
ranges = {**default_ranges, **(ranges if ranges is not None else {})}
ranges = {**default_ranges, **(ranges or {})}
# Generate parameter variant ranges
param_ranges = {
@ -181,10 +146,9 @@ def generate_param_variants(base_params=None, increments=None, ranges=None, incl
for key in ['chunk_size', 'chunk_overlap', 'similarity_score', 'metadata_variation']
}
param_ranges['cognitive_architecture'] = ["simple_index", "cognitive_architecture"]
# Add search_type with possible values
# Add search_type and embeddings with possible values
param_ranges['search_type'] = ['text', 'hybrid', 'bm25', 'generate', 'generate_grouped']
param_ranges['embeddings'] = ['openai', 'cohere', 'huggingface'] # Added 'embeddings' values
# Filter param_ranges based on included_params
if included_params is not None:
@ -197,6 +161,10 @@ def generate_param_variants(base_params=None, increments=None, ranges=None, incl
return param_variants
# Generate parameter variants and display a sample of the generated combinations
async def generate_chatgpt_output(query:str, context:str=None):
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
@ -278,195 +246,200 @@ def generate_letter_uuid(length=8):
letters = string.ascii_uppercase # A-Z
return ''.join(random.choice(letters) for _ in range(length))
async def start_test(data, test_set=None, user_id=None, params=None, job_id=None, metadata=None):
def fetch_test_set_id(session, user_id, id):
try:
return (
session.query(TestSet.id)
.filter_by(user_id=user_id, id=id)
.order_by(TestSet.created_at)
.desc().first()
async with session_scope(session=AsyncSessionLocal()) as session:
memory = await Memory.create_memory(user_id, session, namespace="SEMANTICMEMORY")
job_id = await fetch_job_id(session, user_id=user_id, job_id=job_id)
test_set_id = await fetch_test_set_id(session, user_id=user_id, id=job_id)
if job_id is None:
job_id = str(uuid.uuid4())
await add_entity(session, Operation(id=job_id, user_id=user_id))
if test_set_id is None:
test_set_id = str(uuid.uuid4())
await add_entity(session, TestSet(id=test_set_id, user_id=user_id, content=str(test_set)))
if params is None:
data_format = data_format_route(data) # Assume data_format_route is predefined
data_location = data_location_route(data) # Assume data_location_route is predefined
test_params = generate_param_variants(
included_params=['chunk_size', 'chunk_overlap', 'search_type'])
print("Here are the test params", str(test_params))
loader_settings = {
"format": f"{data_format}",
"source": f"{data_location}",
"path": "https://www.ibiblio.org/ebooks/London/Call%20of%20Wild.pdf"
}
async def run_test(test, loader_settings, metadata):
test_id = str(generate_letter_uuid()) + "_" + "SEMANTICEMEMORY"
await memory.add_memory_instance("ExampleMemory")
existing_user = await Memory.check_existing_user(user_id, session)
await memory.manage_memory_attributes(existing_user)
test_class = test_id + "_class"
await memory.add_dynamic_memory_class(test_id.lower(), test_id)
dynamic_memory_class = getattr(memory, test_class.lower(), None)
# Assuming add_method_to_class and dynamic_method_call are predefined methods in Memory class
if dynamic_memory_class is not None:
await memory.add_method_to_class(dynamic_memory_class, 'add_memories')
else:
print(f"No attribute named {test_class.lower()} in memory.")
if dynamic_memory_class is not None:
await memory.add_method_to_class(dynamic_memory_class, 'fetch_memories')
else:
print(f"No attribute named {test_class.lower()} in memory.")
print(f"Trying to access: {test_class.lower()}")
print("Available memory classes:", await memory.list_memory_classes())
print(f"Trying to check: ", test)
loader_settings.update(test)
async def run_load_test_element(test, loader_settings, metadata, test_id):
test_class = test_id + "_class"
# memory.test_class
await memory.add_dynamic_memory_class(test_id.lower(), test_id)
dynamic_memory_class = getattr(memory, test_class.lower(), None)
if dynamic_memory_class is not None:
await memory.add_method_to_class(dynamic_memory_class, 'add_memories')
else:
print(f"No attribute named {test_class.lower()} in memory.")
if dynamic_memory_class is not None:
await memory.add_method_to_class(dynamic_memory_class, 'fetch_memories')
else:
print(f"No attribute named {test_class.lower()} in memory.")
print(f"Trying to access: {test_class.lower()}")
print("Available memory classes:", await memory.list_memory_classes())
print(f"Trying to check: ", test)
# print("Here is the loader settings", str(loader_settings))
# print("Here is the medatadata", str(metadata))
load_action = await memory.dynamic_method_call(dynamic_memory_class, 'add_memories',
observation='some_observation', params=metadata,
loader_settings=loader_settings)
async def run_search_eval_element(test_item, test_id):
test_class = test_id + "_class"
await memory.add_dynamic_memory_class(test_id.lower(), test_id)
dynamic_memory_class = getattr(memory, test_class.lower(), None)
retrieve_action = await memory.dynamic_method_call(dynamic_memory_class, 'fetch_memories',
observation=test_item["question"],
search_type=test_item["search_type"])
test_result = await eval_test(query=test_item["question"], expected_output=test_item["answer"],
context=str(retrieve_action))
print(test_result)
delete_mems = await memory.dynamic_method_call(dynamic_memory_class, 'delete_memories',
namespace=test_id)
return test_result
test_load_pipeline = await asyncio.gather(
*(run_load_test_element(test_item,loader_settings, metadata, test_id) for test_item in test_set)
)
test_eval_pipeline = await asyncio.gather(
*(run_search_eval_element(test_item, test_id) for test_item in test_set)
)
logging.info("Results of the eval pipeline %s", str(test_eval_pipeline))
await add_entity(session, TestOutput(id=test_id, user_id=user_id, test_results=str(test_eval_pipeline)))
# # Gather and run all tests in parallel
results = await asyncio.gather(
*(run_test(test, loader_settings, metadata) for test in test_params)
)
except Exception as e:
logger.error(f"An error occurred while retrieving the job: {str(e)}")
return None
async def start_test(data, test_set=None, user_id=None, params=None, job_id=None ,metadata=None):
Session = sessionmaker(bind=engine)
session = Session()
#do i need namespace in memory instance, fix it
memory = Memory.create_memory(user_id, session, namespace="SEMANTICMEMORY")
job_id = fetch_job_id(session, user_id = user_id,job_id =job_id)
test_set_id = fetch_test_set_id(session, user_id=user_id, id=job_id)
if job_id is None:
job_id = str(uuid.uuid4())
logging.info("we are adding a new job ID")
add_entity(session, Operation(id = job_id, user_id = user_id))
if test_set_id is None:
test_set_id = str(uuid.uuid4())
add_entity(session, TestSet(id = test_set_id, user_id = user_id, content = str(test_set)))
if params is None:
data_format = data_format_route(data)
data_location = data_location_route(data)
test_params = generate_param_variants( included_params=['chunk_size', 'chunk_overlap', 'similarity_score'])
loader_settings = {
"format": f"{data_format}",
"source": f"{data_location}",
"path": "https://www.ibiblio.org/ebooks/London/Call%20of%20Wild.pdf"
}
for test in test_params:
test_id = str(generate_letter_uuid()) + "_" + "SEMANTICEMEMORY"
# Adding a memory instance
memory.add_memory_instance("ExampleMemory")
# Managing memory attributes
existing_user = Memory.check_existing_user(user_id, session)
print("here is the existing user", existing_user)
memory.manage_memory_attributes(existing_user)
test_class = test_id + "_class"
# memory.test_class
memory.add_dynamic_memory_class(test_id.lower(), test_id)
dynamic_memory_class = getattr(memory, test_class.lower(), None)
if dynamic_memory_class is not None:
memory.add_method_to_class(dynamic_memory_class, 'add_memories')
else:
print(f"No attribute named {test_class.lower()} in memory.")
if dynamic_memory_class is not None:
memory.add_method_to_class(dynamic_memory_class, 'fetch_memories')
else:
print(f"No attribute named {test_class.lower()} in memory.")
print(f"Trying to access: {test_class.lower()}")
print("Available memory classes:", memory.list_memory_classes())
print(f"Trying to check: ", test)
loader_settings.update(test)
load_action = await memory.dynamic_method_call(dynamic_memory_class, 'add_memories',
observation='some_observation', params=metadata, loader_settings=loader_settings)
loader_settings = {key: value for key, value in loader_settings.items() if key not in test}
test_result_collection =[]
for test in test_set:
retrieve_action = await memory.dynamic_method_call(dynamic_memory_class, 'fetch_memories',
observation=test["question"])
test_results = await eval_test( query=test["question"], expected_output=test["answer"], context= str(retrieve_action))
test_result_collection.append(test_results)
print(test_results)
if dynamic_memory_class is not None:
memory.add_method_to_class(dynamic_memory_class, 'delete_memories')
else:
print(f"No attribute named {test_class.lower()} in memory.")
delete_mems = await memory.dynamic_method_call(dynamic_memory_class, 'delete_memories',
namespace =test_id)
print(test_result_collection)
add_entity(session, TestOutput(id=test_id, user_id=user_id, test_results=str(test_result_collection)))
return results
async def main():
params = {
"version": "1.0",
"agreement_id": "AG123456",
"privacy_policy": "https://example.com/privacy",
"terms_of_service": "https://example.com/terms",
"format": "json",
"schema_version": "1.1",
"checksum": "a1b2c3d4e5f6",
"owner": "John Doe",
"license": "MIT",
"validity_start": "2023-08-01",
"validity_end": "2024-07-31",
}
test_set = [
{
"question": "Who is the main character in 'The Call of the Wild'?",
"answer": "Buck"
},
{
"question": "Who wrote 'The Call of the Wild'?",
"answer": "Jack London"
},
{
"question": "Where does Buck live at the start of the book?",
"answer": "In the Santa Clara Valley, at Judge Millers place."
},
{
"question": "Why is Buck kidnapped?",
"answer": "He is kidnapped to be sold as a sled dog in the Yukon during the Klondike Gold Rush."
},
{
"question": "How does Buck become the leader of the sled dog team?",
"answer": "Buck becomes the leader after defeating the original leader, Spitz, in a fight."
}
]
result = await start_test("https://www.ibiblio.org/ebooks/London/Call%20of%20Wild.pdf", test_set=test_set, user_id="666", params=None, metadata=params)
#
# params = {
# "version": "1.0",
# "agreement_id": "AG123456",
# "privacy_policy": "https://example.com/privacy",
# "terms_of_service": "https://example.com/terms",
# "format": "json",
# "schema_version": "1.1",
# "checksum": "a1b2c3d4e5f6",
# "owner": "John Doe",
# "license": "MIT",
# "validity_start": "2023-08-01",
# "validity_end": "2024-07-31",
# }
# parser = argparse.ArgumentParser(description="Run tests against a document.")
# parser.add_argument("--url", required=True, help="URL of the document to test.")
# parser.add_argument("--test_set", required=True, help="Path to JSON file containing the test set.")
# parser.add_argument("--user_id", required=True, help="User ID.")
# parser.add_argument("--params", help="Additional parameters in JSON format.")
# parser.add_argument("--metadata", required=True, help="Path to JSON file containing metadata.")
#
# test_set = [
# {
# "question": "Who is the main character in 'The Call of the Wild'?",
# "answer": "Buck"
# },
# {
# "question": "Who wrote 'The Call of the Wild'?",
# "answer": "Jack London"
# },
# {
# "question": "Where does Buck live at the start of the book?",
# "answer": "In the Santa Clara Valley, at Judge Millers place."
# },
# {
# "question": "Why is Buck kidnapped?",
# "answer": "He is kidnapped to be sold as a sled dog in the Yukon during the Klondike Gold Rush."
# },
# {
# "question": "How does Buck become the leader of the sled dog team?",
# "answer": "Buck becomes the leader after defeating the original leader, Spitz, in a fight."
# }
# ]
# result = await start_test("https://www.ibiblio.org/ebooks/London/Call%20of%20Wild.pdf", test_set=test_set, user_id="666", params=None, metadata=params)
# args = parser.parse_args()
#
parser = argparse.ArgumentParser(description="Run tests against a document.")
parser.add_argument("--url", required=True, help="URL of the document to test.")
parser.add_argument("--test_set", required=True, help="Path to JSON file containing the test set.")
parser.add_argument("--user_id", required=True, help="User ID.")
parser.add_argument("--params", help="Additional parameters in JSON format.")
parser.add_argument("--metadata", required=True, help="Path to JSON file containing metadata.")
args = parser.parse_args()
try:
with open(args.test_set, "r") as file:
test_set = json.load(file)
if not isinstance(test_set, list): # Expecting a list
raise TypeError("Parsed test_set JSON is not a list.")
except Exception as e:
print(f"Error loading test_set: {str(e)}")
return
try:
with open(args.metadata, "r") as file:
metadata = json.load(file)
if not isinstance(metadata, dict):
raise TypeError("Parsed metadata JSON is not a dictionary.")
except Exception as e:
print(f"Error loading metadata: {str(e)}")
return
if args.params:
try:
params = json.loads(args.params)
if not isinstance(params, dict):
raise TypeError("Parsed params JSON is not a dictionary.")
except json.JSONDecodeError as e:
print(f"Error parsing params: {str(e)}")
return
else:
params = None
#clean up params here
await start_test(args.url, test_set, args.user_id, params=None, metadata=metadata)
# try:
# with open(args.test_set, "r") as file:
# test_set = json.load(file)
# if not isinstance(test_set, list): # Expecting a list
# raise TypeError("Parsed test_set JSON is not a list.")
# except Exception as e:
# print(f"Error loading test_set: {str(e)}")
# return
#
# try:
# with open(args.metadata, "r") as file:
# metadata = json.load(file)
# if not isinstance(metadata, dict):
# raise TypeError("Parsed metadata JSON is not a dictionary.")
# except Exception as e:
# print(f"Error loading metadata: {str(e)}")
# return
#
# if args.params:
# try:
# params = json.loads(args.params)
# if not isinstance(params, dict):
# raise TypeError("Parsed params JSON is not a dictionary.")
# except json.JSONDecodeError as e:
# print(f"Error parsing params: {str(e)}")
# return
# else:
# params = None
# #clean up params here
# await start_test(args.url, test_set, args.user_id, params=None, metadata=metadata)
if __name__ == "__main__":
import asyncio

View file

@ -179,17 +179,17 @@ class WeaviateVectorDB(VectorDB):
["id", "creationTimeUnix", "lastUpdateTimeUnix", "score", 'distance']
).with_where(params_user_id).with_limit(10)
n_of_observations = kwargs.get('n_of_observations', 2)
try:
if search_type == 'text':
query_output = (
base_query
.with_near_text({"concepts": [observation]})
.with_autocut(n_of_observations)
.do()
)
elif search_type == 'hybrid':
n_of_observations = kwargs.get('n_of_observations', 2)
query_output = (
base_query
.with_hybrid(query=observation, fusion_type=HybridFusion.RELATIVE_SCORE)
@ -200,6 +200,7 @@ class WeaviateVectorDB(VectorDB):
query_output = (
base_query
.with_bm25(query=observation)
.with_autocut(n_of_observations)
.do()
)
elif search_type == 'generate':
@ -208,6 +209,7 @@ class WeaviateVectorDB(VectorDB):
base_query
.with_generate(single_prompt=generate_prompt)
.with_near_text({"concepts": [observation]})
.with_autocut(n_of_observations)
.do()
)
elif search_type == 'generate_grouped':
@ -216,6 +218,7 @@ class WeaviateVectorDB(VectorDB):
base_query
.with_generate(grouped_task=generate_prompt)
.with_near_text({"concepts": [observation]})
.with_autocut(n_of_observations)
.do()
)
else:

View file

@ -1,5 +1,7 @@
import logging
from sqlalchemy.future import select
logging.basicConfig(level=logging.INFO)
import marvin
from dotenv import load_dotenv
@ -15,6 +17,7 @@ from models.operation import Operation
load_dotenv()
import ast
import tracemalloc
from database.database_crud import session_scope, add_entity
tracemalloc.start()
@ -48,7 +51,7 @@ class DynamicBaseMemory(BaseMemory):
self.associations = []
def add_method(self, method_name):
async def add_method(self, method_name):
"""
Add a method to the memory class.
@ -60,7 +63,7 @@ class DynamicBaseMemory(BaseMemory):
"""
self.methods.add(method_name)
def add_attribute(self, attribute_name):
async def add_attribute(self, attribute_name):
"""
Add an attribute to the memory class.
@ -72,7 +75,7 @@ class DynamicBaseMemory(BaseMemory):
"""
self.attributes.add(attribute_name)
def get_attribute(self, attribute_name):
async def get_attribute(self, attribute_name):
"""
Check if the attribute is in the memory class.
@ -84,7 +87,7 @@ class DynamicBaseMemory(BaseMemory):
"""
return attribute_name in self.attributes
def add_association(self, associated_memory):
async def add_association(self, associated_memory):
"""
Add an association to another memory class.
@ -149,26 +152,26 @@ class Memory:
self.OPENAI_TEMPERATURE = float(os.getenv("OPENAI_TEMPERATURE", 0.0))
self.OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
@classmethod
def create_memory(cls, user_id: str, session, **kwargs):
async def create_memory(cls, user_id: str, session, **kwargs):
"""
Class method that acts as a factory method for creating Memory instances.
It performs necessary DB checks or updates before instance creation.
"""
existing_user = cls.check_existing_user(user_id, session)
existing_user = await cls.check_existing_user(user_id, session)
if existing_user:
# Handle existing user scenario...
memory_id = cls.check_existing_memory(user_id, session)
memory_id = await cls.check_existing_memory(user_id, session)
logging.info(f"Existing user {user_id} found in the DB. Memory ID: {memory_id}")
else:
# Handle new user scenario...
memory_id = cls.handle_new_user(user_id, session)
memory_id = await cls.handle_new_user(user_id, session)
logging.info(f"New user {user_id} created in the DB. Memory ID: {memory_id}")
return cls(user_id=user_id, session=session, memory_id=memory_id, **kwargs)
def list_memory_classes(self):
async def list_memory_classes(self):
"""
Lists all available memory classes in the memory instance.
"""
@ -176,32 +179,36 @@ class Memory:
return [attr for attr in dir(self) if attr.endswith("_class")]
@staticmethod
def check_existing_user(user_id: str, session):
async def check_existing_user(user_id: str, session):
"""Check if a user exists in the DB and return it."""
return session.query(User).filter_by(id=user_id).first()
result = await session.execute(
select(User).where(User.id == user_id)
)
return result.scalar_one_or_none()
@staticmethod
def check_existing_memory(user_id: str, session):
async def check_existing_memory(user_id: str, session):
"""Check if a user memory exists in the DB and return it."""
return session.query(MemoryModel.id).filter_by(user_id=user_id).first()
result = await session.execute(
select(MemoryModel.id).where(MemoryModel.user_id == user_id)
)
return result.scalar_one_or_none()
@staticmethod
def handle_new_user(user_id: str, session):
async def handle_new_user(user_id: str, session):
"""Handle new user creation in the DB and return the new memory ID."""
#handle these better in terms of retry and error handling
memory_id = str(uuid.uuid4())
new_user = User(id=user_id)
session.add(new_user)
session.commit()
await add_entity(session, new_user)
memory = MemoryModel(id=memory_id, user_id=user_id, methods_list=str(['Memory', 'SemanticMemory', 'EpisodicMemory']),
attributes_list=str(['user_id', 'index_name', 'db_type', 'knowledge_source', 'knowledge_type', 'memory_id', 'long_term_memory', 'short_term_memory', 'namespace']))
session.add(memory)
session.commit()
await add_entity(session, memory)
return memory_id
def add_memory_instance(self, memory_class_name: str):
async def add_memory_instance(self, memory_class_name: str):
"""Add a new memory instance to the memory_instances list."""
instance = DynamicBaseMemory(memory_class_name, self.user_id,
self.memory_id, self.index_name,
@ -209,15 +216,26 @@ class Memory:
print("The following instance was defined", instance)
self.memory_instances.append(instance)
def manage_memory_attributes(self, existing_user):
async def query_method(self):
methods_list = await self.session.execute(
select(MemoryModel.methods_list).where(MemoryModel.id == self.memory_id)
)
methods_list = methods_list.scalar_one_or_none()
return methods_list
async def manage_memory_attributes(self, existing_user):
"""Manage memory attributes based on the user existence."""
if existing_user:
print(f"ID before query: {self.memory_id}, type: {type(self.memory_id)}")
attributes_list = self.session.query(MemoryModel.attributes_list).filter_by(id=self.memory_id[0]).scalar()
# attributes_list = await self.session.query(MemoryModel.attributes_list).filter_by(id=self.memory_id[0]).scalar()
attributes_list = await self.query_method()
logging.info(f"Attributes list: {attributes_list}")
if attributes_list is not None:
attributes_list = ast.literal_eval(attributes_list)
self.handle_attributes(attributes_list)
await self.handle_attributes(attributes_list)
else:
logging.warning("attributes_list is None!")
else:
@ -225,20 +243,25 @@ class Memory:
'knowledge_source', 'knowledge_type',
'memory_id', 'long_term_memory',
'short_term_memory', 'namespace']
self.handle_attributes(attributes_list)
await self.handle_attributes(attributes_list)
def handle_attributes(self, attributes_list):
async def handle_attributes(self, attributes_list):
"""Handle attributes for existing memory instances."""
for attr in attributes_list:
self.memory_class.add_attribute(attr)
await self.memory_class.add_attribute(attr)
def manage_memory_methods(self, existing_user):
async def manage_memory_methods(self, existing_user):
"""
Manage memory methods based on the user existence.
"""
if existing_user:
# Fetch existing methods from the database
methods_list = self.session.query(MemoryModel.methods_list).filter_by(id=self.memory_id).scalar()
# methods_list = await self.session.query(MemoryModel.methods_list).filter_by(id=self.memory_id).scalar()
methods_list = await self.session.execute(
select(MemoryModel.methods_list).where(MemoryModel.id == self.memory_id[0])
)
methods_list = methods_list.scalar_one_or_none()
methods_list = ast.literal_eval(methods_list)
else:
# Define default methods for a new user
@ -260,20 +283,20 @@ class Memory:
return await method(*args, **kwargs)
raise AttributeError(f"{dynamic_base_memory_instance.name} object has no attribute {method_name}")
def add_dynamic_memory_class(self, class_name: str, namespace: str):
async def add_dynamic_memory_class(self, class_name: str, namespace: str):
logging.info("Here is the memory id %s", self.memory_id[0])
new_memory_class = DynamicBaseMemory(class_name, self.user_id, self.memory_id[0], self.index_name,
self.db_type, namespace)
setattr(self, f"{class_name.lower()}_class", new_memory_class)
return new_memory_class
def add_attribute_to_class(self, class_instance, attribute_name: str):
async def add_attribute_to_class(self, class_instance, attribute_name: str):
#add this to database for a particular user and load under memory id
class_instance.add_attribute(attribute_name)
await class_instance.add_attribute(attribute_name)
def add_method_to_class(self, class_instance, method_name: str):
async def add_method_to_class(self, class_instance, method_name: str):
#add this to database for a particular user and load under memory id
class_instance.add_method(method_name)
await class_instance.add_method(method_name)
@ -297,35 +320,37 @@ async def main():
}
loader_settings = {
"format": "PDF",
"source": "url",
"source": "URL",
"path": "https://www.ibiblio.org/ebooks/London/Call%20of%20Wild.pdf"
}
# memory_instance = Memory(namespace='SEMANTICMEMORY')
# sss = await memory_instance.dynamic_method_call(memory_instance.semantic_memory_class, 'fetch_memories', observation='some_observation')
#
# Create a DB session
Session = sessionmaker(bind=engine)
session = Session()
memory = Memory.create_memory("676", session, namespace='SEMANTICMEMORY')
# Adding a memory instance
memory.add_memory_instance("ExampleMemory")
from database.database_crud import session_scope
from database.database import AsyncSessionLocal
async with session_scope(AsyncSessionLocal()) as session:
memory = await Memory.create_memory("676", session, namespace='SEMANTICMEMORY')
# Adding a memory instance
await memory.add_memory_instance("ExampleMemory")
# Managing memory attributes
existing_user = Memory.check_existing_user("676", session)
print("here is the existing user", existing_user)
memory.manage_memory_attributes(existing_user)
# Managing memory attributes
existing_user = await Memory.check_existing_user("676", session)
print("here is the existing user", existing_user)
await memory.manage_memory_attributes(existing_user)
# aeehuvyq_semanticememory_class
memory.add_dynamic_memory_class('SemanticMemory', 'SEMANTICMEMORY')
memory.add_method_to_class(memory.semanticmemory_class, 'add_memories')
memory.add_method_to_class(memory.semanticmemory_class, 'fetch_memories')
sss = await memory.dynamic_method_call(memory.semanticmemory_class, 'add_memories',
observation='some_observation', params=params)
await memory.add_dynamic_memory_class('SemanticMemory', 'SEMANTICMEMORY')
await memory.add_method_to_class(memory.semanticmemory_class, 'add_memories')
await memory.add_method_to_class(memory.semanticmemory_class, 'fetch_memories')
sss = await memory.dynamic_method_call(memory.semanticmemory_class, 'add_memories',
observation='some_observation', params=params, loader_settings=loader_settings)
susu = await memory.dynamic_method_call(memory.semanticmemory_class, 'fetch_memories',
observation='some_observation')
print(susu)
susu = await memory.dynamic_method_call(memory.semanticmemory_class, 'fetch_memories',
observation='some_observation')
print(susu)
# Adding a dynamic memory class
# dynamic_memory = memory.add_dynamic_memory_class("DynamicMemory", "ExampleNamespace")
@ -337,8 +362,8 @@ async def main():
# print(sss)
# load_jack_london = await memory._add_semantic_memory(observation = "bla", loader_settings=loader_settings, params=params)
# print(load_jack_london)
load_jack_london = await memory._add_semantic_memory(observation = "bla", loader_settings=loader_settings, params=params)
print(load_jack_london)
modulator = {"relevance": 0.1, "frequency": 0.1}