LightRAG/service/tests/test_direct_integration.py

159 lines
5.2 KiB
Python

import sys
import os
import asyncio
import uuid
from unittest.mock import MagicMock, AsyncMock
# Set SQLite for testing
db_path = "./test.db"
if os.path.exists(db_path):
os.remove(db_path)
os.environ["DATABASE_URL"] = f"sqlite:///{db_path}"
# Add project root to path
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../"))
sys.path.append(project_root)
sys.path.append(os.path.join(project_root, "service"))
# Mocking removed to allow real imports
# Now import the modified router
from lightrag.api.routers.query_routes import create_query_routes, QueryRequest
# Import service DB to check if records are created
from app.core.database import SessionLocal, engine, Base
from app.models.models import ChatMessage, ChatSession
# Create tables
Base.metadata.create_all(bind=engine)
async def test_direct_integration():
print("Testing Direct Integration...")
# Mock RAG instance
mock_rag = MagicMock()
mock_rag.aquery_llm = AsyncMock(return_value={
"llm_response": {"content": "This is a mocked response."},
"data": {"references": [{"reference_id": "1", "file_path": "doc1.txt"}]}
})
# Create router (this registers the endpoints but we'll call the function directly for testing)
# We need to access the function decorated by @router.post("/query")
# Since we can't easily get the route function from the router object without starting FastAPI,
# we will inspect the router.routes
create_query_routes(mock_rag)
from lightrag.api.routers.query_routes import router
# Find the query_text function
query_route = next(r for r in router.routes if r.path == "/query")
query_func = query_route.endpoint
# Prepare Request
request = QueryRequest(
query="Test Query Direct",
mode="hybrid",
session_id=None # Should create new session
)
# Call the endpoint function directly
print("Calling query_text...")
response = await query_func(request)
print("Response received:", response)
# Verify DB
db = SessionLocal()
messages = db.query(ChatMessage).all()
print(f"Total messages: {len(messages)}")
for msg in messages:
print(f"Msg: {msg.content} ({msg.role}) at {msg.created_at}")
last_message = db.query(ChatMessage).filter(ChatMessage.role == "assistant").order_by(ChatMessage.created_at.desc()).first()
if last_message:
print(f"Last Assistant Message: {last_message.content}")
assert last_message.content == "This is a mocked response."
assert last_message.role == "assistant"
# Check user message
user_msg = db.query(ChatMessage).filter(ChatMessage.session_id == last_message.session_id, ChatMessage.role == "user").first()
assert user_msg.content == "Test Query Direct"
print("Verification Successful: History logged to DB.")
else:
print("Verification Failed: No message found in DB.")
db.close()
async def test_stream_integration():
print("\nTesting Stream Integration...")
# Mock RAG instance for streaming
mock_rag = MagicMock()
# Mock response iterator
async def response_iterator():
yield "Chunk 1 "
yield "Chunk 2"
mock_rag.aquery_llm = AsyncMock(return_value={
"llm_response": {
"is_streaming": True,
"response_iterator": response_iterator()
},
"data": {"references": [{"reference_id": "2", "file_path": "doc2.txt"}]}
})
from lightrag.api.routers.query_routes import router
router.routes = [] # Clear existing routes to avoid conflict
create_query_routes(mock_rag)
# Find the query_text_stream function
stream_route = next(r for r in router.routes if r.path == "/query/stream")
stream_func = stream_route.endpoint
# Prepare Request
request = QueryRequest(
query="Test Stream Direct",
mode="hybrid",
stream=True,
session_id=None
)
# Call the endpoint
print("Calling query_text_stream...")
response = await stream_func(request)
# Consume the stream
content = ""
async for chunk in response.body_iterator:
print(f"Chunk received: {chunk}")
content += chunk
print("Stream finished.")
# Verify DB
db = SessionLocal()
# Check for the new message
# We expect "Chunk 1 Chunk 2" as content
# Note: The chunk in body_iterator is NDJSON string, e.g. '{"response": "Chunk 1 "}\n'
# But the DB should contain the parsed content.
last_message = db.query(ChatMessage).filter(ChatMessage.role == "assistant", ChatMessage.content == "Chunk 1 Chunk 2").first()
if last_message:
print(f"Stream Message in DB: {last_message.content}")
assert last_message.content == "Chunk 1 Chunk 2"
print("Verification Successful: Stream History logged to DB.")
else:
print("Verification Failed: Stream message not found in DB.")
# Print all to debug
messages = db.query(ChatMessage).all()
for msg in messages:
print(f"Msg: {msg.content} ({msg.role})")
db.close()
if __name__ == "__main__":
asyncio.run(test_direct_integration())
asyncio.run(test_stream_integration())