Merge pull request #23 from topoteretes/add_async_elements
Added async elements
This commit is contained in:
commit
6ba24d162a
9 changed files with 539 additions and 402 deletions
|
|
@ -1,5 +1,5 @@
|
|||
|
||||
FROM python:3.11-slim
|
||||
FROM python:3.11
|
||||
|
||||
# Set build argument
|
||||
ARG API_ENABLED
|
||||
|
|
|
|||
|
|
@ -25,27 +25,37 @@ After that, you can run:
|
|||
|
||||
```docker compose build promethai_mem ```
|
||||
|
||||
## Run
|
||||
## Run the level 3
|
||||
|
||||
Make sure you have Docker, Poetry, and Python 3.11 installed and postgres installed.
|
||||
|
||||
Copy the .env.example to .env and fill the variables
|
||||
|
||||
|
||||
Start the docker:
|
||||
|
||||
```docker compose up promethai_mem ```
|
||||
|
||||
Use the poetry environment:
|
||||
|
||||
``` poetry shell ```
|
||||
|
||||
Make sure to run
|
||||
Make sure to run to initialize DB tables
|
||||
|
||||
``` python scripts/create_database.py ```
|
||||
|
||||
After that, you can run:
|
||||
After that, you can run the RAG test manager.
|
||||
|
||||
``` python rag_test_manager.py \
|
||||
|
||||
```
|
||||
python rag_test_manager.py \
|
||||
--url "https://www.ibiblio.org/ebooks/London/Call%20of%20Wild.pdf" \
|
||||
--test_set "example_data/test_set.json" \
|
||||
--user_id "666" \
|
||||
--metadata "example_data/metadata.json"
|
||||
|
||||
```
|
||||
|
||||
To see example of test_set.json and metadata.json, check the files in the folder "example_data"
|
||||
Examples of metadata structure and test set are in the folder "example_data"
|
||||
|
||||
|
||||
## Clean database
|
||||
|
|
|
|||
|
|
@ -1,14 +1,15 @@
|
|||
import os
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from contextlib import contextmanager
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
||||
from sqlalchemy.orm import declarative_base, sessionmaker
|
||||
from contextlib import asynccontextmanager
|
||||
from sqlalchemy.exc import OperationalError
|
||||
from time import sleep
|
||||
import asyncio
|
||||
import sys
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
# this is needed to import classes from other modules
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
# Get the parent directory of your script and add it to sys.path
|
||||
|
|
@ -25,35 +26,43 @@ password = os.getenv('POSTGRES_PASSWORD')
|
|||
database_name = os.getenv('POSTGRES_DB')
|
||||
host = os.getenv('POSTGRES_HOST')
|
||||
|
||||
# Use the asyncpg driver for async operation
|
||||
SQLALCHEMY_DATABASE_URL = f"postgresql+asyncpg://{username}:{password}@{host}:5432/{database_name}"
|
||||
|
||||
|
||||
SQLALCHEMY_DATABASE_URL = f"postgresql://{username}:{password}@{host}:5432/{database_name}"
|
||||
|
||||
engine = create_engine(
|
||||
engine = create_async_engine(
|
||||
SQLALCHEMY_DATABASE_URL,
|
||||
pool_recycle=3600, # recycle connections after 1 hour
|
||||
pool_pre_ping=True # test the connection for liveness upon each checkout
|
||||
pool_recycle=3600,
|
||||
echo=True # Enable logging for tutorial purposes
|
||||
)
|
||||
# Use AsyncSession for the session
|
||||
AsyncSessionLocal = sessionmaker(
|
||||
bind=engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
@contextmanager
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
# Use asynccontextmanager to define an async context manager
|
||||
@asynccontextmanager
|
||||
async def get_db():
|
||||
db = AsyncSessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
await db.close()
|
||||
|
||||
def safe_db_operation(db_op, *args, **kwargs):
|
||||
# Use async/await syntax for the async function
|
||||
async def safe_db_operation(db_op, *args, **kwargs):
|
||||
for attempt in range(MAX_RETRIES):
|
||||
with get_db() as db:
|
||||
async with get_db() as db:
|
||||
try:
|
||||
return db_op(db, *args, **kwargs)
|
||||
# Ensure your db_op is also async
|
||||
return await db_op(db, *args, **kwargs)
|
||||
except OperationalError as e:
|
||||
db.rollback()
|
||||
await db.rollback()
|
||||
if "server closed the connection unexpectedly" in str(e) and attempt < MAX_RETRIES - 1:
|
||||
sleep(RETRY_DELAY)
|
||||
await asyncio.sleep(RETRY_DELAY)
|
||||
else:
|
||||
raise
|
||||
|
|
@ -1,15 +1,31 @@
|
|||
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
import logging
|
||||
# from database import AsyncSessionLocal
|
||||
|
||||
@contextmanager
|
||||
def session_scope(session):
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@asynccontextmanager
|
||||
async def session_scope(session):
|
||||
"""Provide a transactional scope around a series of operations."""
|
||||
|
||||
# session = AsyncSessionLocal()
|
||||
try:
|
||||
yield session
|
||||
session.commit()
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
await session.rollback()
|
||||
logger.error(f"Session rollback due to: {str(e)}")
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
await session.close()
|
||||
|
||||
|
||||
async def add_entity(session, entity):
|
||||
async with session_scope(session) as s: # Use your async session_scope
|
||||
s.add(entity) # No need to commit; session_scope takes care of it
|
||||
s.commit()
|
||||
return "Successfully added entity"
|
||||
191
level_3/poetry.lock
generated
191
level_3/poetry.lock
generated
|
|
@ -169,6 +169,59 @@ files = [
|
|||
{file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "asyncpg"
|
||||
version = "0.28.0"
|
||||
description = "An asyncio PostgreSQL driver"
|
||||
optional = false
|
||||
python-versions = ">=3.7.0"
|
||||
files = [
|
||||
{file = "asyncpg-0.28.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0a6d1b954d2b296292ddff4e0060f494bb4270d87fb3655dd23c5c6096d16d83"},
|
||||
{file = "asyncpg-0.28.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0740f836985fd2bd73dca42c50c6074d1d61376e134d7ad3ad7566c4f79f8184"},
|
||||
{file = "asyncpg-0.28.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e907cf620a819fab1737f2dd90c0f185e2a796f139ac7de6aa3212a8af96c050"},
|
||||
{file = "asyncpg-0.28.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:86b339984d55e8202e0c4b252e9573e26e5afa05617ed02252544f7b3e6de3e9"},
|
||||
{file = "asyncpg-0.28.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0c402745185414e4c204a02daca3d22d732b37359db4d2e705172324e2d94e85"},
|
||||
{file = "asyncpg-0.28.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c88eef5e096296626e9688f00ab627231f709d0e7e3fb84bb4413dff81d996d7"},
|
||||
{file = "asyncpg-0.28.0-cp310-cp310-win32.whl", hash = "sha256:90a7bae882a9e65a9e448fdad3e090c2609bb4637d2a9c90bfdcebbfc334bf89"},
|
||||
{file = "asyncpg-0.28.0-cp310-cp310-win_amd64.whl", hash = "sha256:76aacdcd5e2e9999e83c8fbcb748208b60925cc714a578925adcb446d709016c"},
|
||||
{file = "asyncpg-0.28.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a0e08fe2c9b3618459caaef35979d45f4e4f8d4f79490c9fa3367251366af207"},
|
||||
{file = "asyncpg-0.28.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b24e521f6060ff5d35f761a623b0042c84b9c9b9fb82786aadca95a9cb4a893b"},
|
||||
{file = "asyncpg-0.28.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:99417210461a41891c4ff301490a8713d1ca99b694fef05dabd7139f9d64bd6c"},
|
||||
{file = "asyncpg-0.28.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f029c5adf08c47b10bcdc857001bbef551ae51c57b3110964844a9d79ca0f267"},
|
||||
{file = "asyncpg-0.28.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ad1d6abf6c2f5152f46fff06b0e74f25800ce8ec6c80967f0bc789974de3c652"},
|
||||
{file = "asyncpg-0.28.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d7fa81ada2807bc50fea1dc741b26a4e99258825ba55913b0ddbf199a10d69d8"},
|
||||
{file = "asyncpg-0.28.0-cp311-cp311-win32.whl", hash = "sha256:f33c5685e97821533df3ada9384e7784bd1e7865d2b22f153f2e4bd4a083e102"},
|
||||
{file = "asyncpg-0.28.0-cp311-cp311-win_amd64.whl", hash = "sha256:5e7337c98fb493079d686a4a6965e8bcb059b8e1b8ec42106322fc6c1c889bb0"},
|
||||
{file = "asyncpg-0.28.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:1c56092465e718a9fdcc726cc3d9dcf3a692e4834031c9a9f871d92a75d20d48"},
|
||||
{file = "asyncpg-0.28.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4acd6830a7da0eb4426249d71353e8895b350daae2380cb26d11e0d4a01c5472"},
|
||||
{file = "asyncpg-0.28.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63861bb4a540fa033a56db3bb58b0c128c56fad5d24e6d0a8c37cb29b17c1c7d"},
|
||||
{file = "asyncpg-0.28.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:a93a94ae777c70772073d0512f21c74ac82a8a49be3a1d982e3f259ab5f27307"},
|
||||
{file = "asyncpg-0.28.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:d14681110e51a9bc9c065c4e7944e8139076a778e56d6f6a306a26e740ed86d2"},
|
||||
{file = "asyncpg-0.28.0-cp37-cp37m-win32.whl", hash = "sha256:8aec08e7310f9ab322925ae5c768532e1d78cfb6440f63c078b8392a38aa636a"},
|
||||
{file = "asyncpg-0.28.0-cp37-cp37m-win_amd64.whl", hash = "sha256:319f5fa1ab0432bc91fb39b3960b0d591e6b5c7844dafc92c79e3f1bff96abef"},
|
||||
{file = "asyncpg-0.28.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b337ededaabc91c26bf577bfcd19b5508d879c0ad009722be5bb0a9dd30b85a0"},
|
||||
{file = "asyncpg-0.28.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4d32b680a9b16d2957a0a3cc6b7fa39068baba8e6b728f2e0a148a67644578f4"},
|
||||
{file = "asyncpg-0.28.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f4f62f04cdf38441a70f279505ef3b4eadf64479b17e707c950515846a2df197"},
|
||||
{file = "asyncpg-0.28.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f20cac332c2576c79c2e8e6464791c1f1628416d1115935a34ddd7121bfc6a4"},
|
||||
{file = "asyncpg-0.28.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:59f9712ce01e146ff71d95d561fb68bd2d588a35a187116ef05028675462d5ed"},
|
||||
{file = "asyncpg-0.28.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fc9e9f9ff1aa0eddcc3247a180ac9e9b51a62311e988809ac6152e8fb8097756"},
|
||||
{file = "asyncpg-0.28.0-cp38-cp38-win32.whl", hash = "sha256:9e721dccd3838fcff66da98709ed884df1e30a95f6ba19f595a3706b4bc757e3"},
|
||||
{file = "asyncpg-0.28.0-cp38-cp38-win_amd64.whl", hash = "sha256:8ba7d06a0bea539e0487234511d4adf81dc8762249858ed2a580534e1720db00"},
|
||||
{file = "asyncpg-0.28.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:d009b08602b8b18edef3a731f2ce6d3f57d8dac2a0a4140367e194eabd3de457"},
|
||||
{file = "asyncpg-0.28.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ec46a58d81446d580fb21b376ec6baecab7288ce5a578943e2fc7ab73bf7eb39"},
|
||||
{file = "asyncpg-0.28.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7b48ceed606cce9e64fd5480a9b0b9a95cea2b798bb95129687abd8599c8b019"},
|
||||
{file = "asyncpg-0.28.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8858f713810f4fe67876728680f42e93b7e7d5c7b61cf2118ef9153ec16b9423"},
|
||||
{file = "asyncpg-0.28.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:5e18438a0730d1c0c1715016eacda6e9a505fc5aa931b37c97d928d44941b4bf"},
|
||||
{file = "asyncpg-0.28.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e9c433f6fcdd61c21a715ee9128a3ca48be8ac16fa07be69262f016bb0f4dbd2"},
|
||||
{file = "asyncpg-0.28.0-cp39-cp39-win32.whl", hash = "sha256:41e97248d9076bc8e4849da9e33e051be7ba37cd507cbd51dfe4b2d99c70e3dc"},
|
||||
{file = "asyncpg-0.28.0-cp39-cp39-win_amd64.whl", hash = "sha256:3ed77f00c6aacfe9d79e9eff9e21729ce92a4b38e80ea99a58ed382f42ebd55b"},
|
||||
{file = "asyncpg-0.28.0.tar.gz", hash = "sha256:7252cdc3acb2f52feaa3664280d3bcd78a46bd6c10bfd681acfffefa1120e278"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
docs = ["Sphinx (>=5.3.0,<5.4.0)", "sphinx-rtd-theme (>=1.2.2)", "sphinxcontrib-asyncio (>=0.3.0,<0.4.0)"]
|
||||
test = ["flake8 (>=5.0,<6.0)", "uvloop (>=0.15.3)"]
|
||||
|
||||
[[package]]
|
||||
name = "atlassian-python-api"
|
||||
version = "3.41.2"
|
||||
|
|
@ -685,13 +738,13 @@ pdf = ["pypdf (>=3.3.0,<4.0.0)"]
|
|||
|
||||
[[package]]
|
||||
name = "deepeval"
|
||||
version = "0.20.1"
|
||||
version = "0.20.6"
|
||||
description = "DeepEval provides evaluation and unit testing to accelerate development of LLMs and Agents."
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "deepeval-0.20.1-py3-none-any.whl", hash = "sha256:f9880a1246a2a8ba77d88b1d2f977759d34741df6d584bb3c55fadc95c52bc89"},
|
||||
{file = "deepeval-0.20.1.tar.gz", hash = "sha256:e3e36745f5e77bc6055def0b98e7a3274c87564f498f50337b670a291fde32a5"},
|
||||
{file = "deepeval-0.20.6-py3-none-any.whl", hash = "sha256:aa0b96fa062f63398858fe2af1c4982ee9e4d53cd3e322c7bbc3812fe1267614"},
|
||||
{file = "deepeval-0.20.6.tar.gz", hash = "sha256:502c6bb8bc27069d4bbac171c2aac1a760ec8e5c11e3c87a7a8ed2a81ef21db6"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
|
@ -702,6 +755,7 @@ pytest = "*"
|
|||
requests = "*"
|
||||
rich = "*"
|
||||
sentence-transformers = "*"
|
||||
sentry-sdk = "*"
|
||||
tabulate = "*"
|
||||
tqdm = "*"
|
||||
transformers = "*"
|
||||
|
|
@ -3519,6 +3573,51 @@ files = [
|
|||
{file = "sentencepiece-0.1.99.tar.gz", hash = "sha256:189c48f5cb2949288f97ccdb97f0473098d9c3dcf5a3d99d4eabe719ec27297f"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sentry-sdk"
|
||||
version = "1.32.0"
|
||||
description = "Python client for Sentry (https://sentry.io)"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "sentry-sdk-1.32.0.tar.gz", hash = "sha256:935e8fbd7787a3702457393b74b13d89a5afb67185bc0af85c00cb27cbd42e7c"},
|
||||
{file = "sentry_sdk-1.32.0-py2.py3-none-any.whl", hash = "sha256:eeb0b3550536f3bbc05bb1c7e0feb3a78d74acb43b607159a606ed2ec0a33a4d"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
certifi = "*"
|
||||
urllib3 = {version = ">=1.26.11", markers = "python_version >= \"3.6\""}
|
||||
|
||||
[package.extras]
|
||||
aiohttp = ["aiohttp (>=3.5)"]
|
||||
arq = ["arq (>=0.23)"]
|
||||
asyncpg = ["asyncpg (>=0.23)"]
|
||||
beam = ["apache-beam (>=2.12)"]
|
||||
bottle = ["bottle (>=0.12.13)"]
|
||||
celery = ["celery (>=3)"]
|
||||
chalice = ["chalice (>=1.16.0)"]
|
||||
clickhouse-driver = ["clickhouse-driver (>=0.2.0)"]
|
||||
django = ["django (>=1.8)"]
|
||||
falcon = ["falcon (>=1.4)"]
|
||||
fastapi = ["fastapi (>=0.79.0)"]
|
||||
flask = ["blinker (>=1.1)", "flask (>=0.11)", "markupsafe"]
|
||||
grpcio = ["grpcio (>=1.21.1)"]
|
||||
httpx = ["httpx (>=0.16.0)"]
|
||||
huey = ["huey (>=2)"]
|
||||
loguru = ["loguru (>=0.5)"]
|
||||
opentelemetry = ["opentelemetry-distro (>=0.35b0)"]
|
||||
opentelemetry-experimental = ["opentelemetry-distro (>=0.40b0,<1.0)", "opentelemetry-instrumentation-aiohttp-client (>=0.40b0,<1.0)", "opentelemetry-instrumentation-django (>=0.40b0,<1.0)", "opentelemetry-instrumentation-fastapi (>=0.40b0,<1.0)", "opentelemetry-instrumentation-flask (>=0.40b0,<1.0)", "opentelemetry-instrumentation-requests (>=0.40b0,<1.0)", "opentelemetry-instrumentation-sqlite3 (>=0.40b0,<1.0)", "opentelemetry-instrumentation-urllib (>=0.40b0,<1.0)"]
|
||||
pure-eval = ["asttokens", "executing", "pure-eval"]
|
||||
pymongo = ["pymongo (>=3.1)"]
|
||||
pyspark = ["pyspark (>=2.4.4)"]
|
||||
quart = ["blinker (>=1.1)", "quart (>=0.16.1)"]
|
||||
rq = ["rq (>=0.6)"]
|
||||
sanic = ["sanic (>=0.8)"]
|
||||
sqlalchemy = ["sqlalchemy (>=1.2)"]
|
||||
starlette = ["starlette (>=0.19.1)"]
|
||||
starlite = ["starlite (>=1.48)"]
|
||||
tornado = ["tornado (>=5)"]
|
||||
|
||||
[[package]]
|
||||
name = "setuptools"
|
||||
version = "68.1.2"
|
||||
|
|
@ -3816,52 +3915,52 @@ files = [
|
|||
|
||||
[[package]]
|
||||
name = "sqlalchemy"
|
||||
version = "2.0.20"
|
||||
version = "2.0.21"
|
||||
description = "Database Abstraction Library"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "SQLAlchemy-2.0.20-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:759b51346aa388c2e606ee206c0bc6f15a5299f6174d1e10cadbe4530d3c7a98"},
|
||||
{file = "SQLAlchemy-2.0.20-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1506e988ebeaaf316f183da601f24eedd7452e163010ea63dbe52dc91c7fc70e"},
|
||||
{file = "SQLAlchemy-2.0.20-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5768c268df78bacbde166b48be788b83dddaa2a5974b8810af422ddfe68a9bc8"},
|
||||
{file = "SQLAlchemy-2.0.20-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a3f0dd6d15b6dc8b28a838a5c48ced7455c3e1fb47b89da9c79cc2090b072a50"},
|
||||
{file = "SQLAlchemy-2.0.20-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:243d0fb261f80a26774829bc2cee71df3222587ac789b7eaf6555c5b15651eed"},
|
||||
{file = "SQLAlchemy-2.0.20-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6eb6d77c31e1bf4268b4d61b549c341cbff9842f8e115ba6904249c20cb78a61"},
|
||||
{file = "SQLAlchemy-2.0.20-cp310-cp310-win32.whl", hash = "sha256:bcb04441f370cbe6e37c2b8d79e4af9e4789f626c595899d94abebe8b38f9a4d"},
|
||||
{file = "SQLAlchemy-2.0.20-cp310-cp310-win_amd64.whl", hash = "sha256:d32b5ffef6c5bcb452723a496bad2d4c52b346240c59b3e6dba279f6dcc06c14"},
|
||||
{file = "SQLAlchemy-2.0.20-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:dd81466bdbc82b060c3c110b2937ab65ace41dfa7b18681fdfad2f37f27acdd7"},
|
||||
{file = "SQLAlchemy-2.0.20-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6fe7d61dc71119e21ddb0094ee994418c12f68c61b3d263ebaae50ea8399c4d4"},
|
||||
{file = "SQLAlchemy-2.0.20-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e4e571af672e1bb710b3cc1a9794b55bce1eae5aed41a608c0401885e3491179"},
|
||||
{file = "SQLAlchemy-2.0.20-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3364b7066b3c7f4437dd345d47271f1251e0cfb0aba67e785343cdbdb0fff08c"},
|
||||
{file = "SQLAlchemy-2.0.20-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:1be86ccea0c965a1e8cd6ccf6884b924c319fcc85765f16c69f1ae7148eba64b"},
|
||||
{file = "SQLAlchemy-2.0.20-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:1d35d49a972649b5080557c603110620a86aa11db350d7a7cb0f0a3f611948a0"},
|
||||
{file = "SQLAlchemy-2.0.20-cp311-cp311-win32.whl", hash = "sha256:27d554ef5d12501898d88d255c54eef8414576f34672e02fe96d75908993cf53"},
|
||||
{file = "SQLAlchemy-2.0.20-cp311-cp311-win_amd64.whl", hash = "sha256:411e7f140200c02c4b953b3dbd08351c9f9818d2bd591b56d0fa0716bd014f1e"},
|
||||
{file = "SQLAlchemy-2.0.20-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:3c6aceebbc47db04f2d779db03afeaa2c73ea3f8dcd3987eb9efdb987ffa09a3"},
|
||||
{file = "SQLAlchemy-2.0.20-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7d3f175410a6db0ad96b10bfbb0a5530ecd4fcf1e2b5d83d968dd64791f810ed"},
|
||||
{file = "SQLAlchemy-2.0.20-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ea8186be85da6587456c9ddc7bf480ebad1a0e6dcbad3967c4821233a4d4df57"},
|
||||
{file = "SQLAlchemy-2.0.20-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:c3d99ba99007dab8233f635c32b5cd24fb1df8d64e17bc7df136cedbea427897"},
|
||||
{file = "SQLAlchemy-2.0.20-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:76fdfc0f6f5341987474ff48e7a66c3cd2b8a71ddda01fa82fedb180b961630a"},
|
||||
{file = "SQLAlchemy-2.0.20-cp37-cp37m-win32.whl", hash = "sha256:d3793dcf5bc4d74ae1e9db15121250c2da476e1af8e45a1d9a52b1513a393459"},
|
||||
{file = "SQLAlchemy-2.0.20-cp37-cp37m-win_amd64.whl", hash = "sha256:79fde625a0a55220d3624e64101ed68a059c1c1f126c74f08a42097a72ff66a9"},
|
||||
{file = "SQLAlchemy-2.0.20-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:599ccd23a7146e126be1c7632d1d47847fa9f333104d03325c4e15440fc7d927"},
|
||||
{file = "SQLAlchemy-2.0.20-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:1a58052b5a93425f656675673ef1f7e005a3b72e3f2c91b8acca1b27ccadf5f4"},
|
||||
{file = "SQLAlchemy-2.0.20-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:79543f945be7a5ada9943d555cf9b1531cfea49241809dd1183701f94a748624"},
|
||||
{file = "SQLAlchemy-2.0.20-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63e73da7fb030ae0a46a9ffbeef7e892f5def4baf8064786d040d45c1d6d1dc5"},
|
||||
{file = "SQLAlchemy-2.0.20-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:3ce5e81b800a8afc870bb8e0a275d81957e16f8c4b62415a7b386f29a0cb9763"},
|
||||
{file = "SQLAlchemy-2.0.20-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:cb0d3e94c2a84215532d9bcf10229476ffd3b08f481c53754113b794afb62d14"},
|
||||
{file = "SQLAlchemy-2.0.20-cp38-cp38-win32.whl", hash = "sha256:8dd77fd6648b677d7742d2c3cc105a66e2681cc5e5fb247b88c7a7b78351cf74"},
|
||||
{file = "SQLAlchemy-2.0.20-cp38-cp38-win_amd64.whl", hash = "sha256:6f8a934f9dfdf762c844e5164046a9cea25fabbc9ec865c023fe7f300f11ca4a"},
|
||||
{file = "SQLAlchemy-2.0.20-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:26a3399eaf65e9ab2690c07bd5cf898b639e76903e0abad096cd609233ce5208"},
|
||||
{file = "SQLAlchemy-2.0.20-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4cde2e1096cbb3e62002efdb7050113aa5f01718035ba9f29f9d89c3758e7e4e"},
|
||||
{file = "SQLAlchemy-2.0.20-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d1b09ba72e4e6d341bb5bdd3564f1cea6095d4c3632e45dc69375a1dbe4e26ec"},
|
||||
{file = "SQLAlchemy-2.0.20-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1b74eeafaa11372627ce94e4dc88a6751b2b4d263015b3523e2b1e57291102f0"},
|
||||
{file = "SQLAlchemy-2.0.20-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:77d37c1b4e64c926fa3de23e8244b964aab92963d0f74d98cbc0783a9e04f501"},
|
||||
{file = "SQLAlchemy-2.0.20-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:eefebcc5c555803065128401a1e224a64607259b5eb907021bf9b175f315d2a6"},
|
||||
{file = "SQLAlchemy-2.0.20-cp39-cp39-win32.whl", hash = "sha256:3423dc2a3b94125094897118b52bdf4d37daf142cbcf26d48af284b763ab90e9"},
|
||||
{file = "SQLAlchemy-2.0.20-cp39-cp39-win_amd64.whl", hash = "sha256:5ed61e3463021763b853628aef8bc5d469fe12d95f82c74ef605049d810f3267"},
|
||||
{file = "SQLAlchemy-2.0.20-py3-none-any.whl", hash = "sha256:63a368231c53c93e2b67d0c5556a9836fdcd383f7e3026a39602aad775b14acf"},
|
||||
{file = "SQLAlchemy-2.0.20.tar.gz", hash = "sha256:ca8a5ff2aa7f3ade6c498aaafce25b1eaeabe4e42b73e25519183e4566a16fc6"},
|
||||
{file = "SQLAlchemy-2.0.21-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:1e7dc99b23e33c71d720c4ae37ebb095bebebbd31a24b7d99dfc4753d2803ede"},
|
||||
{file = "SQLAlchemy-2.0.21-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:7f0c4ee579acfe6c994637527c386d1c22eb60bc1c1d36d940d8477e482095d4"},
|
||||
{file = "SQLAlchemy-2.0.21-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3f7d57a7e140efe69ce2d7b057c3f9a595f98d0bbdfc23fd055efdfbaa46e3a5"},
|
||||
{file = "SQLAlchemy-2.0.21-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ca38746eac23dd7c20bec9278d2058c7ad662b2f1576e4c3dbfcd7c00cc48fa"},
|
||||
{file = "SQLAlchemy-2.0.21-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:3cf229704074bce31f7f47d12883afee3b0a02bb233a0ba45ddbfe542939cca4"},
|
||||
{file = "SQLAlchemy-2.0.21-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:fb87f763b5d04a82ae84ccff25554ffd903baafba6698e18ebaf32561f2fe4aa"},
|
||||
{file = "SQLAlchemy-2.0.21-cp310-cp310-win32.whl", hash = "sha256:89e274604abb1a7fd5c14867a412c9d49c08ccf6ce3e1e04fffc068b5b6499d4"},
|
||||
{file = "SQLAlchemy-2.0.21-cp310-cp310-win_amd64.whl", hash = "sha256:e36339a68126ffb708dc6d1948161cea2a9e85d7d7b0c54f6999853d70d44430"},
|
||||
{file = "SQLAlchemy-2.0.21-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:bf8eebccc66829010f06fbd2b80095d7872991bfe8415098b9fe47deaaa58063"},
|
||||
{file = "SQLAlchemy-2.0.21-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b977bfce15afa53d9cf6a632482d7968477625f030d86a109f7bdfe8ce3c064a"},
|
||||
{file = "SQLAlchemy-2.0.21-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6ff3dc2f60dbf82c9e599c2915db1526d65415be323464f84de8db3e361ba5b9"},
|
||||
{file = "SQLAlchemy-2.0.21-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:44ac5c89b6896f4740e7091f4a0ff2e62881da80c239dd9408f84f75a293dae9"},
|
||||
{file = "SQLAlchemy-2.0.21-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:87bf91ebf15258c4701d71dcdd9c4ba39521fb6a37379ea68088ce8cd869b446"},
|
||||
{file = "SQLAlchemy-2.0.21-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:b69f1f754d92eb1cc6b50938359dead36b96a1dcf11a8670bff65fd9b21a4b09"},
|
||||
{file = "SQLAlchemy-2.0.21-cp311-cp311-win32.whl", hash = "sha256:af520a730d523eab77d754f5cf44cc7dd7ad2d54907adeb3233177eeb22f271b"},
|
||||
{file = "SQLAlchemy-2.0.21-cp311-cp311-win_amd64.whl", hash = "sha256:141675dae56522126986fa4ca713739d00ed3a6f08f3c2eb92c39c6dfec463ce"},
|
||||
{file = "SQLAlchemy-2.0.21-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:7614f1eab4336df7dd6bee05bc974f2b02c38d3d0c78060c5faa4cd1ca2af3b8"},
|
||||
{file = "SQLAlchemy-2.0.21-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d59cb9e20d79686aa473e0302e4a82882d7118744d30bb1dfb62d3c47141b3ec"},
|
||||
{file = "SQLAlchemy-2.0.21-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a95aa0672e3065d43c8aa80080cdd5cc40fe92dc873749e6c1cf23914c4b83af"},
|
||||
{file = "SQLAlchemy-2.0.21-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:8c323813963b2503e54d0944813cd479c10c636e3ee223bcbd7bd478bf53c178"},
|
||||
{file = "SQLAlchemy-2.0.21-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:419b1276b55925b5ac9b4c7044e999f1787c69761a3c9756dec6e5c225ceca01"},
|
||||
{file = "SQLAlchemy-2.0.21-cp37-cp37m-win32.whl", hash = "sha256:4615623a490e46be85fbaa6335f35cf80e61df0783240afe7d4f544778c315a9"},
|
||||
{file = "SQLAlchemy-2.0.21-cp37-cp37m-win_amd64.whl", hash = "sha256:cca720d05389ab1a5877ff05af96551e58ba65e8dc65582d849ac83ddde3e231"},
|
||||
{file = "SQLAlchemy-2.0.21-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b4eae01faee9f2b17f08885e3f047153ae0416648f8e8c8bd9bc677c5ce64be9"},
|
||||
{file = "SQLAlchemy-2.0.21-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:3eb7c03fe1cd3255811cd4e74db1ab8dca22074d50cd8937edf4ef62d758cdf4"},
|
||||
{file = "SQLAlchemy-2.0.21-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c2d494b6a2a2d05fb99f01b84cc9af9f5f93bf3e1e5dbdafe4bed0c2823584c1"},
|
||||
{file = "SQLAlchemy-2.0.21-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b19ae41ef26c01a987e49e37c77b9ad060c59f94d3b3efdfdbf4f3daaca7b5fe"},
|
||||
{file = "SQLAlchemy-2.0.21-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:fc6b15465fabccc94bf7e38777d665b6a4f95efd1725049d6184b3a39fd54880"},
|
||||
{file = "SQLAlchemy-2.0.21-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:014794b60d2021cc8ae0f91d4d0331fe92691ae5467a00841f7130fe877b678e"},
|
||||
{file = "SQLAlchemy-2.0.21-cp38-cp38-win32.whl", hash = "sha256:0268256a34806e5d1c8f7ee93277d7ea8cc8ae391f487213139018b6805aeaf6"},
|
||||
{file = "SQLAlchemy-2.0.21-cp38-cp38-win_amd64.whl", hash = "sha256:73c079e21d10ff2be54a4699f55865d4b275fd6c8bd5d90c5b1ef78ae0197301"},
|
||||
{file = "SQLAlchemy-2.0.21-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:785e2f2c1cb50d0a44e2cdeea5fd36b5bf2d79c481c10f3a88a8be4cfa2c4615"},
|
||||
{file = "SQLAlchemy-2.0.21-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c111cd40910ffcb615b33605fc8f8e22146aeb7933d06569ac90f219818345ef"},
|
||||
{file = "SQLAlchemy-2.0.21-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c9cba4e7369de663611ce7460a34be48e999e0bbb1feb9130070f0685e9a6b66"},
|
||||
{file = "SQLAlchemy-2.0.21-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:50a69067af86ec7f11a8e50ba85544657b1477aabf64fa447fd3736b5a0a4f67"},
|
||||
{file = "SQLAlchemy-2.0.21-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:ccb99c3138c9bde118b51a289d90096a3791658da9aea1754667302ed6564f6e"},
|
||||
{file = "SQLAlchemy-2.0.21-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:513fd5b6513d37e985eb5b7ed89da5fd9e72354e3523980ef00d439bc549c9e9"},
|
||||
{file = "SQLAlchemy-2.0.21-cp39-cp39-win32.whl", hash = "sha256:f9fefd6298433b6e9188252f3bff53b9ff0443c8fde27298b8a2b19f6617eeb9"},
|
||||
{file = "SQLAlchemy-2.0.21-cp39-cp39-win_amd64.whl", hash = "sha256:2e617727fe4091cedb3e4409b39368f424934c7faa78171749f704b49b4bb4ce"},
|
||||
{file = "SQLAlchemy-2.0.21-py3-none-any.whl", hash = "sha256:ea7da25ee458d8f404b93eb073116156fd7d8c2a776d8311534851f28277b4ce"},
|
||||
{file = "SQLAlchemy-2.0.21.tar.gz", hash = "sha256:05b971ab1ac2994a14c56b35eaaa91f86ba080e9ad481b20d99d77f381bb6258"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
|
@ -4778,4 +4877,4 @@ multidict = ">=4.0"
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "cf14b96576c57633fea1b0cc1f7266a5fb265d4ff4ce2d3479a036483082b3b4"
|
||||
content-hash = "b3f795baef70806c8dab1058b644d95f3c36b1c2fae53a7e5784a633731bc268"
|
||||
|
|
|
|||
|
|
@ -40,11 +40,13 @@ weaviate-client = "^3.22.1"
|
|||
python-multipart = "^0.0.6"
|
||||
deep-translator = "^1.11.4"
|
||||
humanize = "^4.8.0"
|
||||
deepeval = "^0.20.1"
|
||||
deepeval = "^0.20.6"
|
||||
pymupdf = "^1.23.3"
|
||||
psycopg2 = "^2.9.8"
|
||||
llama-index = "^0.8.39.post2"
|
||||
llama-hub = "^0.0.34"
|
||||
sqlalchemy = "^2.0.21"
|
||||
asyncpg = "^0.28.0"
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -20,6 +20,8 @@ from database.database import engine
|
|||
from vectorstore_manager import Memory
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from database.database import AsyncSessionLocal
|
||||
from database.database_crud import session_scope
|
||||
|
||||
import random
|
||||
import string
|
||||
|
|
@ -30,103 +32,65 @@ dotenv.load_dotenv()
|
|||
import openai
|
||||
logger = logging.getLogger(__name__)
|
||||
openai.api_key = os.getenv("OPENAI_API_KEY", "")
|
||||
@contextmanager
|
||||
def session_scope(session):
|
||||
"""Provide a transactional scope around a series of operations."""
|
||||
|
||||
async def retrieve_latest_test_case(session, user_id, memory_id):
|
||||
try:
|
||||
yield session
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"Session rollback due to: {str(e)}")
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def retrieve_latest_test_case(session, user_id, memory_id):
|
||||
"""
|
||||
Retrieve the most recently created test case from the database filtered by user_id and memory_id.
|
||||
|
||||
Parameters:
|
||||
- session (Session): A database session.
|
||||
- user_id (int/str): The ID of the user to filter test cases by.
|
||||
- memory_id (int/str): The ID of the memory to filter test cases by.
|
||||
|
||||
Returns:
|
||||
- Object: The most recent test case attributes filtered by user_id and memory_id, or None if an error occurs.
|
||||
"""
|
||||
try:
|
||||
return (
|
||||
# Use await with session.execute() and row.fetchone() or row.all() for async query execution
|
||||
result = await session.execute(
|
||||
session.query(TestSet.attributes_list)
|
||||
.filter_by(user_id=user_id, memory_id=memory_id)
|
||||
.order_by(TestSet.created_at.desc())
|
||||
.first()
|
||||
.order_by(TestSet.created_at).first()
|
||||
)
|
||||
return result.scalar_one_or_none() # scalar_one_or_none() is a non-blocking call
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred while retrieving the latest test case: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
def add_entity(session, entity):
|
||||
"""
|
||||
Add an entity (like TestOutput, Session, etc.) to the database.
|
||||
|
||||
Parameters:
|
||||
- session (Session): A database session.
|
||||
- entity (Base): An instance of an SQLAlchemy model.
|
||||
|
||||
Returns:
|
||||
- str: A message indicating whether the addition was successful.
|
||||
"""
|
||||
with session_scope(session):
|
||||
session.add(entity)
|
||||
session.commit()
|
||||
async def add_entity(session, entity):
|
||||
async with session_scope(session) as s: # Use your async session_scope
|
||||
s.add(entity) # No need to commit; session_scope takes care of it
|
||||
s.commit()
|
||||
return "Successfully added entity"
|
||||
|
||||
|
||||
def retrieve_job_by_id(session, user_id, job_id):
|
||||
"""
|
||||
Retrieve a job by user ID and job ID.
|
||||
|
||||
Parameters:
|
||||
- session (Session): A database session.
|
||||
- user_id (int/str): The ID of the user.
|
||||
- job_id (int/str): The ID of the job to retrieve.
|
||||
|
||||
Returns:
|
||||
- Object: The job attributes filtered by user_id and job_id, or None if an error occurs.
|
||||
"""
|
||||
async def retrieve_job_by_id(session, user_id, job_id):
|
||||
try:
|
||||
return (
|
||||
result = await session.execute(
|
||||
session.query(Session.id)
|
||||
.filter_by(user_id=user_id, id=job_id)
|
||||
.order_by(Session.created_at.desc())
|
||||
.first()
|
||||
.order_by(Session.created_at)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred while retrieving the job: {str(e)}")
|
||||
return None
|
||||
|
||||
def fetch_job_id(session, user_id=None, memory_id=None, job_id=None):
|
||||
async def fetch_job_id(session, user_id=None, memory_id=None, job_id=None):
|
||||
try:
|
||||
return (
|
||||
result = await session.execute(
|
||||
session.query(Session.id)
|
||||
.filter_by(user_id=user_id, id=job_id)
|
||||
.order_by(Session.created_at.desc())
|
||||
.first()
|
||||
.order_by(Session.created_at).first()
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
# Handle exceptions as per your application's requirements.
|
||||
print(f"An error occurred: {str(e)}")
|
||||
logger.error(f"An error occurred: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
def compare_output(output, expected_output):
|
||||
"""Compare the output against the expected output."""
|
||||
pass
|
||||
|
||||
|
||||
async def fetch_test_set_id(session, user_id, id):
|
||||
try:
|
||||
# Await the execution of the query and fetching of the result
|
||||
result = await session.execute(
|
||||
session.query(TestSet.id)
|
||||
.filter_by(user_id=user_id, id=id)
|
||||
.order_by(TestSet.created_at).desc().first()
|
||||
)
|
||||
return result.scalar_one_or_none() # scalar_one_or_none() is a non-blocking call
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred while retrieving the test set: {str(e)}")
|
||||
return None
|
||||
|
||||
# Adding "embeddings" to the parameter variants function
|
||||
|
||||
def generate_param_variants(base_params=None, increments=None, ranges=None, included_params=None):
|
||||
"""Generate parameter variants for testing.
|
||||
|
|
@ -142,38 +106,39 @@ def generate_param_variants(base_params=None, increments=None, ranges=None, incl
|
|||
list: A list of dictionaries containing parameter variants.
|
||||
"""
|
||||
|
||||
# Default base values
|
||||
# Default values
|
||||
defaults = {
|
||||
'chunk_size': 500,
|
||||
'chunk_size': 250,
|
||||
'chunk_overlap': 20,
|
||||
'similarity_score': 0.5,
|
||||
'metadata_variation': 0,
|
||||
'search_type': 'hybrid'
|
||||
'search_type': 'hybrid',
|
||||
'embeddings': 'openai' # Default value added for 'embeddings'
|
||||
}
|
||||
|
||||
# Update defaults with provided base parameters
|
||||
params = {**defaults, **(base_params if base_params is not None else {})}
|
||||
params = {**defaults, **(base_params or {})}
|
||||
|
||||
default_increments = {
|
||||
'chunk_size': 500,
|
||||
'chunk_size': 150,
|
||||
'chunk_overlap': 10,
|
||||
'similarity_score': 0.1,
|
||||
'metadata_variation': 1
|
||||
}
|
||||
|
||||
# Update default increments with provided increments
|
||||
increments = {**default_increments, **(increments if increments is not None else {})}
|
||||
increments = {**default_increments, **(increments or {})}
|
||||
|
||||
# Default ranges
|
||||
default_ranges = {
|
||||
'chunk_size': 3,
|
||||
'chunk_overlap': 3,
|
||||
'similarity_score': 3,
|
||||
'metadata_variation': 3
|
||||
'chunk_size': 2,
|
||||
'chunk_overlap': 2,
|
||||
'similarity_score': 2,
|
||||
'metadata_variation': 2
|
||||
}
|
||||
|
||||
# Update default ranges with provided ranges
|
||||
ranges = {**default_ranges, **(ranges if ranges is not None else {})}
|
||||
ranges = {**default_ranges, **(ranges or {})}
|
||||
|
||||
# Generate parameter variant ranges
|
||||
param_ranges = {
|
||||
|
|
@ -181,10 +146,9 @@ def generate_param_variants(base_params=None, increments=None, ranges=None, incl
|
|||
for key in ['chunk_size', 'chunk_overlap', 'similarity_score', 'metadata_variation']
|
||||
}
|
||||
|
||||
|
||||
param_ranges['cognitive_architecture'] = ["simple_index", "cognitive_architecture"]
|
||||
# Add search_type with possible values
|
||||
# Add search_type and embeddings with possible values
|
||||
param_ranges['search_type'] = ['text', 'hybrid', 'bm25', 'generate', 'generate_grouped']
|
||||
param_ranges['embeddings'] = ['openai', 'cohere', 'huggingface'] # Added 'embeddings' values
|
||||
|
||||
# Filter param_ranges based on included_params
|
||||
if included_params is not None:
|
||||
|
|
@ -197,6 +161,10 @@ def generate_param_variants(base_params=None, increments=None, ranges=None, incl
|
|||
|
||||
return param_variants
|
||||
|
||||
|
||||
# Generate parameter variants and display a sample of the generated combinations
|
||||
|
||||
|
||||
async def generate_chatgpt_output(query:str, context:str=None):
|
||||
response = openai.ChatCompletion.create(
|
||||
model="gpt-3.5-turbo",
|
||||
|
|
@ -278,195 +246,200 @@ def generate_letter_uuid(length=8):
|
|||
letters = string.ascii_uppercase # A-Z
|
||||
return ''.join(random.choice(letters) for _ in range(length))
|
||||
|
||||
async def start_test(data, test_set=None, user_id=None, params=None, job_id=None, metadata=None):
|
||||
|
||||
def fetch_test_set_id(session, user_id, id):
|
||||
try:
|
||||
return (
|
||||
session.query(TestSet.id)
|
||||
.filter_by(user_id=user_id, id=id)
|
||||
.order_by(TestSet.created_at)
|
||||
.desc().first()
|
||||
|
||||
async with session_scope(session=AsyncSessionLocal()) as session:
|
||||
memory = await Memory.create_memory(user_id, session, namespace="SEMANTICMEMORY")
|
||||
job_id = await fetch_job_id(session, user_id=user_id, job_id=job_id)
|
||||
test_set_id = await fetch_test_set_id(session, user_id=user_id, id=job_id)
|
||||
|
||||
if job_id is None:
|
||||
job_id = str(uuid.uuid4())
|
||||
await add_entity(session, Operation(id=job_id, user_id=user_id))
|
||||
|
||||
if test_set_id is None:
|
||||
test_set_id = str(uuid.uuid4())
|
||||
await add_entity(session, TestSet(id=test_set_id, user_id=user_id, content=str(test_set)))
|
||||
|
||||
if params is None:
|
||||
data_format = data_format_route(data) # Assume data_format_route is predefined
|
||||
data_location = data_location_route(data) # Assume data_location_route is predefined
|
||||
test_params = generate_param_variants(
|
||||
included_params=['chunk_size', 'chunk_overlap', 'search_type'])
|
||||
|
||||
print("Here are the test params", str(test_params))
|
||||
|
||||
loader_settings = {
|
||||
"format": f"{data_format}",
|
||||
"source": f"{data_location}",
|
||||
"path": "https://www.ibiblio.org/ebooks/London/Call%20of%20Wild.pdf"
|
||||
}
|
||||
|
||||
async def run_test(test, loader_settings, metadata):
|
||||
test_id = str(generate_letter_uuid()) + "_" + "SEMANTICEMEMORY"
|
||||
await memory.add_memory_instance("ExampleMemory")
|
||||
existing_user = await Memory.check_existing_user(user_id, session)
|
||||
await memory.manage_memory_attributes(existing_user)
|
||||
test_class = test_id + "_class"
|
||||
await memory.add_dynamic_memory_class(test_id.lower(), test_id)
|
||||
dynamic_memory_class = getattr(memory, test_class.lower(), None)
|
||||
|
||||
# Assuming add_method_to_class and dynamic_method_call are predefined methods in Memory class
|
||||
|
||||
if dynamic_memory_class is not None:
|
||||
await memory.add_method_to_class(dynamic_memory_class, 'add_memories')
|
||||
else:
|
||||
print(f"No attribute named {test_class.lower()} in memory.")
|
||||
|
||||
if dynamic_memory_class is not None:
|
||||
await memory.add_method_to_class(dynamic_memory_class, 'fetch_memories')
|
||||
else:
|
||||
print(f"No attribute named {test_class.lower()} in memory.")
|
||||
|
||||
print(f"Trying to access: {test_class.lower()}")
|
||||
print("Available memory classes:", await memory.list_memory_classes())
|
||||
|
||||
print(f"Trying to check: ", test)
|
||||
loader_settings.update(test)
|
||||
|
||||
|
||||
async def run_load_test_element(test, loader_settings, metadata, test_id):
|
||||
|
||||
test_class = test_id + "_class"
|
||||
# memory.test_class
|
||||
|
||||
await memory.add_dynamic_memory_class(test_id.lower(), test_id)
|
||||
dynamic_memory_class = getattr(memory, test_class.lower(), None)
|
||||
if dynamic_memory_class is not None:
|
||||
await memory.add_method_to_class(dynamic_memory_class, 'add_memories')
|
||||
else:
|
||||
print(f"No attribute named {test_class.lower()} in memory.")
|
||||
|
||||
if dynamic_memory_class is not None:
|
||||
await memory.add_method_to_class(dynamic_memory_class, 'fetch_memories')
|
||||
else:
|
||||
print(f"No attribute named {test_class.lower()} in memory.")
|
||||
|
||||
print(f"Trying to access: {test_class.lower()}")
|
||||
print("Available memory classes:", await memory.list_memory_classes())
|
||||
|
||||
print(f"Trying to check: ", test)
|
||||
# print("Here is the loader settings", str(loader_settings))
|
||||
# print("Here is the medatadata", str(metadata))
|
||||
load_action = await memory.dynamic_method_call(dynamic_memory_class, 'add_memories',
|
||||
observation='some_observation', params=metadata,
|
||||
loader_settings=loader_settings)
|
||||
async def run_search_eval_element(test_item, test_id):
|
||||
|
||||
test_class = test_id + "_class"
|
||||
await memory.add_dynamic_memory_class(test_id.lower(), test_id)
|
||||
dynamic_memory_class = getattr(memory, test_class.lower(), None)
|
||||
|
||||
retrieve_action = await memory.dynamic_method_call(dynamic_memory_class, 'fetch_memories',
|
||||
observation=test_item["question"],
|
||||
search_type=test_item["search_type"])
|
||||
test_result = await eval_test(query=test_item["question"], expected_output=test_item["answer"],
|
||||
context=str(retrieve_action))
|
||||
print(test_result)
|
||||
delete_mems = await memory.dynamic_method_call(dynamic_memory_class, 'delete_memories',
|
||||
namespace=test_id)
|
||||
return test_result
|
||||
test_load_pipeline = await asyncio.gather(
|
||||
*(run_load_test_element(test_item,loader_settings, metadata, test_id) for test_item in test_set)
|
||||
)
|
||||
|
||||
test_eval_pipeline = await asyncio.gather(
|
||||
*(run_search_eval_element(test_item, test_id) for test_item in test_set)
|
||||
)
|
||||
logging.info("Results of the eval pipeline %s", str(test_eval_pipeline))
|
||||
await add_entity(session, TestOutput(id=test_id, user_id=user_id, test_results=str(test_eval_pipeline)))
|
||||
|
||||
# # Gather and run all tests in parallel
|
||||
results = await asyncio.gather(
|
||||
*(run_test(test, loader_settings, metadata) for test in test_params)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred while retrieving the job: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
async def start_test(data, test_set=None, user_id=None, params=None, job_id=None ,metadata=None):
|
||||
|
||||
Session = sessionmaker(bind=engine)
|
||||
session = Session()
|
||||
|
||||
|
||||
#do i need namespace in memory instance, fix it
|
||||
memory = Memory.create_memory(user_id, session, namespace="SEMANTICMEMORY")
|
||||
|
||||
job_id = fetch_job_id(session, user_id = user_id,job_id =job_id)
|
||||
test_set_id = fetch_test_set_id(session, user_id=user_id, id=job_id)
|
||||
if job_id is None:
|
||||
job_id = str(uuid.uuid4())
|
||||
logging.info("we are adding a new job ID")
|
||||
add_entity(session, Operation(id = job_id, user_id = user_id))
|
||||
if test_set_id is None:
|
||||
test_set_id = str(uuid.uuid4())
|
||||
add_entity(session, TestSet(id = test_set_id, user_id = user_id, content = str(test_set)))
|
||||
|
||||
|
||||
|
||||
if params is None:
|
||||
|
||||
data_format = data_format_route(data)
|
||||
data_location = data_location_route(data)
|
||||
test_params = generate_param_variants( included_params=['chunk_size', 'chunk_overlap', 'similarity_score'])
|
||||
|
||||
|
||||
|
||||
loader_settings = {
|
||||
"format": f"{data_format}",
|
||||
"source": f"{data_location}",
|
||||
"path": "https://www.ibiblio.org/ebooks/London/Call%20of%20Wild.pdf"
|
||||
}
|
||||
|
||||
for test in test_params:
|
||||
test_id = str(generate_letter_uuid()) + "_" + "SEMANTICEMEMORY"
|
||||
|
||||
# Adding a memory instance
|
||||
memory.add_memory_instance("ExampleMemory")
|
||||
|
||||
# Managing memory attributes
|
||||
existing_user = Memory.check_existing_user(user_id, session)
|
||||
print("here is the existing user", existing_user)
|
||||
memory.manage_memory_attributes(existing_user)
|
||||
|
||||
test_class = test_id + "_class"
|
||||
# memory.test_class
|
||||
|
||||
memory.add_dynamic_memory_class(test_id.lower(), test_id)
|
||||
dynamic_memory_class = getattr(memory, test_class.lower(), None)
|
||||
|
||||
|
||||
|
||||
if dynamic_memory_class is not None:
|
||||
memory.add_method_to_class(dynamic_memory_class, 'add_memories')
|
||||
else:
|
||||
print(f"No attribute named {test_class.lower()} in memory.")
|
||||
|
||||
if dynamic_memory_class is not None:
|
||||
memory.add_method_to_class(dynamic_memory_class, 'fetch_memories')
|
||||
else:
|
||||
print(f"No attribute named {test_class.lower()} in memory.")
|
||||
|
||||
print(f"Trying to access: {test_class.lower()}")
|
||||
print("Available memory classes:", memory.list_memory_classes())
|
||||
|
||||
print(f"Trying to check: ", test)
|
||||
loader_settings.update(test)
|
||||
load_action = await memory.dynamic_method_call(dynamic_memory_class, 'add_memories',
|
||||
observation='some_observation', params=metadata, loader_settings=loader_settings)
|
||||
loader_settings = {key: value for key, value in loader_settings.items() if key not in test}
|
||||
|
||||
|
||||
|
||||
test_result_collection =[]
|
||||
|
||||
for test in test_set:
|
||||
retrieve_action = await memory.dynamic_method_call(dynamic_memory_class, 'fetch_memories',
|
||||
observation=test["question"])
|
||||
|
||||
test_results = await eval_test( query=test["question"], expected_output=test["answer"], context= str(retrieve_action))
|
||||
test_result_collection.append(test_results)
|
||||
|
||||
print(test_results)
|
||||
if dynamic_memory_class is not None:
|
||||
memory.add_method_to_class(dynamic_memory_class, 'delete_memories')
|
||||
else:
|
||||
print(f"No attribute named {test_class.lower()} in memory.")
|
||||
delete_mems = await memory.dynamic_method_call(dynamic_memory_class, 'delete_memories',
|
||||
namespace =test_id)
|
||||
|
||||
print(test_result_collection)
|
||||
|
||||
add_entity(session, TestOutput(id=test_id, user_id=user_id, test_results=str(test_result_collection)))
|
||||
return results
|
||||
|
||||
async def main():
|
||||
|
||||
params = {
|
||||
"version": "1.0",
|
||||
"agreement_id": "AG123456",
|
||||
"privacy_policy": "https://example.com/privacy",
|
||||
"terms_of_service": "https://example.com/terms",
|
||||
"format": "json",
|
||||
"schema_version": "1.1",
|
||||
"checksum": "a1b2c3d4e5f6",
|
||||
"owner": "John Doe",
|
||||
"license": "MIT",
|
||||
"validity_start": "2023-08-01",
|
||||
"validity_end": "2024-07-31",
|
||||
}
|
||||
|
||||
test_set = [
|
||||
{
|
||||
"question": "Who is the main character in 'The Call of the Wild'?",
|
||||
"answer": "Buck"
|
||||
},
|
||||
{
|
||||
"question": "Who wrote 'The Call of the Wild'?",
|
||||
"answer": "Jack London"
|
||||
},
|
||||
{
|
||||
"question": "Where does Buck live at the start of the book?",
|
||||
"answer": "In the Santa Clara Valley, at Judge Miller’s place."
|
||||
},
|
||||
{
|
||||
"question": "Why is Buck kidnapped?",
|
||||
"answer": "He is kidnapped to be sold as a sled dog in the Yukon during the Klondike Gold Rush."
|
||||
},
|
||||
{
|
||||
"question": "How does Buck become the leader of the sled dog team?",
|
||||
"answer": "Buck becomes the leader after defeating the original leader, Spitz, in a fight."
|
||||
}
|
||||
]
|
||||
result = await start_test("https://www.ibiblio.org/ebooks/London/Call%20of%20Wild.pdf", test_set=test_set, user_id="666", params=None, metadata=params)
|
||||
#
|
||||
# params = {
|
||||
# "version": "1.0",
|
||||
# "agreement_id": "AG123456",
|
||||
# "privacy_policy": "https://example.com/privacy",
|
||||
# "terms_of_service": "https://example.com/terms",
|
||||
# "format": "json",
|
||||
# "schema_version": "1.1",
|
||||
# "checksum": "a1b2c3d4e5f6",
|
||||
# "owner": "John Doe",
|
||||
# "license": "MIT",
|
||||
# "validity_start": "2023-08-01",
|
||||
# "validity_end": "2024-07-31",
|
||||
# }
|
||||
# parser = argparse.ArgumentParser(description="Run tests against a document.")
|
||||
# parser.add_argument("--url", required=True, help="URL of the document to test.")
|
||||
# parser.add_argument("--test_set", required=True, help="Path to JSON file containing the test set.")
|
||||
# parser.add_argument("--user_id", required=True, help="User ID.")
|
||||
# parser.add_argument("--params", help="Additional parameters in JSON format.")
|
||||
# parser.add_argument("--metadata", required=True, help="Path to JSON file containing metadata.")
|
||||
#
|
||||
# test_set = [
|
||||
# {
|
||||
# "question": "Who is the main character in 'The Call of the Wild'?",
|
||||
# "answer": "Buck"
|
||||
# },
|
||||
# {
|
||||
# "question": "Who wrote 'The Call of the Wild'?",
|
||||
# "answer": "Jack London"
|
||||
# },
|
||||
# {
|
||||
# "question": "Where does Buck live at the start of the book?",
|
||||
# "answer": "In the Santa Clara Valley, at Judge Miller’s place."
|
||||
# },
|
||||
# {
|
||||
# "question": "Why is Buck kidnapped?",
|
||||
# "answer": "He is kidnapped to be sold as a sled dog in the Yukon during the Klondike Gold Rush."
|
||||
# },
|
||||
# {
|
||||
# "question": "How does Buck become the leader of the sled dog team?",
|
||||
# "answer": "Buck becomes the leader after defeating the original leader, Spitz, in a fight."
|
||||
# }
|
||||
# ]
|
||||
# result = await start_test("https://www.ibiblio.org/ebooks/London/Call%20of%20Wild.pdf", test_set=test_set, user_id="666", params=None, metadata=params)
|
||||
# args = parser.parse_args()
|
||||
#
|
||||
parser = argparse.ArgumentParser(description="Run tests against a document.")
|
||||
parser.add_argument("--url", required=True, help="URL of the document to test.")
|
||||
parser.add_argument("--test_set", required=True, help="Path to JSON file containing the test set.")
|
||||
parser.add_argument("--user_id", required=True, help="User ID.")
|
||||
parser.add_argument("--params", help="Additional parameters in JSON format.")
|
||||
parser.add_argument("--metadata", required=True, help="Path to JSON file containing metadata.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
with open(args.test_set, "r") as file:
|
||||
test_set = json.load(file)
|
||||
if not isinstance(test_set, list): # Expecting a list
|
||||
raise TypeError("Parsed test_set JSON is not a list.")
|
||||
except Exception as e:
|
||||
print(f"Error loading test_set: {str(e)}")
|
||||
return
|
||||
|
||||
try:
|
||||
with open(args.metadata, "r") as file:
|
||||
metadata = json.load(file)
|
||||
if not isinstance(metadata, dict):
|
||||
raise TypeError("Parsed metadata JSON is not a dictionary.")
|
||||
except Exception as e:
|
||||
print(f"Error loading metadata: {str(e)}")
|
||||
return
|
||||
|
||||
if args.params:
|
||||
try:
|
||||
params = json.loads(args.params)
|
||||
if not isinstance(params, dict):
|
||||
raise TypeError("Parsed params JSON is not a dictionary.")
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Error parsing params: {str(e)}")
|
||||
return
|
||||
else:
|
||||
params = None
|
||||
#clean up params here
|
||||
await start_test(args.url, test_set, args.user_id, params=None, metadata=metadata)
|
||||
# try:
|
||||
# with open(args.test_set, "r") as file:
|
||||
# test_set = json.load(file)
|
||||
# if not isinstance(test_set, list): # Expecting a list
|
||||
# raise TypeError("Parsed test_set JSON is not a list.")
|
||||
# except Exception as e:
|
||||
# print(f"Error loading test_set: {str(e)}")
|
||||
# return
|
||||
#
|
||||
# try:
|
||||
# with open(args.metadata, "r") as file:
|
||||
# metadata = json.load(file)
|
||||
# if not isinstance(metadata, dict):
|
||||
# raise TypeError("Parsed metadata JSON is not a dictionary.")
|
||||
# except Exception as e:
|
||||
# print(f"Error loading metadata: {str(e)}")
|
||||
# return
|
||||
#
|
||||
# if args.params:
|
||||
# try:
|
||||
# params = json.loads(args.params)
|
||||
# if not isinstance(params, dict):
|
||||
# raise TypeError("Parsed params JSON is not a dictionary.")
|
||||
# except json.JSONDecodeError as e:
|
||||
# print(f"Error parsing params: {str(e)}")
|
||||
# return
|
||||
# else:
|
||||
# params = None
|
||||
# #clean up params here
|
||||
# await start_test(args.url, test_set, args.user_id, params=None, metadata=metadata)
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
|
|
|
|||
|
|
@ -179,17 +179,17 @@ class WeaviateVectorDB(VectorDB):
|
|||
["id", "creationTimeUnix", "lastUpdateTimeUnix", "score", 'distance']
|
||||
).with_where(params_user_id).with_limit(10)
|
||||
|
||||
n_of_observations = kwargs.get('n_of_observations', 2)
|
||||
|
||||
try:
|
||||
if search_type == 'text':
|
||||
query_output = (
|
||||
base_query
|
||||
.with_near_text({"concepts": [observation]})
|
||||
.with_autocut(n_of_observations)
|
||||
.do()
|
||||
)
|
||||
elif search_type == 'hybrid':
|
||||
n_of_observations = kwargs.get('n_of_observations', 2)
|
||||
|
||||
|
||||
query_output = (
|
||||
base_query
|
||||
.with_hybrid(query=observation, fusion_type=HybridFusion.RELATIVE_SCORE)
|
||||
|
|
@ -200,6 +200,7 @@ class WeaviateVectorDB(VectorDB):
|
|||
query_output = (
|
||||
base_query
|
||||
.with_bm25(query=observation)
|
||||
.with_autocut(n_of_observations)
|
||||
.do()
|
||||
)
|
||||
elif search_type == 'generate':
|
||||
|
|
@ -208,6 +209,7 @@ class WeaviateVectorDB(VectorDB):
|
|||
base_query
|
||||
.with_generate(single_prompt=generate_prompt)
|
||||
.with_near_text({"concepts": [observation]})
|
||||
.with_autocut(n_of_observations)
|
||||
.do()
|
||||
)
|
||||
elif search_type == 'generate_grouped':
|
||||
|
|
@ -216,6 +218,7 @@ class WeaviateVectorDB(VectorDB):
|
|||
base_query
|
||||
.with_generate(grouped_task=generate_prompt)
|
||||
.with_near_text({"concepts": [observation]})
|
||||
.with_autocut(n_of_observations)
|
||||
.do()
|
||||
)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
import logging
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
import marvin
|
||||
from dotenv import load_dotenv
|
||||
|
|
@ -15,6 +17,7 @@ from models.operation import Operation
|
|||
load_dotenv()
|
||||
import ast
|
||||
import tracemalloc
|
||||
from database.database_crud import session_scope, add_entity
|
||||
|
||||
tracemalloc.start()
|
||||
|
||||
|
|
@ -48,7 +51,7 @@ class DynamicBaseMemory(BaseMemory):
|
|||
self.associations = []
|
||||
|
||||
|
||||
def add_method(self, method_name):
|
||||
async def add_method(self, method_name):
|
||||
"""
|
||||
Add a method to the memory class.
|
||||
|
||||
|
|
@ -60,7 +63,7 @@ class DynamicBaseMemory(BaseMemory):
|
|||
"""
|
||||
self.methods.add(method_name)
|
||||
|
||||
def add_attribute(self, attribute_name):
|
||||
async def add_attribute(self, attribute_name):
|
||||
"""
|
||||
Add an attribute to the memory class.
|
||||
|
||||
|
|
@ -72,7 +75,7 @@ class DynamicBaseMemory(BaseMemory):
|
|||
"""
|
||||
self.attributes.add(attribute_name)
|
||||
|
||||
def get_attribute(self, attribute_name):
|
||||
async def get_attribute(self, attribute_name):
|
||||
"""
|
||||
Check if the attribute is in the memory class.
|
||||
|
||||
|
|
@ -84,7 +87,7 @@ class DynamicBaseMemory(BaseMemory):
|
|||
"""
|
||||
return attribute_name in self.attributes
|
||||
|
||||
def add_association(self, associated_memory):
|
||||
async def add_association(self, associated_memory):
|
||||
"""
|
||||
Add an association to another memory class.
|
||||
|
||||
|
|
@ -149,26 +152,26 @@ class Memory:
|
|||
self.OPENAI_TEMPERATURE = float(os.getenv("OPENAI_TEMPERATURE", 0.0))
|
||||
self.OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
|
||||
@classmethod
|
||||
def create_memory(cls, user_id: str, session, **kwargs):
|
||||
async def create_memory(cls, user_id: str, session, **kwargs):
|
||||
"""
|
||||
Class method that acts as a factory method for creating Memory instances.
|
||||
It performs necessary DB checks or updates before instance creation.
|
||||
"""
|
||||
existing_user = cls.check_existing_user(user_id, session)
|
||||
existing_user = await cls.check_existing_user(user_id, session)
|
||||
|
||||
if existing_user:
|
||||
|
||||
# Handle existing user scenario...
|
||||
memory_id = cls.check_existing_memory(user_id, session)
|
||||
memory_id = await cls.check_existing_memory(user_id, session)
|
||||
logging.info(f"Existing user {user_id} found in the DB. Memory ID: {memory_id}")
|
||||
else:
|
||||
# Handle new user scenario...
|
||||
memory_id = cls.handle_new_user(user_id, session)
|
||||
memory_id = await cls.handle_new_user(user_id, session)
|
||||
logging.info(f"New user {user_id} created in the DB. Memory ID: {memory_id}")
|
||||
|
||||
return cls(user_id=user_id, session=session, memory_id=memory_id, **kwargs)
|
||||
|
||||
def list_memory_classes(self):
|
||||
async def list_memory_classes(self):
|
||||
"""
|
||||
Lists all available memory classes in the memory instance.
|
||||
"""
|
||||
|
|
@ -176,32 +179,36 @@ class Memory:
|
|||
return [attr for attr in dir(self) if attr.endswith("_class")]
|
||||
|
||||
@staticmethod
|
||||
def check_existing_user(user_id: str, session):
|
||||
async def check_existing_user(user_id: str, session):
|
||||
"""Check if a user exists in the DB and return it."""
|
||||
return session.query(User).filter_by(id=user_id).first()
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == user_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@staticmethod
|
||||
def check_existing_memory(user_id: str, session):
|
||||
async def check_existing_memory(user_id: str, session):
|
||||
"""Check if a user memory exists in the DB and return it."""
|
||||
return session.query(MemoryModel.id).filter_by(user_id=user_id).first()
|
||||
result = await session.execute(
|
||||
select(MemoryModel.id).where(MemoryModel.user_id == user_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@staticmethod
|
||||
def handle_new_user(user_id: str, session):
|
||||
async def handle_new_user(user_id: str, session):
|
||||
"""Handle new user creation in the DB and return the new memory ID."""
|
||||
|
||||
#handle these better in terms of retry and error handling
|
||||
memory_id = str(uuid.uuid4())
|
||||
new_user = User(id=user_id)
|
||||
session.add(new_user)
|
||||
session.commit()
|
||||
await add_entity(session, new_user)
|
||||
|
||||
memory = MemoryModel(id=memory_id, user_id=user_id, methods_list=str(['Memory', 'SemanticMemory', 'EpisodicMemory']),
|
||||
attributes_list=str(['user_id', 'index_name', 'db_type', 'knowledge_source', 'knowledge_type', 'memory_id', 'long_term_memory', 'short_term_memory', 'namespace']))
|
||||
session.add(memory)
|
||||
session.commit()
|
||||
await add_entity(session, memory)
|
||||
return memory_id
|
||||
|
||||
def add_memory_instance(self, memory_class_name: str):
|
||||
async def add_memory_instance(self, memory_class_name: str):
|
||||
"""Add a new memory instance to the memory_instances list."""
|
||||
instance = DynamicBaseMemory(memory_class_name, self.user_id,
|
||||
self.memory_id, self.index_name,
|
||||
|
|
@ -209,15 +216,26 @@ class Memory:
|
|||
print("The following instance was defined", instance)
|
||||
self.memory_instances.append(instance)
|
||||
|
||||
def manage_memory_attributes(self, existing_user):
|
||||
async def query_method(self):
|
||||
methods_list = await self.session.execute(
|
||||
select(MemoryModel.methods_list).where(MemoryModel.id == self.memory_id)
|
||||
)
|
||||
methods_list = methods_list.scalar_one_or_none()
|
||||
return methods_list
|
||||
|
||||
async def manage_memory_attributes(self, existing_user):
|
||||
"""Manage memory attributes based on the user existence."""
|
||||
if existing_user:
|
||||
print(f"ID before query: {self.memory_id}, type: {type(self.memory_id)}")
|
||||
attributes_list = self.session.query(MemoryModel.attributes_list).filter_by(id=self.memory_id[0]).scalar()
|
||||
|
||||
|
||||
|
||||
# attributes_list = await self.session.query(MemoryModel.attributes_list).filter_by(id=self.memory_id[0]).scalar()
|
||||
attributes_list = await self.query_method()
|
||||
logging.info(f"Attributes list: {attributes_list}")
|
||||
if attributes_list is not None:
|
||||
attributes_list = ast.literal_eval(attributes_list)
|
||||
self.handle_attributes(attributes_list)
|
||||
await self.handle_attributes(attributes_list)
|
||||
else:
|
||||
logging.warning("attributes_list is None!")
|
||||
else:
|
||||
|
|
@ -225,20 +243,25 @@ class Memory:
|
|||
'knowledge_source', 'knowledge_type',
|
||||
'memory_id', 'long_term_memory',
|
||||
'short_term_memory', 'namespace']
|
||||
self.handle_attributes(attributes_list)
|
||||
await self.handle_attributes(attributes_list)
|
||||
|
||||
def handle_attributes(self, attributes_list):
|
||||
async def handle_attributes(self, attributes_list):
|
||||
"""Handle attributes for existing memory instances."""
|
||||
for attr in attributes_list:
|
||||
self.memory_class.add_attribute(attr)
|
||||
await self.memory_class.add_attribute(attr)
|
||||
|
||||
def manage_memory_methods(self, existing_user):
|
||||
async def manage_memory_methods(self, existing_user):
|
||||
"""
|
||||
Manage memory methods based on the user existence.
|
||||
"""
|
||||
if existing_user:
|
||||
# Fetch existing methods from the database
|
||||
methods_list = self.session.query(MemoryModel.methods_list).filter_by(id=self.memory_id).scalar()
|
||||
# methods_list = await self.session.query(MemoryModel.methods_list).filter_by(id=self.memory_id).scalar()
|
||||
|
||||
methods_list = await self.session.execute(
|
||||
select(MemoryModel.methods_list).where(MemoryModel.id == self.memory_id[0])
|
||||
)
|
||||
methods_list = methods_list.scalar_one_or_none()
|
||||
methods_list = ast.literal_eval(methods_list)
|
||||
else:
|
||||
# Define default methods for a new user
|
||||
|
|
@ -260,20 +283,20 @@ class Memory:
|
|||
return await method(*args, **kwargs)
|
||||
raise AttributeError(f"{dynamic_base_memory_instance.name} object has no attribute {method_name}")
|
||||
|
||||
def add_dynamic_memory_class(self, class_name: str, namespace: str):
|
||||
async def add_dynamic_memory_class(self, class_name: str, namespace: str):
|
||||
logging.info("Here is the memory id %s", self.memory_id[0])
|
||||
new_memory_class = DynamicBaseMemory(class_name, self.user_id, self.memory_id[0], self.index_name,
|
||||
self.db_type, namespace)
|
||||
setattr(self, f"{class_name.lower()}_class", new_memory_class)
|
||||
return new_memory_class
|
||||
|
||||
def add_attribute_to_class(self, class_instance, attribute_name: str):
|
||||
async def add_attribute_to_class(self, class_instance, attribute_name: str):
|
||||
#add this to database for a particular user and load under memory id
|
||||
class_instance.add_attribute(attribute_name)
|
||||
await class_instance.add_attribute(attribute_name)
|
||||
|
||||
def add_method_to_class(self, class_instance, method_name: str):
|
||||
async def add_method_to_class(self, class_instance, method_name: str):
|
||||
#add this to database for a particular user and load under memory id
|
||||
class_instance.add_method(method_name)
|
||||
await class_instance.add_method(method_name)
|
||||
|
||||
|
||||
|
||||
|
|
@ -297,35 +320,37 @@ async def main():
|
|||
}
|
||||
loader_settings = {
|
||||
"format": "PDF",
|
||||
"source": "url",
|
||||
"source": "URL",
|
||||
"path": "https://www.ibiblio.org/ebooks/London/Call%20of%20Wild.pdf"
|
||||
}
|
||||
# memory_instance = Memory(namespace='SEMANTICMEMORY')
|
||||
# sss = await memory_instance.dynamic_method_call(memory_instance.semantic_memory_class, 'fetch_memories', observation='some_observation')
|
||||
#
|
||||
# Create a DB session
|
||||
Session = sessionmaker(bind=engine)
|
||||
session = Session()
|
||||
memory = Memory.create_memory("676", session, namespace='SEMANTICMEMORY')
|
||||
|
||||
# Adding a memory instance
|
||||
memory.add_memory_instance("ExampleMemory")
|
||||
from database.database_crud import session_scope
|
||||
from database.database import AsyncSessionLocal
|
||||
|
||||
async with session_scope(AsyncSessionLocal()) as session:
|
||||
memory = await Memory.create_memory("676", session, namespace='SEMANTICMEMORY')
|
||||
|
||||
# Adding a memory instance
|
||||
await memory.add_memory_instance("ExampleMemory")
|
||||
|
||||
|
||||
# Managing memory attributes
|
||||
existing_user = Memory.check_existing_user("676", session)
|
||||
print("here is the existing user", existing_user)
|
||||
memory.manage_memory_attributes(existing_user)
|
||||
# Managing memory attributes
|
||||
existing_user = await Memory.check_existing_user("676", session)
|
||||
print("here is the existing user", existing_user)
|
||||
await memory.manage_memory_attributes(existing_user)
|
||||
# aeehuvyq_semanticememory_class
|
||||
|
||||
memory.add_dynamic_memory_class('SemanticMemory', 'SEMANTICMEMORY')
|
||||
memory.add_method_to_class(memory.semanticmemory_class, 'add_memories')
|
||||
memory.add_method_to_class(memory.semanticmemory_class, 'fetch_memories')
|
||||
sss = await memory.dynamic_method_call(memory.semanticmemory_class, 'add_memories',
|
||||
observation='some_observation', params=params)
|
||||
await memory.add_dynamic_memory_class('SemanticMemory', 'SEMANTICMEMORY')
|
||||
await memory.add_method_to_class(memory.semanticmemory_class, 'add_memories')
|
||||
await memory.add_method_to_class(memory.semanticmemory_class, 'fetch_memories')
|
||||
sss = await memory.dynamic_method_call(memory.semanticmemory_class, 'add_memories',
|
||||
observation='some_observation', params=params, loader_settings=loader_settings)
|
||||
|
||||
susu = await memory.dynamic_method_call(memory.semanticmemory_class, 'fetch_memories',
|
||||
observation='some_observation')
|
||||
print(susu)
|
||||
susu = await memory.dynamic_method_call(memory.semanticmemory_class, 'fetch_memories',
|
||||
observation='some_observation')
|
||||
print(susu)
|
||||
|
||||
# Adding a dynamic memory class
|
||||
# dynamic_memory = memory.add_dynamic_memory_class("DynamicMemory", "ExampleNamespace")
|
||||
|
|
@ -337,8 +362,8 @@ async def main():
|
|||
|
||||
|
||||
# print(sss)
|
||||
# load_jack_london = await memory._add_semantic_memory(observation = "bla", loader_settings=loader_settings, params=params)
|
||||
# print(load_jack_london)
|
||||
load_jack_london = await memory._add_semantic_memory(observation = "bla", loader_settings=loader_settings, params=params)
|
||||
print(load_jack_london)
|
||||
|
||||
modulator = {"relevance": 0.1, "frequency": 0.1}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue