feat: adds adapter level basic operations for session management
This commit is contained in:
parent
5399b54b54
commit
15d7f69af3
2 changed files with 147 additions and 6 deletions
|
|
@ -2,6 +2,7 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
class CacheConfig(BaseSettings):
|
class CacheConfig(BaseSettings):
|
||||||
"""
|
"""
|
||||||
Configuration for distributed cache systems (e.g., Redis), used for locking or coordination.
|
Configuration for distributed cache systems (e.g., Redis), used for locking or coordination.
|
||||||
|
|
|
||||||
|
|
@ -1,21 +1,37 @@
|
||||||
|
import asyncio
|
||||||
import redis
|
import redis
|
||||||
|
import redis.asyncio as aioredis
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from cognee.infrastructure.databases.cache.cache_db_interface import CacheDBInterface
|
from cognee.infrastructure.databases.cache.cache_db_interface import CacheDBInterface
|
||||||
|
from datetime import datetime
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
class RedisAdapter(CacheDBInterface):
|
class RedisAdapter(CacheDBInterface):
|
||||||
def __init__(self, host, port, lock_name, username=None, password=None, timeout=240, blocking_timeout=300):
|
def __init__(
|
||||||
|
self,
|
||||||
|
host,
|
||||||
|
port,
|
||||||
|
lock_name="default_lock",
|
||||||
|
username=None,
|
||||||
|
password=None,
|
||||||
|
timeout=240,
|
||||||
|
blocking_timeout=300,
|
||||||
|
):
|
||||||
super().__init__(host, port, lock_name)
|
super().__init__(host, port, lock_name)
|
||||||
|
|
||||||
self.redis = redis.Redis(host=host, port=port, username=username, password=password)
|
self.sync_redis = redis.Redis(host=host, port=port, username=username, password=password)
|
||||||
|
self.async_redis = aioredis.Redis(
|
||||||
|
host=host, port=port, username=username, password=password, decode_responses=True
|
||||||
|
)
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.blocking_timeout = blocking_timeout
|
self.blocking_timeout = blocking_timeout
|
||||||
|
|
||||||
def acquire(self):
|
def acquire(self):
|
||||||
"""
|
"""
|
||||||
Acquire the Redis lock manually. Raises if acquisition fails.
|
Acquire the Redis lock manually. Raises if acquisition fails. (Sync because of Kuzu)
|
||||||
"""
|
"""
|
||||||
self.lock = self.redis.lock(
|
self.lock = self.sync_redis.lock(
|
||||||
name=self.lock_key,
|
name=self.lock_key,
|
||||||
timeout=self.timeout,
|
timeout=self.timeout,
|
||||||
blocking_timeout=self.blocking_timeout,
|
blocking_timeout=self.blocking_timeout,
|
||||||
|
|
@ -29,7 +45,7 @@ class RedisAdapter(CacheDBInterface):
|
||||||
|
|
||||||
def release(self):
|
def release(self):
|
||||||
"""
|
"""
|
||||||
Release the Redis lock manually, if held.
|
Release the Redis lock manually, if held. (Sync because of Kuzu)
|
||||||
"""
|
"""
|
||||||
if self.lock:
|
if self.lock:
|
||||||
try:
|
try:
|
||||||
|
|
@ -41,10 +57,134 @@ class RedisAdapter(CacheDBInterface):
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def hold(self):
|
def hold(self):
|
||||||
"""
|
"""
|
||||||
Context manager for acquiring and releasing the Redis lock automatically.
|
Context manager for acquiring and releasing the Redis lock automatically. (Sync because of Kuzu)
|
||||||
"""
|
"""
|
||||||
self.acquire()
|
self.acquire()
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
self.release()
|
self.release()
|
||||||
|
|
||||||
|
async def add_qa(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
session_id: str,
|
||||||
|
question: str,
|
||||||
|
context: str,
|
||||||
|
answer: str,
|
||||||
|
ttl: int | None = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Add a Q/A/context triplet to a Redis list for this session.
|
||||||
|
Creates the session if it doesn't exist.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str): The user ID.
|
||||||
|
session_id: Unique identifier for the session.
|
||||||
|
question: User question text.
|
||||||
|
context: Context used to answer.
|
||||||
|
answer: Assistant answer text.
|
||||||
|
ttl: Optional time-to-live (seconds). If provided, the session expires after this time.
|
||||||
|
"""
|
||||||
|
session_key = f"agent_sessions:{user_id}:{session_id}"
|
||||||
|
|
||||||
|
qa_entry = {
|
||||||
|
"time": datetime.utcnow().isoformat(),
|
||||||
|
"question": question,
|
||||||
|
"context": context,
|
||||||
|
"answer": answer,
|
||||||
|
}
|
||||||
|
|
||||||
|
await self.async_redis.rpush(session_key, json.dumps(qa_entry))
|
||||||
|
|
||||||
|
if ttl is not None:
|
||||||
|
await self.async_redis.expire(session_key, ttl)
|
||||||
|
|
||||||
|
async def get_latest_qa(self, user_id: str, session_id: str, last_n: int = 1):
|
||||||
|
"""
|
||||||
|
Retrieve the most recent Q/A/context triplet(s) for the given session.
|
||||||
|
"""
|
||||||
|
session_key = f"agent_sessions:{user_id}:{session_id}"
|
||||||
|
if last_n == 1:
|
||||||
|
data = await self.async_redis.lindex(session_key, -1)
|
||||||
|
return json.loads(data) if data else None
|
||||||
|
else:
|
||||||
|
data = await self.async_redis.lrange(session_key, -last_n, -1)
|
||||||
|
return [json.loads(d) for d in data] if data else []
|
||||||
|
|
||||||
|
async def get_all_qas(self, user_id: str, session_id: str):
|
||||||
|
"""
|
||||||
|
Retrieve all Q/A/context triplets for the given session.
|
||||||
|
"""
|
||||||
|
session_key = f"agent_sessions:{user_id}:{session_id}"
|
||||||
|
entries = await self.async_redis.lrange(session_key, 0, -1)
|
||||||
|
return [json.loads(e) for e in entries]
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""
|
||||||
|
Gracefully close the async Redis connection.
|
||||||
|
"""
|
||||||
|
await self.async_redis.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
HOST = "localhost"
|
||||||
|
PORT = 6379
|
||||||
|
|
||||||
|
adapter = RedisAdapter(host=HOST, port=PORT)
|
||||||
|
session_id = "demo_session"
|
||||||
|
user_id = "demo_user_id"
|
||||||
|
|
||||||
|
print("\nAdding sample Q/A pairs...")
|
||||||
|
await adapter.add_qa(
|
||||||
|
user_id,
|
||||||
|
session_id,
|
||||||
|
"What is Redis?",
|
||||||
|
"Basic DB context",
|
||||||
|
"Redis is an in-memory data store.",
|
||||||
|
ttl=10,
|
||||||
|
)
|
||||||
|
await adapter.add_qa(
|
||||||
|
user_id,
|
||||||
|
session_id,
|
||||||
|
"Who created Redis?",
|
||||||
|
"Historical context",
|
||||||
|
"Salvatore Sanfilippo (antirez).",
|
||||||
|
ttl=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\nLatest QA:")
|
||||||
|
latest = await adapter.get_latest_qa(user_id, session_id)
|
||||||
|
print(json.dumps(latest, indent=2))
|
||||||
|
|
||||||
|
print("\nLast 2 QAs:")
|
||||||
|
last_two = await adapter.get_latest_qa(user_id, session_id, last_n=2)
|
||||||
|
print(json.dumps(last_two, indent=2))
|
||||||
|
|
||||||
|
session_id = "session_expire_demo"
|
||||||
|
|
||||||
|
await adapter.add_qa(
|
||||||
|
user_id,
|
||||||
|
session_id,
|
||||||
|
"What is Redis?",
|
||||||
|
"Database context",
|
||||||
|
"Redis is an in-memory data store.",
|
||||||
|
ttl=3600,
|
||||||
|
)
|
||||||
|
|
||||||
|
await adapter.add_qa(
|
||||||
|
user_id,
|
||||||
|
session_id,
|
||||||
|
"Who created Redis?",
|
||||||
|
"History context",
|
||||||
|
"Salvatore Sanfilippo (antirez).",
|
||||||
|
ttl=3600,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(await adapter.get_all_qas(user_id, session_id))
|
||||||
|
|
||||||
|
await adapter.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue