improving search screen
This commit is contained in:
parent
c7ee3c37da
commit
7d3586e1b2
2 changed files with 235 additions and 23 deletions
|
|
@ -1,7 +1,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from textual.app import ComposeResult
|
from textual.app import ComposeResult
|
||||||
from textual.widgets import Input, Label, Static, Select
|
from textual.widgets import Input, Label, Static, Select, ListView, ListItem
|
||||||
from textual.containers import Container, Vertical, ScrollableContainer
|
from textual.containers import Container, Vertical
|
||||||
from textual.binding import Binding
|
from textual.binding import Binding
|
||||||
from cognee.cli.tui.base_screen import BaseTUIScreen
|
from cognee.cli.tui.base_screen import BaseTUIScreen
|
||||||
|
|
||||||
|
|
@ -46,7 +46,7 @@ class SearchTUIScreen(BaseTUIScreen):
|
||||||
margin-bottom: 1;
|
margin-bottom: 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
#results-content {
|
#results-list {
|
||||||
height: 1fr;
|
height: 1fr;
|
||||||
overflow-y: auto;
|
overflow-y: auto;
|
||||||
}
|
}
|
||||||
|
|
@ -78,19 +78,22 @@ class SearchTUIScreen(BaseTUIScreen):
|
||||||
)
|
)
|
||||||
with Container(id="results-container"):
|
with Container(id="results-container"):
|
||||||
yield Static("Results", id="results-title")
|
yield Static("Results", id="results-title")
|
||||||
with ScrollableContainer(id="results-content"):
|
yield ListView(id="results-list")
|
||||||
yield Static(
|
|
||||||
"Enter a query and click Search to see results.", id="results-text"
|
|
||||||
)
|
|
||||||
|
|
||||||
def compose_footer(self) -> ComposeResult:
|
def compose_footer(self) -> ComposeResult:
|
||||||
yield Static("Ctrl+S: Search • Esc: Back • q: Quit", classes="tui-footer")
|
yield Static("Ctrl+S: Search • Esc: Back • q: Quit", classes="tui-footer")
|
||||||
|
|
||||||
def on_mount(self) -> None:
|
def on_mount(self) -> None:
|
||||||
"""Focus the query input on mount."""
|
"""Focus the query input on mount and show initial help text."""
|
||||||
query_input = self.query_one("#query-input", Input)
|
query_input = self.query_one("#query-input", Input)
|
||||||
query_input.focus()
|
query_input.focus()
|
||||||
|
|
||||||
|
# Add initial help text to list
|
||||||
|
results_list = self.query_one("#results-list", ListView)
|
||||||
|
results_list.mount(
|
||||||
|
ListItem(Label("Enter a query and click Search to see results."))
|
||||||
|
)
|
||||||
|
|
||||||
def action_back(self) -> None:
|
def action_back(self) -> None:
|
||||||
"""Go back to home screen."""
|
"""Go back to home screen."""
|
||||||
self.app.pop_screen()
|
self.app.pop_screen()
|
||||||
|
|
@ -128,17 +131,23 @@ class SearchTUIScreen(BaseTUIScreen):
|
||||||
self.notify(f"Searching for: {query_text}", severity="information")
|
self.notify(f"Searching for: {query_text}", severity="information")
|
||||||
|
|
||||||
# Update results to show loading
|
# Update results to show loading
|
||||||
results_text = self.query_one("#results-text", Static)
|
results_list = self.query_one("#results-list", ListView)
|
||||||
results_text.update("🔍 Searching...")
|
results_list.clear()
|
||||||
|
results_list.mount(ListItem(Label("🔍 Searching...")))
|
||||||
|
|
||||||
# Run async search
|
# Run async search
|
||||||
asyncio.create_task(self._async_search(query_text, query_type))
|
asyncio.create_task(self._async_search(query_text, query_type))
|
||||||
|
|
||||||
async def _async_search(self, query_text: str, query_type: str) -> None:
|
async def _async_search(self, query_text: str, query_type: str) -> None:
|
||||||
"""Async search operation."""
|
"""Async search operation."""
|
||||||
|
results_list = self.query_one("#results-list", ListView)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import cognee
|
import cognee
|
||||||
from cognee.modules.search.types import SearchType
|
from cognee.modules.search.types import SearchType
|
||||||
|
from cognee.infrastructure.databases.exceptions.exceptions import (
|
||||||
|
EntityNotFoundError,
|
||||||
|
)
|
||||||
|
|
||||||
# Convert string to SearchType enum
|
# Convert string to SearchType enum
|
||||||
search_type = SearchType[query_type]
|
search_type = SearchType[query_type]
|
||||||
|
|
@ -150,29 +159,43 @@ class SearchTUIScreen(BaseTUIScreen):
|
||||||
top_k=10,
|
top_k=10,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update results display
|
# Clear loading message
|
||||||
results_text = self.query_one("#results-text", Static)
|
results_list.clear()
|
||||||
|
|
||||||
if not results:
|
if not results:
|
||||||
results_text.update("No results found for your query.")
|
results_list.mount(
|
||||||
|
ListItem(Label("No results found for your query."))
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Format results based on type
|
# Format results based on type
|
||||||
if query_type in ["GRAPH_COMPLETION", "RAG_COMPLETION"]:
|
if query_type in ["GRAPH_COMPLETION", "RAG_COMPLETION"]:
|
||||||
formatted = "\n\n".join([f"📝 {result}" for result in results])
|
for result in results:
|
||||||
|
results_list.mount(ListItem(Label(f"📝 {result}")))
|
||||||
elif query_type == "CHUNKS":
|
elif query_type == "CHUNKS":
|
||||||
formatted = "\n\n".join(
|
for i, result in enumerate(results):
|
||||||
[f"📄 Chunk {i + 1}:\n{result}" for i, result in enumerate(results)]
|
results_list.mount(
|
||||||
)
|
ListItem(Label(f"📄 Chunk {i + 1}:\n{result}"))
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
formatted = "\n\n".join([f"• {result}" for result in results])
|
for result in results:
|
||||||
|
results_list.mount(ListItem(Label(f"• {result}")))
|
||||||
|
|
||||||
|
self.notify(f"✓ Found {len(results)} result(s)", severity="information")
|
||||||
|
|
||||||
results_text.update(formatted)
|
except EntityNotFoundError:
|
||||||
|
results_list.clear()
|
||||||
self.notify(f"✓ Found {len(results)} result(s)", severity="information")
|
results_list.mount(
|
||||||
|
ListItem(
|
||||||
|
Label(
|
||||||
|
"No data found. Please run 'cognee cognify' to process your data first."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.notify("Knowledge graph is empty", severity="warning")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
results_text = self.query_one("#results-text", Static)
|
results_list.clear()
|
||||||
results_text.update(f"❌ Error: {str(e)}")
|
results_list.mount(ListItem(Label(f"❌ Error: {str(e)}")))
|
||||||
self.notify(f"Search failed: {str(e)}", severity="error")
|
self.notify(f"Search failed: {str(e)}", severity="error")
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
|
|
|
||||||
189
tests/reproduce_tui_search_logic.py
Normal file
189
tests/reproduce_tui_search_logic.py
Normal file
|
|
@ -0,0 +1,189 @@
|
||||||
|
import sys
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import MagicMock, AsyncMock
|
||||||
|
|
||||||
|
# 1. Setup Mocks for cognee dependencies
|
||||||
|
mock_cognee = MagicMock()
|
||||||
|
sys.modules["cognee"] = mock_cognee
|
||||||
|
sys.modules["cognee.version"] = MagicMock()
|
||||||
|
sys.modules["cognee.cli.tui.common_styles"] = MagicMock()
|
||||||
|
sys.modules["cognee.cli.tui.common_styles"].COMMON_STYLES = ""
|
||||||
|
|
||||||
|
# Define the exception we want to catch
|
||||||
|
class EntityNotFoundError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Setup the exception in the mocked module structure
|
||||||
|
sys.modules["cognee.infrastructure"] = MagicMock()
|
||||||
|
sys.modules["cognee.infrastructure.databases"] = MagicMock()
|
||||||
|
sys.modules["cognee.infrastructure.databases.exceptions"] = MagicMock()
|
||||||
|
exceptions_mock = MagicMock()
|
||||||
|
exceptions_mock.EntityNotFoundError = EntityNotFoundError
|
||||||
|
sys.modules["cognee.infrastructure.databases.exceptions.exceptions"] = exceptions_mock
|
||||||
|
|
||||||
|
# Setup search types
|
||||||
|
sys.modules["cognee.modules"] = MagicMock()
|
||||||
|
sys.modules["cognee.modules.search"] = MagicMock()
|
||||||
|
types_mock = MagicMock()
|
||||||
|
|
||||||
|
# Mock SearchType to support item access
|
||||||
|
class MockSearchTypeMeta(type):
|
||||||
|
def __getitem__(cls, key):
|
||||||
|
return key
|
||||||
|
class MockSearchType(metaclass=MockSearchTypeMeta):
|
||||||
|
pass
|
||||||
|
types_mock.SearchType = MockSearchType
|
||||||
|
sys.modules["cognee.modules.search.types"] = types_mock
|
||||||
|
|
||||||
|
import importlib.util
|
||||||
|
import os
|
||||||
|
|
||||||
|
# ... existing mocks ...
|
||||||
|
# Ensure we have deep mocks for structure
|
||||||
|
sys.modules["cognee.cli"] = MagicMock()
|
||||||
|
sys.modules["cognee.cli.tui"] = MagicMock()
|
||||||
|
|
||||||
|
# Mock BaseTUIScreen specifically
|
||||||
|
base_screen_mock = MagicMock()
|
||||||
|
class MockBaseScreen:
|
||||||
|
CSS = ""
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
def compose_header(self): yield from ()
|
||||||
|
def compose_footer(self): yield from ()
|
||||||
|
base_screen_mock.BaseTUIScreen = MockBaseScreen
|
||||||
|
# Crucial: Mock the specific module path search_screen tries to import from
|
||||||
|
sys.modules["cognee.cli.tui.base_screen"] = base_screen_mock
|
||||||
|
|
||||||
|
# Also mock textual.binding which is imported at top level
|
||||||
|
sys.modules["textual.binding"] = MagicMock()
|
||||||
|
|
||||||
|
# Now load the file directly
|
||||||
|
module_path = os.path.join(os.getcwd(), "cognee/cli/tui/search_screen.py")
|
||||||
|
spec = importlib.util.spec_from_file_location("search_screen_mod", module_path)
|
||||||
|
search_screen_mod = importlib.util.module_from_spec(spec)
|
||||||
|
|
||||||
|
# Before executing, ensure imports in that file will resolve to our mocks
|
||||||
|
# The file does: from textual...
|
||||||
|
# Real Textual is explicitly NOT mocked in sys.modules so it loads real textual (if installed)
|
||||||
|
# But we mocked textual.binding above?
|
||||||
|
# Actually, let's NOT mock textual.binding if we can avoid it, or mock it if it's simple.
|
||||||
|
# Real code: from textual.binding import Binding.
|
||||||
|
# If textual is installed, we should leverage it. If not, mock it.
|
||||||
|
try:
|
||||||
|
import textual
|
||||||
|
except ImportError:
|
||||||
|
# If textual not installed/available in this step runner, we must mock it all
|
||||||
|
sys.modules["textual"] = MagicMock()
|
||||||
|
sys.modules["textual.app"] = MagicMock()
|
||||||
|
sys.modules["textual.widgets"] = MagicMock()
|
||||||
|
sys.modules["textual.containers"] = MagicMock()
|
||||||
|
sys.modules["textual.binding"] = MagicMock()
|
||||||
|
# We need to provide Widget classes that search_screen inherits/uses
|
||||||
|
# It imports: Input, Label, Static, Select, ListView, ListItem
|
||||||
|
# It uses: ComposeResult (type)
|
||||||
|
|
||||||
|
# Simple Mock widgets
|
||||||
|
class MockWidget:
|
||||||
|
def __init__(self, *args, **kwargs): pass
|
||||||
|
def focus(self): pass
|
||||||
|
def mount(self, *args): pass
|
||||||
|
def clear(self): pass
|
||||||
|
def update(self, *args): pass
|
||||||
|
|
||||||
|
sys.modules["textual.widgets"].Input = MockWidget
|
||||||
|
sys.modules["textual.widgets"].Label = MockWidget
|
||||||
|
sys.modules["textual.widgets"].Static = MockWidget
|
||||||
|
sys.modules["textual.widgets"].Select = MockWidget
|
||||||
|
sys.modules["textual.widgets"].ListView = MockWidget
|
||||||
|
sys.modules["textual.widgets"].ListItem = MockWidget
|
||||||
|
|
||||||
|
sys.modules["textual.containers"].Container = MagicMock()
|
||||||
|
sys.modules["textual.containers"].Vertical = MagicMock()
|
||||||
|
sys.modules["textual.app"].ComposeResult = MagicMock()
|
||||||
|
|
||||||
|
# Execute the module
|
||||||
|
spec.loader.exec_module(search_screen_mod)
|
||||||
|
SearchTUIScreen = search_screen_mod.SearchTUIScreen
|
||||||
|
|
||||||
|
async def test_empty_graph_handling():
|
||||||
|
print("Testing Empty Graph (EntityNotFoundError) Handling...")
|
||||||
|
|
||||||
|
# Instantiate screen
|
||||||
|
screen = SearchTUIScreen()
|
||||||
|
|
||||||
|
# Mock query_one to return our list view mock
|
||||||
|
results_list_mock = MagicMock()
|
||||||
|
|
||||||
|
def query_one_side_effect(selector, type_cls=None):
|
||||||
|
if "ListView" in str(type_cls) or "list" in str(selector):
|
||||||
|
return results_list_mock
|
||||||
|
return MagicMock() # For other queries like Static
|
||||||
|
|
||||||
|
screen.query_one = MagicMock(side_effect=query_one_side_effect)
|
||||||
|
screen.notify = MagicMock()
|
||||||
|
|
||||||
|
# Configure cognee.search to raise EntityNotFoundError
|
||||||
|
mock_cognee.search = AsyncMock(side_effect=EntityNotFoundError("Graph is empty"))
|
||||||
|
|
||||||
|
# Run the method
|
||||||
|
# Note: query_type needs to be a valid key in our MockSearchType or just a string if we mocked it right
|
||||||
|
await screen._async_search("test query", "GRAPH_COMPLETION")
|
||||||
|
|
||||||
|
# Verification
|
||||||
|
# 1. usage of clear()
|
||||||
|
results_list_mock.clear.assert_called()
|
||||||
|
|
||||||
|
# 2. usage of mount() with correct message
|
||||||
|
assert results_list_mock.mount.called
|
||||||
|
|
||||||
|
# Check that notify was called with warning
|
||||||
|
# We allow flexible matching for the exact message but check 'warning' severity
|
||||||
|
args = screen.notify.call_args
|
||||||
|
if args:
|
||||||
|
assert args[1].get('severity') == 'warning' or 'Knowledge graph is empty' in args[0][0]
|
||||||
|
else:
|
||||||
|
print("FAIL: notify was not called")
|
||||||
|
|
||||||
|
print("SUCCESS: EntityNotFoundError was caught and handled correctly.")
|
||||||
|
|
||||||
|
async def test_generic_error_handling():
|
||||||
|
print("\nTesting Generic Error Handling...")
|
||||||
|
screen = SearchTUIScreen()
|
||||||
|
results_list_mock = MagicMock()
|
||||||
|
screen.query_one = MagicMock(return_value=results_list_mock)
|
||||||
|
screen.notify = MagicMock()
|
||||||
|
|
||||||
|
# Configure generic error
|
||||||
|
mock_cognee.search = AsyncMock(side_effect=Exception("Something bad happened"))
|
||||||
|
|
||||||
|
await screen._async_search("test query", "GRAPH_COMPLETION")
|
||||||
|
|
||||||
|
screen.notify.assert_called()
|
||||||
|
args = screen.notify.call_args
|
||||||
|
# Check for error severity or message
|
||||||
|
assert args[1].get('severity') == 'error' or "Search failed" in args[0][0]
|
||||||
|
print("SUCCESS: Generic Exception was caught and handled correctly.")
|
||||||
|
|
||||||
|
async def test_success_path():
|
||||||
|
print("\nTesting Success Path...")
|
||||||
|
screen = SearchTUIScreen()
|
||||||
|
results_list_mock = MagicMock()
|
||||||
|
screen.query_one = MagicMock(return_value=results_list_mock)
|
||||||
|
screen.notify = MagicMock()
|
||||||
|
|
||||||
|
# Configure success
|
||||||
|
mock_cognee.search = AsyncMock(return_value=["Result 1", "Result 2"])
|
||||||
|
|
||||||
|
await screen._async_search("test query", "GRAPH_COMPLETION")
|
||||||
|
|
||||||
|
assert results_list_mock.clear.called
|
||||||
|
assert results_list_mock.mount.called
|
||||||
|
# We expect 2 mount calls for results
|
||||||
|
assert results_list_mock.mount.call_count == 2
|
||||||
|
print("SUCCESS: Search results were displayed.")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(test_empty_graph_handling())
|
||||||
|
asyncio.run(test_generic_error_handling())
|
||||||
|
asyncio.run(test_success_path())
|
||||||
Loading…
Add table
Reference in a new issue