improving search screen

This commit is contained in:
rajeevrajeshuni 2025-12-29 20:12:50 +05:30
parent c7ee3c37da
commit 7d3586e1b2
2 changed files with 235 additions and 23 deletions

View file

@ -1,7 +1,7 @@
import asyncio
from textual.app import ComposeResult
from textual.widgets import Input, Label, Static, Select
from textual.containers import Container, Vertical, ScrollableContainer
from textual.widgets import Input, Label, Static, Select, ListView, ListItem
from textual.containers import Container, Vertical
from textual.binding import Binding
from cognee.cli.tui.base_screen import BaseTUIScreen
@ -46,7 +46,7 @@ class SearchTUIScreen(BaseTUIScreen):
margin-bottom: 1;
}
#results-content {
#results-list {
height: 1fr;
overflow-y: auto;
}
@ -78,19 +78,22 @@ class SearchTUIScreen(BaseTUIScreen):
)
with Container(id="results-container"):
yield Static("Results", id="results-title")
with ScrollableContainer(id="results-content"):
yield Static(
"Enter a query and click Search to see results.", id="results-text"
)
yield ListView(id="results-list")
def compose_footer(self) -> ComposeResult:
yield Static("Ctrl+S: Search • Esc: Back • q: Quit", classes="tui-footer")
def on_mount(self) -> None:
"""Focus the query input on mount."""
"""Focus the query input on mount and show initial help text."""
query_input = self.query_one("#query-input", Input)
query_input.focus()
# Add initial help text to list
results_list = self.query_one("#results-list", ListView)
results_list.mount(
ListItem(Label("Enter a query and click Search to see results."))
)
def action_back(self) -> None:
"""Go back to home screen."""
self.app.pop_screen()
@ -128,17 +131,23 @@ class SearchTUIScreen(BaseTUIScreen):
self.notify(f"Searching for: {query_text}", severity="information")
# Update results to show loading
results_text = self.query_one("#results-text", Static)
results_text.update("🔍 Searching...")
results_list = self.query_one("#results-list", ListView)
results_list.clear()
results_list.mount(ListItem(Label("🔍 Searching...")))
# Run async search
asyncio.create_task(self._async_search(query_text, query_type))
async def _async_search(self, query_text: str, query_type: str) -> None:
"""Async search operation."""
results_list = self.query_one("#results-list", ListView)
try:
import cognee
from cognee.modules.search.types import SearchType
from cognee.infrastructure.databases.exceptions.exceptions import (
EntityNotFoundError,
)
# Convert string to SearchType enum
search_type = SearchType[query_type]
@ -150,29 +159,43 @@ class SearchTUIScreen(BaseTUIScreen):
top_k=10,
)
# Update results display
results_text = self.query_one("#results-text", Static)
# Clear loading message
results_list.clear()
if not results:
results_text.update("No results found for your query.")
results_list.mount(
ListItem(Label("No results found for your query."))
)
else:
# Format results based on type
if query_type in ["GRAPH_COMPLETION", "RAG_COMPLETION"]:
formatted = "\n\n".join([f"📝 {result}" for result in results])
for result in results:
results_list.mount(ListItem(Label(f"📝 {result}")))
elif query_type == "CHUNKS":
formatted = "\n\n".join(
[f"📄 Chunk {i + 1}:\n{result}" for i, result in enumerate(results)]
)
for i, result in enumerate(results):
results_list.mount(
ListItem(Label(f"📄 Chunk {i + 1}:\n{result}"))
)
else:
formatted = "\n\n".join([f"{result}" for result in results])
for result in results:
results_list.mount(ListItem(Label(f"{result}")))
self.notify(f"✓ Found {len(results)} result(s)", severity="information")
results_text.update(formatted)
self.notify(f"✓ Found {len(results)} result(s)", severity="information")
except EntityNotFoundError:
results_list.clear()
results_list.mount(
ListItem(
Label(
"No data found. Please run 'cognee cognify' to process your data first."
)
)
)
self.notify("Knowledge graph is empty", severity="warning")
except Exception as e:
results_text = self.query_one("#results-text", Static)
results_text.update(f"❌ Error: {str(e)}")
results_list.clear()
results_list.mount(ListItem(Label(f"❌ Error: {str(e)}")))
self.notify(f"Search failed: {str(e)}", severity="error")
finally:

View file

@ -0,0 +1,189 @@
import sys
import asyncio
from unittest.mock import MagicMock, AsyncMock
# 1. Setup Mocks for cognee dependencies
mock_cognee = MagicMock()
sys.modules["cognee"] = mock_cognee
sys.modules["cognee.version"] = MagicMock()
sys.modules["cognee.cli.tui.common_styles"] = MagicMock()
sys.modules["cognee.cli.tui.common_styles"].COMMON_STYLES = ""
# Define the exception we want to catch
class EntityNotFoundError(Exception):
pass
# Setup the exception in the mocked module structure
sys.modules["cognee.infrastructure"] = MagicMock()
sys.modules["cognee.infrastructure.databases"] = MagicMock()
sys.modules["cognee.infrastructure.databases.exceptions"] = MagicMock()
exceptions_mock = MagicMock()
exceptions_mock.EntityNotFoundError = EntityNotFoundError
sys.modules["cognee.infrastructure.databases.exceptions.exceptions"] = exceptions_mock
# Setup search types
sys.modules["cognee.modules"] = MagicMock()
sys.modules["cognee.modules.search"] = MagicMock()
types_mock = MagicMock()
# Mock SearchType to support item access
class MockSearchTypeMeta(type):
def __getitem__(cls, key):
return key
class MockSearchType(metaclass=MockSearchTypeMeta):
pass
types_mock.SearchType = MockSearchType
sys.modules["cognee.modules.search.types"] = types_mock
import importlib.util
import os
# ... existing mocks ...
# Ensure we have deep mocks for structure
sys.modules["cognee.cli"] = MagicMock()
sys.modules["cognee.cli.tui"] = MagicMock()
# Mock BaseTUIScreen specifically
base_screen_mock = MagicMock()
class MockBaseScreen:
CSS = ""
def __init__(self):
pass
def compose_header(self): yield from ()
def compose_footer(self): yield from ()
base_screen_mock.BaseTUIScreen = MockBaseScreen
# Crucial: Mock the specific module path search_screen tries to import from
sys.modules["cognee.cli.tui.base_screen"] = base_screen_mock
# Also mock textual.binding which is imported at top level
sys.modules["textual.binding"] = MagicMock()
# Now load the file directly
module_path = os.path.join(os.getcwd(), "cognee/cli/tui/search_screen.py")
spec = importlib.util.spec_from_file_location("search_screen_mod", module_path)
search_screen_mod = importlib.util.module_from_spec(spec)
# Before executing, ensure imports in that file will resolve to our mocks
# The file does: from textual...
# Real Textual is explicitly NOT mocked in sys.modules so it loads real textual (if installed)
# But we mocked textual.binding above?
# Actually, let's NOT mock textual.binding if we can avoid it, or mock it if it's simple.
# Real code: from textual.binding import Binding.
# If textual is installed, we should leverage it. If not, mock it.
try:
import textual
except ImportError:
# If textual not installed/available in this step runner, we must mock it all
sys.modules["textual"] = MagicMock()
sys.modules["textual.app"] = MagicMock()
sys.modules["textual.widgets"] = MagicMock()
sys.modules["textual.containers"] = MagicMock()
sys.modules["textual.binding"] = MagicMock()
# We need to provide Widget classes that search_screen inherits/uses
# It imports: Input, Label, Static, Select, ListView, ListItem
# It uses: ComposeResult (type)
# Simple Mock widgets
class MockWidget:
def __init__(self, *args, **kwargs): pass
def focus(self): pass
def mount(self, *args): pass
def clear(self): pass
def update(self, *args): pass
sys.modules["textual.widgets"].Input = MockWidget
sys.modules["textual.widgets"].Label = MockWidget
sys.modules["textual.widgets"].Static = MockWidget
sys.modules["textual.widgets"].Select = MockWidget
sys.modules["textual.widgets"].ListView = MockWidget
sys.modules["textual.widgets"].ListItem = MockWidget
sys.modules["textual.containers"].Container = MagicMock()
sys.modules["textual.containers"].Vertical = MagicMock()
sys.modules["textual.app"].ComposeResult = MagicMock()
# Execute the module
spec.loader.exec_module(search_screen_mod)
SearchTUIScreen = search_screen_mod.SearchTUIScreen
async def test_empty_graph_handling():
print("Testing Empty Graph (EntityNotFoundError) Handling...")
# Instantiate screen
screen = SearchTUIScreen()
# Mock query_one to return our list view mock
results_list_mock = MagicMock()
def query_one_side_effect(selector, type_cls=None):
if "ListView" in str(type_cls) or "list" in str(selector):
return results_list_mock
return MagicMock() # For other queries like Static
screen.query_one = MagicMock(side_effect=query_one_side_effect)
screen.notify = MagicMock()
# Configure cognee.search to raise EntityNotFoundError
mock_cognee.search = AsyncMock(side_effect=EntityNotFoundError("Graph is empty"))
# Run the method
# Note: query_type needs to be a valid key in our MockSearchType or just a string if we mocked it right
await screen._async_search("test query", "GRAPH_COMPLETION")
# Verification
# 1. usage of clear()
results_list_mock.clear.assert_called()
# 2. usage of mount() with correct message
assert results_list_mock.mount.called
# Check that notify was called with warning
# We allow flexible matching for the exact message but check 'warning' severity
args = screen.notify.call_args
if args:
assert args[1].get('severity') == 'warning' or 'Knowledge graph is empty' in args[0][0]
else:
print("FAIL: notify was not called")
print("SUCCESS: EntityNotFoundError was caught and handled correctly.")
async def test_generic_error_handling():
print("\nTesting Generic Error Handling...")
screen = SearchTUIScreen()
results_list_mock = MagicMock()
screen.query_one = MagicMock(return_value=results_list_mock)
screen.notify = MagicMock()
# Configure generic error
mock_cognee.search = AsyncMock(side_effect=Exception("Something bad happened"))
await screen._async_search("test query", "GRAPH_COMPLETION")
screen.notify.assert_called()
args = screen.notify.call_args
# Check for error severity or message
assert args[1].get('severity') == 'error' or "Search failed" in args[0][0]
print("SUCCESS: Generic Exception was caught and handled correctly.")
async def test_success_path():
print("\nTesting Success Path...")
screen = SearchTUIScreen()
results_list_mock = MagicMock()
screen.query_one = MagicMock(return_value=results_list_mock)
screen.notify = MagicMock()
# Configure success
mock_cognee.search = AsyncMock(return_value=["Result 1", "Result 2"])
await screen._async_search("test query", "GRAPH_COMPLETION")
assert results_list_mock.clear.called
assert results_list_mock.mount.called
# We expect 2 mount calls for results
assert results_list_mock.mount.call_count == 2
print("SUCCESS: Search results were displayed.")
if __name__ == "__main__":
asyncio.run(test_empty_graph_handling())
asyncio.run(test_generic_error_handling())
asyncio.run(test_success_path())