159 lines
5.2 KiB
Python
159 lines
5.2 KiB
Python
import sys
|
|
import os
|
|
import asyncio
|
|
import uuid
|
|
from unittest.mock import MagicMock, AsyncMock
|
|
|
|
# Set SQLite for testing
|
|
db_path = "./test.db"
|
|
if os.path.exists(db_path):
|
|
os.remove(db_path)
|
|
os.environ["DATABASE_URL"] = f"sqlite:///{db_path}"
|
|
|
|
# Add project root to path
|
|
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../"))
|
|
sys.path.append(project_root)
|
|
sys.path.append(os.path.join(project_root, "service"))
|
|
|
|
# Mocking removed to allow real imports
|
|
|
|
# Now import the modified router
|
|
from lightrag.api.routers.query_routes import create_query_routes, QueryRequest
|
|
|
|
# Import service DB to check if records are created
|
|
from app.core.database import SessionLocal, engine, Base
|
|
from app.models.models import ChatMessage, ChatSession
|
|
|
|
# Create tables
|
|
Base.metadata.create_all(bind=engine)
|
|
|
|
async def test_direct_integration():
|
|
print("Testing Direct Integration...")
|
|
|
|
# Mock RAG instance
|
|
mock_rag = MagicMock()
|
|
mock_rag.aquery_llm = AsyncMock(return_value={
|
|
"llm_response": {"content": "This is a mocked response."},
|
|
"data": {"references": [{"reference_id": "1", "file_path": "doc1.txt"}]}
|
|
})
|
|
|
|
# Create router (this registers the endpoints but we'll call the function directly for testing)
|
|
# We need to access the function decorated by @router.post("/query")
|
|
# Since we can't easily get the route function from the router object without starting FastAPI,
|
|
# we will inspect the router.routes
|
|
|
|
create_query_routes(mock_rag)
|
|
|
|
from lightrag.api.routers.query_routes import router
|
|
|
|
# Find the query_text function
|
|
query_route = next(r for r in router.routes if r.path == "/query")
|
|
query_func = query_route.endpoint
|
|
|
|
# Prepare Request
|
|
request = QueryRequest(
|
|
query="Test Query Direct",
|
|
mode="hybrid",
|
|
session_id=None # Should create new session
|
|
)
|
|
|
|
# Call the endpoint function directly
|
|
print("Calling query_text...")
|
|
response = await query_func(request)
|
|
print("Response received:", response)
|
|
|
|
# Verify DB
|
|
db = SessionLocal()
|
|
messages = db.query(ChatMessage).all()
|
|
print(f"Total messages: {len(messages)}")
|
|
for msg in messages:
|
|
print(f"Msg: {msg.content} ({msg.role}) at {msg.created_at}")
|
|
|
|
last_message = db.query(ChatMessage).filter(ChatMessage.role == "assistant").order_by(ChatMessage.created_at.desc()).first()
|
|
|
|
if last_message:
|
|
print(f"Last Assistant Message: {last_message.content}")
|
|
assert last_message.content == "This is a mocked response."
|
|
assert last_message.role == "assistant"
|
|
|
|
# Check user message
|
|
user_msg = db.query(ChatMessage).filter(ChatMessage.session_id == last_message.session_id, ChatMessage.role == "user").first()
|
|
assert user_msg.content == "Test Query Direct"
|
|
print("Verification Successful: History logged to DB.")
|
|
else:
|
|
print("Verification Failed: No message found in DB.")
|
|
|
|
db.close()
|
|
|
|
async def test_stream_integration():
|
|
print("\nTesting Stream Integration...")
|
|
|
|
# Mock RAG instance for streaming
|
|
mock_rag = MagicMock()
|
|
|
|
# Mock response iterator
|
|
async def response_iterator():
|
|
yield "Chunk 1 "
|
|
yield "Chunk 2"
|
|
|
|
mock_rag.aquery_llm = AsyncMock(return_value={
|
|
"llm_response": {
|
|
"is_streaming": True,
|
|
"response_iterator": response_iterator()
|
|
},
|
|
"data": {"references": [{"reference_id": "2", "file_path": "doc2.txt"}]}
|
|
})
|
|
|
|
from lightrag.api.routers.query_routes import router
|
|
router.routes = [] # Clear existing routes to avoid conflict
|
|
create_query_routes(mock_rag)
|
|
|
|
# Find the query_text_stream function
|
|
stream_route = next(r for r in router.routes if r.path == "/query/stream")
|
|
stream_func = stream_route.endpoint
|
|
|
|
# Prepare Request
|
|
request = QueryRequest(
|
|
query="Test Stream Direct",
|
|
mode="hybrid",
|
|
stream=True,
|
|
session_id=None
|
|
)
|
|
|
|
# Call the endpoint
|
|
print("Calling query_text_stream...")
|
|
response = await stream_func(request)
|
|
|
|
# Consume the stream
|
|
content = ""
|
|
async for chunk in response.body_iterator:
|
|
print(f"Chunk received: {chunk}")
|
|
content += chunk
|
|
|
|
print("Stream finished.")
|
|
|
|
# Verify DB
|
|
db = SessionLocal()
|
|
# Check for the new message
|
|
# We expect "Chunk 1 Chunk 2" as content
|
|
# Note: The chunk in body_iterator is NDJSON string, e.g. '{"response": "Chunk 1 "}\n'
|
|
# But the DB should contain the parsed content.
|
|
|
|
last_message = db.query(ChatMessage).filter(ChatMessage.role == "assistant", ChatMessage.content == "Chunk 1 Chunk 2").first()
|
|
|
|
if last_message:
|
|
print(f"Stream Message in DB: {last_message.content}")
|
|
assert last_message.content == "Chunk 1 Chunk 2"
|
|
print("Verification Successful: Stream History logged to DB.")
|
|
else:
|
|
print("Verification Failed: Stream message not found in DB.")
|
|
# Print all to debug
|
|
messages = db.query(ChatMessage).all()
|
|
for msg in messages:
|
|
print(f"Msg: {msg.content} ({msg.role})")
|
|
|
|
db.close()
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(test_direct_integration())
|
|
asyncio.run(test_stream_integration())
|