Mcp test (#980)
<!-- .github/pull_request_template.md --> ## Description Add test of MCP functionality and starting of MCP server, fix some MCP and LanceDB issues ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
parent
21a4217301
commit
456f3b58c0
5 changed files with 588 additions and 1 deletions
52
.github/workflows/test_mcp.yml
vendored
Normal file
52
.github/workflows/test_mcp.yml
vendored
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
name: test | mcp
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
|
||||
jobs:
|
||||
test-mcp:
|
||||
name: Run MCP Test
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ inputs.python-version }}
|
||||
|
||||
- name: Install UV
|
||||
shell: bash
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install uv
|
||||
|
||||
# This will install all dependencies along with Cognee version deployed on PIP
|
||||
- name: Install dependencies
|
||||
shell: bash
|
||||
working-directory: cognee-mcp
|
||||
run: uv sync
|
||||
|
||||
# NEW: swap in current local cognee branch version
|
||||
- name: Override with cognee branch checkout
|
||||
working-directory: cognee-mcp
|
||||
run: |
|
||||
# Remove Cognee wheel that came from PyPI
|
||||
uv pip uninstall cognee
|
||||
# Install of the freshly-checked-out Cognee branch
|
||||
uv pip install --no-deps --force-reinstall -e ../
|
||||
|
||||
- name: Run MCP test
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
working-directory: cognee-mcp
|
||||
run: uv run --no-sync python ./src/test_client.py
|
||||
8
.github/workflows/test_suites.yml
vendored
8
.github/workflows/test_suites.yml
vendored
|
|
@ -77,6 +77,12 @@ jobs:
|
|||
uses: ./.github/workflows/examples_tests.yml
|
||||
secrets: inherit
|
||||
|
||||
mcp-test:
|
||||
name: Example Tests
|
||||
needs: [ basic-tests, e2e-tests ]
|
||||
uses: ./.github/workflows/test_mcp.yml
|
||||
secrets: inherit
|
||||
|
||||
db-examples-tests:
|
||||
name: DB Examples Tests
|
||||
needs: [vector-db-tests, graph-db-tests, relational-db-migration-tests]
|
||||
|
|
@ -102,6 +108,7 @@ jobs:
|
|||
vector-db-tests,
|
||||
example-tests,
|
||||
gemini-tests,
|
||||
mcp-test,
|
||||
relational-db-migration-tests,
|
||||
docker-compose-test,
|
||||
docker-ci-test,
|
||||
|
|
@ -120,6 +127,7 @@ jobs:
|
|||
vector-db-tests,
|
||||
example-tests,
|
||||
db-examples-tests,
|
||||
mcp-test,
|
||||
gemini-tests,
|
||||
ollama-tests,
|
||||
relational-db-migration-tests,
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ async def cognee_add_developer_rules(
|
|||
with redirect_stdout(sys.stderr):
|
||||
logger.info(f"Starting cognify for: {file_path}")
|
||||
try:
|
||||
await cognee.add(file_path, nodeset="developer_rules")
|
||||
await cognee.add(file_path, node_set=["developer_rules"])
|
||||
model = KnowledgeGraph
|
||||
if graph_model_file and graph_model_name:
|
||||
model = load_class(graph_model_file, graph_model_name)
|
||||
|
|
@ -82,6 +82,7 @@ async def cognee_add_developer_rules(
|
|||
logger.info(f"Cognify finished for: {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Cognify failed for {file_path}: {str(e)}")
|
||||
raise ValueError(f"Failed to cognify: {str(e)}")
|
||||
|
||||
tasks = []
|
||||
for rel_path in developer_rule_paths:
|
||||
|
|
|
|||
522
cognee-mcp/src/test_client.py
Executable file
522
cognee-mcp/src/test_client.py
Executable file
|
|
@ -0,0 +1,522 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test client for Cognee MCP Server functionality.
|
||||
|
||||
This script tests all the tools and functions available in the Cognee MCP server,
|
||||
including cognify, codify, search, prune, status checks, and utility functions.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from cognee.shared.logging_utils import setup_logging
|
||||
|
||||
|
||||
from cognee.modules.pipelines.models.PipelineRun import PipelineRunStatus
|
||||
from cognee.infrastructure.databases.exceptions import DatabaseNotCreatedError
|
||||
from src.server import (
|
||||
cognify,
|
||||
codify,
|
||||
search,
|
||||
prune,
|
||||
cognify_status,
|
||||
codify_status,
|
||||
cognee_add_developer_rules,
|
||||
node_to_string,
|
||||
retrieved_edges_to_string,
|
||||
load_class,
|
||||
)
|
||||
|
||||
# Import MCP client functionality for server testing
|
||||
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
|
||||
|
||||
# Set timeout for cognify/codify to complete in
|
||||
TIMEOUT = 5 * 60 # 5 min in seconds
|
||||
|
||||
|
||||
class CogneeTestClient:
|
||||
"""Test client for Cognee MCP Server functionality."""
|
||||
|
||||
def __init__(self):
|
||||
self.test_results = {}
|
||||
self.temp_files = []
|
||||
|
||||
async def setup(self):
|
||||
"""Setup test environment."""
|
||||
print("🔧 Setting up test environment...")
|
||||
|
||||
# Create temporary test files
|
||||
self.test_data_dir = tempfile.mkdtemp(prefix="cognee_test_")
|
||||
|
||||
# Create a test text file
|
||||
self.test_text_file = os.path.join(self.test_data_dir, "test.txt")
|
||||
with open(self.test_text_file, "w") as f:
|
||||
f.write(
|
||||
"This is a test document for Cognee testing. It contains information about AI and knowledge graphs."
|
||||
)
|
||||
|
||||
# Create a test code repository structure
|
||||
self.test_repo_dir = os.path.join(self.test_data_dir, "test_repo")
|
||||
os.makedirs(self.test_repo_dir)
|
||||
|
||||
# Create test Python files
|
||||
test_py_file = os.path.join(self.test_repo_dir, "main.py")
|
||||
with open(test_py_file, "w") as f:
|
||||
f.write("""
|
||||
def hello_world():
|
||||
'''A simple hello world function.'''
|
||||
return "Hello, World!"
|
||||
|
||||
class TestClass:
|
||||
'''A test class for demonstration.'''
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def greet(self):
|
||||
return f"Hello, {self.name}!"
|
||||
""")
|
||||
|
||||
# Create a test configuration file
|
||||
config_file = os.path.join(self.test_repo_dir, "config.py")
|
||||
with open(config_file, "w") as f:
|
||||
f.write("""
|
||||
# Configuration settings
|
||||
DATABASE_URL = "sqlite:///test.db"
|
||||
DEBUG = True
|
||||
""")
|
||||
|
||||
# Create test developer rules files
|
||||
cursorrules_file = os.path.join(self.test_data_dir, ".cursorrules")
|
||||
with open(cursorrules_file, "w") as f:
|
||||
f.write("# Test cursor rules\nUse Python best practices.")
|
||||
|
||||
self.temp_files.extend([self.test_text_file, test_py_file, config_file, cursorrules_file])
|
||||
print(f"✅ Test environment created at: {self.test_data_dir}")
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up test environment."""
|
||||
print("🧹 Cleaning up test environment...")
|
||||
import shutil
|
||||
|
||||
if os.path.exists(self.test_data_dir):
|
||||
shutil.rmtree(self.test_data_dir)
|
||||
print("✅ Cleanup completed")
|
||||
|
||||
@asynccontextmanager
|
||||
async def mcp_server_session(self):
|
||||
"""Context manager to start and manage MCP server session."""
|
||||
# Get the path to the server script
|
||||
server_script = os.path.join(os.path.dirname(__file__), "server.py")
|
||||
|
||||
# Start the server process
|
||||
server_params = StdioServerParameters(
|
||||
command="python",
|
||||
args=[server_script, "--transport", "stdio"],
|
||||
env=None,
|
||||
)
|
||||
|
||||
async with stdio_client(server_params) as (read, write):
|
||||
async with ClientSession(read, write) as session:
|
||||
# Initialize the session
|
||||
await session.initialize()
|
||||
yield session
|
||||
|
||||
async def test_mcp_server_startup_and_tools(self):
|
||||
"""Test that the MCP server starts properly and returns tool results."""
|
||||
print("\n🧪 Testing MCP server startup and tool execution...")
|
||||
|
||||
try:
|
||||
async with self.mcp_server_session() as session:
|
||||
# Test 1: List available tools
|
||||
print(" 🔍 Testing tool discovery...")
|
||||
tools_result = await session.list_tools()
|
||||
|
||||
expected_tools = {
|
||||
"cognify",
|
||||
"codify",
|
||||
"search",
|
||||
"prune",
|
||||
"cognify_status",
|
||||
"codify_status",
|
||||
"cognee_add_developer_rules",
|
||||
}
|
||||
available_tools = {tool.name for tool in tools_result.tools}
|
||||
|
||||
if not expected_tools.issubset(available_tools):
|
||||
missing_tools = expected_tools - available_tools
|
||||
raise AssertionError(f"Missing expected tools: {missing_tools}")
|
||||
|
||||
print(
|
||||
f" ✅ Found {len(available_tools)} tools: {', '.join(sorted(available_tools))}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.test_results["mcp_server_integration"] = {
|
||||
"status": "FAIL",
|
||||
"error": str(e),
|
||||
"message": "MCP server integration test failed",
|
||||
}
|
||||
print(f"❌ MCP server integration test failed: {e}")
|
||||
|
||||
async def test_prune(self):
|
||||
"""Test the prune functionality."""
|
||||
print("\n🧪 Testing prune functionality...")
|
||||
try:
|
||||
result = await prune()
|
||||
self.test_results["prune"] = {
|
||||
"status": "PASS",
|
||||
"result": result,
|
||||
"message": "Prune executed successfully",
|
||||
}
|
||||
print("✅ Prune test passed")
|
||||
except Exception as e:
|
||||
self.test_results["prune"] = {
|
||||
"status": "FAIL",
|
||||
"error": str(e),
|
||||
"message": "Prune test failed",
|
||||
}
|
||||
print(f"❌ Prune test failed: {e}")
|
||||
raise e
|
||||
|
||||
async def test_cognify(self, test_text, test_name):
|
||||
"""Test the cognify functionality."""
|
||||
print("\n🧪 Testing cognify functionality...")
|
||||
try:
|
||||
# Test with simple text
|
||||
cognify_result = await cognify(test_text)
|
||||
|
||||
start = time.time() # mark the start
|
||||
while True:
|
||||
try:
|
||||
# Wait a moment
|
||||
await asyncio.sleep(5)
|
||||
|
||||
# Check if cognify processing is finished
|
||||
status_result = await cognify_status()
|
||||
if str(PipelineRunStatus.DATASET_PROCESSING_COMPLETED) in status_result[0].text:
|
||||
break
|
||||
elif time.time() - start > TIMEOUT:
|
||||
raise TimeoutError("Cognify did not complete in 5min")
|
||||
except DatabaseNotCreatedError:
|
||||
if time.time() - start > TIMEOUT:
|
||||
raise TimeoutError("Database was not created in 5min")
|
||||
|
||||
self.test_results[test_name] = {
|
||||
"status": "PASS",
|
||||
"result": cognify_result,
|
||||
"message": f"{test_name} executed successfully",
|
||||
}
|
||||
print(f"✅ {test_name} test passed")
|
||||
|
||||
except Exception as e:
|
||||
self.test_results[test_name] = {
|
||||
"status": "FAIL",
|
||||
"error": str(e),
|
||||
"message": f"{test_name} test failed",
|
||||
}
|
||||
print(f"❌ {test_name} test failed: {e}")
|
||||
|
||||
async def test_codify(self):
|
||||
"""Test the codify functionality."""
|
||||
print("\n🧪 Testing codify functionality...")
|
||||
try:
|
||||
codify_result = await codify(self.test_repo_dir)
|
||||
|
||||
start = time.time() # mark the start
|
||||
while True:
|
||||
try:
|
||||
# Wait a moment
|
||||
await asyncio.sleep(5)
|
||||
|
||||
# Check if codify processing is finished
|
||||
status_result = await codify_status()
|
||||
if str(PipelineRunStatus.DATASET_PROCESSING_COMPLETED) in status_result[0].text:
|
||||
break
|
||||
elif time.time() - start > TIMEOUT:
|
||||
raise TimeoutError("Codify did not complete in 5min")
|
||||
except DatabaseNotCreatedError:
|
||||
if time.time() - start > TIMEOUT:
|
||||
raise TimeoutError("Database was not created in 5min")
|
||||
|
||||
self.test_results["codify"] = {
|
||||
"status": "PASS",
|
||||
"result": codify_result,
|
||||
"message": "Codify executed successfully",
|
||||
}
|
||||
print("✅ Codify test passed")
|
||||
|
||||
except Exception as e:
|
||||
self.test_results["codify"] = {
|
||||
"status": "FAIL",
|
||||
"error": str(e),
|
||||
"message": "Codify test failed",
|
||||
}
|
||||
print(f"❌ Codify test failed: {e}")
|
||||
|
||||
async def test_cognee_add_developer_rules(self):
|
||||
"""Test the cognee_add_developer_rules functionality."""
|
||||
print("\n🧪 Testing cognee_add_developer_rules functionality...")
|
||||
try:
|
||||
result = await cognee_add_developer_rules(base_path=self.test_data_dir)
|
||||
|
||||
start = time.time() # mark the start
|
||||
while True:
|
||||
try:
|
||||
# Wait a moment
|
||||
await asyncio.sleep(5)
|
||||
|
||||
# Check if developer rule cognify processing is finished
|
||||
status_result = await cognify_status()
|
||||
if str(PipelineRunStatus.DATASET_PROCESSING_COMPLETED) in status_result[0].text:
|
||||
break
|
||||
elif time.time() - start > TIMEOUT:
|
||||
raise TimeoutError("Cognify of developer rules did not complete in 5min")
|
||||
except DatabaseNotCreatedError:
|
||||
if time.time() - start > TIMEOUT:
|
||||
raise TimeoutError("Database was not created in 5min")
|
||||
|
||||
self.test_results["cognee_add_developer_rules"] = {
|
||||
"status": "PASS",
|
||||
"result": result,
|
||||
"message": "Developer rules addition executed successfully",
|
||||
}
|
||||
print("✅ Developer rules test passed")
|
||||
|
||||
except Exception as e:
|
||||
self.test_results["cognee_add_developer_rules"] = {
|
||||
"status": "FAIL",
|
||||
"error": str(e),
|
||||
"message": "Developer rules test failed",
|
||||
}
|
||||
print(f"❌ Developer rules test failed: {e}")
|
||||
|
||||
async def test_search_functionality(self):
|
||||
"""Test the search functionality with different search types."""
|
||||
print("\n🧪 Testing search functionality...")
|
||||
|
||||
search_query = "What is artificial intelligence?"
|
||||
|
||||
# Test if all search types will execute
|
||||
from cognee import SearchType
|
||||
|
||||
# Go through all Cognee search types
|
||||
for search_type in SearchType:
|
||||
# Don't test these search types
|
||||
if search_type in [SearchType.NATURAL_LANGUAGE, SearchType.CYPHER]:
|
||||
break
|
||||
try:
|
||||
result = await search(search_query, search_type.value)
|
||||
self.test_results[f"search_{search_type}"] = {
|
||||
"status": "PASS",
|
||||
"result": result,
|
||||
"message": f"Search with {search_type} successful",
|
||||
}
|
||||
print(f"✅ Search {search_type} test passed")
|
||||
except Exception as e:
|
||||
self.test_results[f"search_{search_type}"] = {
|
||||
"status": "FAIL",
|
||||
"error": str(e),
|
||||
"message": f"Search with {search_type} failed",
|
||||
}
|
||||
print(f"❌ Search {search_type} test failed: {e}")
|
||||
|
||||
def test_utility_functions(self):
|
||||
"""Test utility functions."""
|
||||
print("\n🧪 Testing utility functions...")
|
||||
|
||||
# Test node_to_string
|
||||
try:
|
||||
test_node = {"id": "test_id", "name": "test_name", "type": "test_type"}
|
||||
result = node_to_string(test_node)
|
||||
expected = 'Node(id: "test_id", name: "test_name")'
|
||||
|
||||
if result == expected:
|
||||
self.test_results["node_to_string"] = {
|
||||
"status": "PASS",
|
||||
"result": result,
|
||||
"message": "node_to_string function works correctly",
|
||||
}
|
||||
print("✅ node_to_string test passed")
|
||||
else:
|
||||
self.test_results["node_to_string"] = {
|
||||
"status": "FAIL",
|
||||
"result": result,
|
||||
"expected": expected,
|
||||
"message": "node_to_string function output mismatch",
|
||||
}
|
||||
print(f"❌ node_to_string test failed: expected {expected}, got {result}")
|
||||
|
||||
except Exception as e:
|
||||
self.test_results["node_to_string"] = {
|
||||
"status": "FAIL",
|
||||
"error": str(e),
|
||||
"message": "node_to_string test failed",
|
||||
}
|
||||
print(f"❌ node_to_string test failed: {e}")
|
||||
|
||||
# Test retrieved_edges_to_string
|
||||
try:
|
||||
test_triplet = [
|
||||
(
|
||||
{"id": "node1", "name": "Node1"},
|
||||
{"relationship_name": "CONNECTS_TO"},
|
||||
{"id": "node2", "name": "Node2"},
|
||||
)
|
||||
]
|
||||
result = retrieved_edges_to_string(test_triplet)
|
||||
expected = (
|
||||
'Node(id: "node1", name: "Node1") CONNECTS_TO Node(id: "node2", name: "Node2")'
|
||||
)
|
||||
if result == expected:
|
||||
self.test_results["retrieved_edges_to_string"] = {
|
||||
"status": "PASS",
|
||||
"result": result,
|
||||
"message": "retrieved_edges_to_string function works correctly",
|
||||
}
|
||||
print("✅ retrieved_edges_to_string test passed")
|
||||
else:
|
||||
self.test_results["retrieved_edges_to_string"] = {
|
||||
"status": "FAIL",
|
||||
"result": result,
|
||||
"expected": expected,
|
||||
"message": "retrieved_edges_to_string function output mismatch",
|
||||
}
|
||||
print(
|
||||
f"❌ retrieved_edges_to_string test failed: expected {expected}, got {result}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.test_results["retrieved_edges_to_string"] = {
|
||||
"status": "FAIL",
|
||||
"error": str(e),
|
||||
"message": "retrieved_edges_to_string test failed",
|
||||
}
|
||||
print(f"❌ retrieved_edges_to_string test failed: {e}")
|
||||
|
||||
def test_load_class_function(self):
|
||||
"""Test load_class function."""
|
||||
print("\n🧪 Testing load_class function...")
|
||||
|
||||
try:
|
||||
# Create a temporary Python file with a test class
|
||||
test_module_file = os.path.join(self.test_data_dir, "test_model.py")
|
||||
with open(test_module_file, "w") as f:
|
||||
f.write("""
|
||||
class TestModel:
|
||||
def __init__(self):
|
||||
self.name = "TestModel"
|
||||
|
||||
def get_name(self):
|
||||
return self.name
|
||||
""")
|
||||
|
||||
# Test loading the class
|
||||
loaded_class = load_class(test_module_file, "TestModel")
|
||||
instance = loaded_class()
|
||||
|
||||
if hasattr(instance, "get_name") and instance.get_name() == "TestModel":
|
||||
self.test_results["load_class"] = {
|
||||
"status": "PASS",
|
||||
"message": "load_class function works correctly",
|
||||
}
|
||||
print("✅ load_class test passed")
|
||||
else:
|
||||
self.test_results["load_class"] = {
|
||||
"status": "FAIL",
|
||||
"message": "load_class function did not load class correctly",
|
||||
}
|
||||
print("❌ load_class test failed: class not loaded correctly")
|
||||
|
||||
except Exception as e:
|
||||
self.test_results["load_class"] = {
|
||||
"status": "FAIL",
|
||||
"error": str(e),
|
||||
"message": "load_class test failed",
|
||||
}
|
||||
print(f"❌ load_class test failed: {e}")
|
||||
|
||||
async def run_all_tests(self):
|
||||
"""Run all tests."""
|
||||
print("🚀 Starting Cognee MCP Server Test Suite")
|
||||
print("=" * 50)
|
||||
|
||||
await self.setup()
|
||||
|
||||
# Test MCP server integration first
|
||||
await self.test_mcp_server_startup_and_tools()
|
||||
|
||||
# Run tests in logical order
|
||||
await self.test_prune() # Start with clean slate
|
||||
|
||||
# Test cognify twice to make sure updating a dataset with new docs is working as expected
|
||||
await self.test_cognify(
|
||||
test_text="Artificial Intelligence is transforming the world through machine learning and deep learning technologies.",
|
||||
test_name="Cognify1",
|
||||
)
|
||||
await self.test_cognify(
|
||||
test_text="Natural language processing (NLP) is an interdisciplinary subfield of computer science and information retrieval.",
|
||||
test_name="Cognify2",
|
||||
)
|
||||
|
||||
await self.test_codify()
|
||||
await self.test_cognee_add_developer_rules()
|
||||
|
||||
await self.test_search_functionality()
|
||||
|
||||
# Test utility functions (synchronous)
|
||||
self.test_utility_functions()
|
||||
self.test_load_class_function()
|
||||
|
||||
await self.cleanup()
|
||||
|
||||
# Print summary
|
||||
self.print_test_summary()
|
||||
|
||||
def print_test_summary(self):
|
||||
"""Print test results summary."""
|
||||
print("\n" + "=" * 50)
|
||||
print("📊 TEST RESULTS SUMMARY")
|
||||
print("=" * 50)
|
||||
|
||||
passed = 0
|
||||
failed = 0
|
||||
|
||||
for test_name, result in self.test_results.items():
|
||||
if result["status"] == "PASS":
|
||||
status_emoji = "✅"
|
||||
passed += 1
|
||||
else:
|
||||
status_emoji = "❌"
|
||||
failed += 1
|
||||
|
||||
print(f"{status_emoji} {test_name}: {result['status']}")
|
||||
|
||||
if result["status"] == "FAIL" and "error" in result:
|
||||
print(f" Error: {result['error']}")
|
||||
|
||||
print("\n" + "-" * 50)
|
||||
total_tests = passed + failed
|
||||
print(f"Total Tests: {total_tests}")
|
||||
print(f"Passed: {passed}")
|
||||
print(f"Failed: {failed}")
|
||||
print(f"Success Rate: {(passed / total_tests * 100):.1f}%")
|
||||
|
||||
assert failed == 0, "\n ⚠️ Number of tests didn't pass!"
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main function to run the test suite."""
|
||||
client = CogneeTestClient()
|
||||
await client.run_all_tests()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from logging import ERROR
|
||||
|
||||
logger = setup_logging(log_level=ERROR)
|
||||
asyncio.run(main())
|
||||
|
|
@ -231,6 +231,10 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
if limit == 0:
|
||||
limit = await collection.count_rows()
|
||||
|
||||
# LanceDB search will break if limit is 0 so we must return
|
||||
if limit == 0:
|
||||
return []
|
||||
|
||||
results = await collection.vector_search(query_vector).limit(limit).to_pandas()
|
||||
|
||||
result_values = list(results.to_dict("index").values())
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue