feat: adds adapter level basic operations for session management

This commit is contained in:
hajdul88 2025-10-14 15:04:55 +02:00
parent 5399b54b54
commit 15d7f69af3
2 changed files with 147 additions and 6 deletions

View file

@ -2,6 +2,7 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
from functools import lru_cache
from typing import Optional
class CacheConfig(BaseSettings):
"""
Configuration for distributed cache systems (e.g., Redis), used for locking or coordination.

View file

@ -1,21 +1,37 @@
import asyncio
import redis
import redis.asyncio as aioredis
from contextlib import contextmanager
from cognee.infrastructure.databases.cache.cache_db_interface import CacheDBInterface
from datetime import datetime
import json
class RedisAdapter(CacheDBInterface):
def __init__(self, host, port, lock_name, username=None, password=None, timeout=240, blocking_timeout=300):
def __init__(
self,
host,
port,
lock_name="default_lock",
username=None,
password=None,
timeout=240,
blocking_timeout=300,
):
super().__init__(host, port, lock_name)
self.redis = redis.Redis(host=host, port=port, username=username, password=password)
self.sync_redis = redis.Redis(host=host, port=port, username=username, password=password)
self.async_redis = aioredis.Redis(
host=host, port=port, username=username, password=password, decode_responses=True
)
self.timeout = timeout
self.blocking_timeout = blocking_timeout
def acquire(self):
"""
Acquire the Redis lock manually. Raises if acquisition fails.
Acquire the Redis lock manually. Raises if acquisition fails. (Sync because of Kuzu)
"""
self.lock = self.redis.lock(
self.lock = self.sync_redis.lock(
name=self.lock_key,
timeout=self.timeout,
blocking_timeout=self.blocking_timeout,
@ -29,7 +45,7 @@ class RedisAdapter(CacheDBInterface):
def release(self):
"""
Release the Redis lock manually, if held.
Release the Redis lock manually, if held. (Sync because of Kuzu)
"""
if self.lock:
try:
@ -41,10 +57,134 @@ class RedisAdapter(CacheDBInterface):
@contextmanager
def hold(self):
"""
Context manager for acquiring and releasing the Redis lock automatically.
Context manager for acquiring and releasing the Redis lock automatically. (Sync because of Kuzu)
"""
self.acquire()
try:
yield
finally:
self.release()
async def add_qa(
self,
user_id: str,
session_id: str,
question: str,
context: str,
answer: str,
ttl: int | None = None,
):
"""
Add a Q/A/context triplet to a Redis list for this session.
Creates the session if it doesn't exist.
Args:
user_id (str): The user ID.
session_id: Unique identifier for the session.
question: User question text.
context: Context used to answer.
answer: Assistant answer text.
ttl: Optional time-to-live (seconds). If provided, the session expires after this time.
"""
session_key = f"agent_sessions:{user_id}:{session_id}"
qa_entry = {
"time": datetime.utcnow().isoformat(),
"question": question,
"context": context,
"answer": answer,
}
await self.async_redis.rpush(session_key, json.dumps(qa_entry))
if ttl is not None:
await self.async_redis.expire(session_key, ttl)
async def get_latest_qa(self, user_id: str, session_id: str, last_n: int = 1):
"""
Retrieve the most recent Q/A/context triplet(s) for the given session.
"""
session_key = f"agent_sessions:{user_id}:{session_id}"
if last_n == 1:
data = await self.async_redis.lindex(session_key, -1)
return json.loads(data) if data else None
else:
data = await self.async_redis.lrange(session_key, -last_n, -1)
return [json.loads(d) for d in data] if data else []
async def get_all_qas(self, user_id: str, session_id: str):
"""
Retrieve all Q/A/context triplets for the given session.
"""
session_key = f"agent_sessions:{user_id}:{session_id}"
entries = await self.async_redis.lrange(session_key, 0, -1)
return [json.loads(e) for e in entries]
async def close(self):
"""
Gracefully close the async Redis connection.
"""
await self.async_redis.aclose()
async def main():
HOST = "localhost"
PORT = 6379
adapter = RedisAdapter(host=HOST, port=PORT)
session_id = "demo_session"
user_id = "demo_user_id"
print("\nAdding sample Q/A pairs...")
await adapter.add_qa(
user_id,
session_id,
"What is Redis?",
"Basic DB context",
"Redis is an in-memory data store.",
ttl=10,
)
await adapter.add_qa(
user_id,
session_id,
"Who created Redis?",
"Historical context",
"Salvatore Sanfilippo (antirez).",
ttl=10,
)
print("\nLatest QA:")
latest = await adapter.get_latest_qa(user_id, session_id)
print(json.dumps(latest, indent=2))
print("\nLast 2 QAs:")
last_two = await adapter.get_latest_qa(user_id, session_id, last_n=2)
print(json.dumps(last_two, indent=2))
session_id = "session_expire_demo"
await adapter.add_qa(
user_id,
session_id,
"What is Redis?",
"Database context",
"Redis is an in-memory data store.",
ttl=3600,
)
await adapter.add_qa(
user_id,
session_id,
"Who created Redis?",
"History context",
"Salvatore Sanfilippo (antirez).",
ttl=3600,
)
print(await adapter.get_all_qas(user_id, session_id))
await adapter.close()
if __name__ == "__main__":
asyncio.run(main())