tests: Resolve failing search tests
This commit is contained in:
parent
52b25882b3
commit
94bc0ef47f
5 changed files with 39 additions and 28 deletions
|
|
@ -45,15 +45,13 @@ async def relational_db_migration():
|
|||
await migrate_relational_database(graph_engine, schema=schema)
|
||||
|
||||
# 1. Search the graph
|
||||
search_results: List[SearchResult] = await cognee.search(
|
||||
search_results = await cognee.search(
|
||||
query_type=SearchType.GRAPH_COMPLETION, query_text="Tell me about the artist AC/DC"
|
||||
) # type: ignore
|
||||
)
|
||||
print("Search results:", search_results)
|
||||
|
||||
# 2. Assert that the search results contain "AC/DC"
|
||||
assert any("AC/DC" in r.search_result for r in search_results), (
|
||||
"AC/DC not found in search results!"
|
||||
)
|
||||
assert any("AC/DC" in r for r in search_results), "AC/DC not found in search results!"
|
||||
|
||||
migration_db_provider = migration_engine.engine.dialect.name
|
||||
if migration_db_provider == "postgresql":
|
||||
|
|
|
|||
|
|
@ -144,13 +144,16 @@ async def main():
|
|||
("GRAPH_COMPLETION_CONTEXT_EXTENSION", completion_ext),
|
||||
("GRAPH_SUMMARY_COMPLETION", completion_sum),
|
||||
]:
|
||||
for search_result in search_results:
|
||||
completion = search_result.search_result
|
||||
assert isinstance(completion, str), f"{name}: should return a string"
|
||||
assert completion.strip(), f"{name}: string should not be empty"
|
||||
assert "netherlands" in completion.lower(), (
|
||||
f"{name}: expected 'netherlands' in result, got: {completion!r}"
|
||||
)
|
||||
assert isinstance(search_results, list), f"{name}: should return a list"
|
||||
assert len(search_results) == 1, (
|
||||
f"{name}: expected single-element list, got {len(search_results)}"
|
||||
)
|
||||
text = search_results[0]
|
||||
assert isinstance(text, str), f"{name}: element should be a string"
|
||||
assert text.strip(), f"{name}: string should not be empty"
|
||||
assert "netherlands" in text.lower(), (
|
||||
f"{name}: expected 'netherlands' in result, got: {text!r}"
|
||||
)
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
graph = await graph_engine.get_graph_data()
|
||||
|
|
|
|||
|
|
@ -59,8 +59,10 @@ class TestGraphCompletionWithContextExtensionRetriever:
|
|||
|
||||
answer = await retriever.get_completion("Who works at Canva?")
|
||||
|
||||
assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}"
|
||||
assert answer.strip(), "Answer must contain only non-empty strings"
|
||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||
"Answer must contain only non-empty strings"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_completion_extension_context_complex(self):
|
||||
|
|
@ -140,8 +142,10 @@ class TestGraphCompletionWithContextExtensionRetriever:
|
|||
|
||||
answer = await retriever.get_completion("Who works at Figma?")
|
||||
|
||||
assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}"
|
||||
assert answer.strip(), "Answer must contain only non-empty strings"
|
||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||
"Answer must contain only non-empty strings"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_graph_completion_extension_context_on_empty_graph(self):
|
||||
|
|
@ -171,5 +175,7 @@ class TestGraphCompletionWithContextExtensionRetriever:
|
|||
|
||||
answer = await retriever.get_completion("Who works at Figma?")
|
||||
|
||||
assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}"
|
||||
assert answer.strip(), "Answer must contain only non-empty strings"
|
||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||
"Answer must contain only non-empty strings"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -55,8 +55,10 @@ class TestGraphCompletionCoTRetriever:
|
|||
|
||||
answer = await retriever.get_completion("Who works at Canva?")
|
||||
|
||||
assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}"
|
||||
assert answer.strip(), "Answer must contain only non-empty strings"
|
||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||
"Answer must contain only non-empty strings"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_completion_cot_context_complex(self):
|
||||
|
|
@ -133,8 +135,10 @@ class TestGraphCompletionCoTRetriever:
|
|||
|
||||
answer = await retriever.get_completion("Who works at Figma?")
|
||||
|
||||
assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}"
|
||||
assert answer.strip(), "Answer must contain only non-empty strings"
|
||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||
"Answer must contain only non-empty strings"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_graph_completion_cot_context_on_empty_graph(self):
|
||||
|
|
@ -164,5 +168,7 @@ class TestGraphCompletionCoTRetriever:
|
|||
|
||||
answer = await retriever.get_completion("Who works at Figma?")
|
||||
|
||||
assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}"
|
||||
assert answer.strip(), "Answer must contain only non-empty strings"
|
||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||
"Answer must contain only non-empty strings"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -82,7 +82,7 @@ class TestInsightsRetriever:
|
|||
|
||||
context = await retriever.get_context("Mike")
|
||||
|
||||
assert context[0].node1.attributes["name"] == "Mike Broski", "Failed to get Mike Broski"
|
||||
assert context[0][0]["name"] == "Mike Broski", "Failed to get Mike Broski"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_insights_context_complex(self):
|
||||
|
|
@ -222,9 +222,7 @@ class TestInsightsRetriever:
|
|||
|
||||
context = await retriever.get_context("Christina")
|
||||
|
||||
assert context[0].node1.attributes["name"] == "Christina Mayer", (
|
||||
"Failed to get Christina Mayer"
|
||||
)
|
||||
assert context[0][0]["name"] == "Christina Mayer", "Failed to get Christina Mayer"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_insights_context_on_empty_graph(self):
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue