diff --git a/.github/workflows/cli_tests.yml b/.github/workflows/cli_tests.yml index 7086d341f..958d341ae 100644 --- a/.github/workflows/cli_tests.yml +++ b/.github/workflows/cli_tests.yml @@ -60,7 +60,7 @@ jobs: python-version: ${{ inputs.python-version }} - name: Run CLI Unit Tests - run: uv run pytest cognee/tests/unit/cli/ -v + run: uv run pytest cognee/tests/cli_tests/cli_unit_tests/ -v cli-integration-tests: name: CLI Integration Tests @@ -87,7 +87,7 @@ jobs: python-version: ${{ inputs.python-version }} - name: Run CLI Integration Tests - run: uv run pytest cognee/tests/integration/cli/ -v + run: uv run pytest cognee/tests/cli_tests/cli_integration_tests/ -v cli-functionality-tests: name: CLI Functionality Tests @@ -135,12 +135,12 @@ jobs: run: | # Test invalid command (should fail gracefully) ! uv run python -m cognee.cli._cognee invalid_command - + # Test missing required arguments (should fail gracefully) ! uv run python -m cognee.cli._cognee search - + # Test invalid search type (should fail gracefully) ! uv run python -m cognee.cli._cognee search "test" --query-type INVALID_TYPE - + # Test invalid chunker (should fail gracefully) ! uv run python -m cognee.cli._cognee cognify --chunker InvalidChunker diff --git a/.github/workflows/examples_tests.yml b/.github/workflows/examples_tests.yml index 1dc720f8e..f4167a57a 100644 --- a/.github/workflows/examples_tests.yml +++ b/.github/workflows/examples_tests.yml @@ -106,3 +106,28 @@ jobs: EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} run: uv run python ./examples/python/dynamic_steps_example.py + + test-memify: + name: Run Memify Example + runs-on: ubuntu-22.04 + steps: + - name: Check out repository + uses: actions/checkout@v4 + + - name: Cognee Setup + uses: ./.github/actions/cognee_setup + with: + python-version: '3.11.x' + + - name: Run Memify Tests + env: + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + LLM_MODEL: ${{ secrets.LLM_MODEL }} + LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} + LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} + EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} + EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} + EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} + EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} + run: uv run python ./examples/python/memify_coding_agent_example.py diff --git a/.github/workflows/test_suites.yml b/.github/workflows/test_suites.yml index f051a31b3..ff18f2962 100644 --- a/.github/workflows/test_suites.yml +++ b/.github/workflows/test_suites.yml @@ -34,49 +34,49 @@ jobs: docker-compose-test: name: Docker Compose Test - needs: [basic-tests, e2e-tests, cli-tests] + needs: [basic-tests, e2e-tests] uses: ./.github/workflows/docker_compose.yml secrets: inherit docker-ci-test: name: Docker CI test - needs: [basic-tests, e2e-tests, cli-tests] + needs: [basic-tests, e2e-tests] uses: ./.github/workflows/backend_docker_build_test.yml secrets: inherit graph-db-tests: name: Graph Database Tests - needs: [basic-tests, e2e-tests, cli-tests] + needs: [basic-tests, e2e-tests] uses: ./.github/workflows/graph_db_tests.yml secrets: inherit temporal-graph-tests: name: Temporal Graph Test - needs: [ basic-tests, e2e-tests, cli-tests, graph-db-tests ] + needs: [ basic-tests, e2e-tests, graph-db-tests ] uses: ./.github/workflows/temporal_graph_tests.yml secrets: inherit search-db-tests: name: Search Test on Different DBs - needs: [basic-tests, e2e-tests, cli-tests, graph-db-tests] + needs: [basic-tests, e2e-tests, graph-db-tests] uses: ./.github/workflows/search_db_tests.yml secrets: inherit relational-db-migration-tests: name: Relational DB Migration Tests - needs: [basic-tests, e2e-tests, cli-tests, graph-db-tests] + needs: [basic-tests, e2e-tests, graph-db-tests] uses: ./.github/workflows/relational_db_migration_tests.yml secrets: inherit notebook-tests: name: Notebook Tests - needs: [basic-tests, e2e-tests, cli-tests] + needs: [basic-tests, e2e-tests] uses: ./.github/workflows/notebooks_tests.yml secrets: inherit different-operating-systems-tests: name: Operating System and Python Tests - needs: [basic-tests, e2e-tests, cli-tests] + needs: [basic-tests, e2e-tests] uses: ./.github/workflows/test_different_operating_systems.yml with: python-versions: '["3.10.x", "3.11.x", "3.12.x"]' @@ -85,20 +85,20 @@ jobs: # Matrix-based vector database tests vector-db-tests: name: Vector DB Tests - needs: [basic-tests, e2e-tests, cli-tests] + needs: [basic-tests, e2e-tests] uses: ./.github/workflows/vector_db_tests.yml secrets: inherit # Matrix-based example tests example-tests: name: Example Tests - needs: [basic-tests, e2e-tests, cli-tests] + needs: [basic-tests, e2e-tests] uses: ./.github/workflows/examples_tests.yml secrets: inherit mcp-test: name: MCP Tests - needs: [basic-tests, e2e-tests, cli-tests] + needs: [basic-tests, e2e-tests] uses: ./.github/workflows/test_mcp.yml secrets: inherit @@ -110,14 +110,14 @@ jobs: s3-file-storage-test: name: S3 File Storage Test - needs: [basic-tests, e2e-tests, cli-tests] + needs: [basic-tests, e2e-tests] uses: ./.github/workflows/test_s3_file_storage.yml secrets: inherit # Additional LLM tests llm-tests: name: LLM Test Suite - needs: [ basic-tests, e2e-tests, cli-tests ] + needs: [ basic-tests, e2e-tests ] uses: ./.github/workflows/test_llms.yml secrets: inherit @@ -127,7 +127,6 @@ jobs: needs: [ basic-tests, e2e-tests, - cli-tests, graph-db-tests, notebook-tests, different-operating-systems-tests, diff --git a/cognee/api/v1/search/routers/get_search_router.py b/cognee/api/v1/search/routers/get_search_router.py index bd4841d3f..36d1c567e 100644 --- a/cognee/api/v1/search/routers/get_search_router.py +++ b/cognee/api/v1/search/routers/get_search_router.py @@ -1,12 +1,12 @@ from uuid import UUID -from typing import Optional +from typing import Optional, Union, List, Any from datetime import datetime from pydantic import Field from fastapi import Depends, APIRouter from fastapi.responses import JSONResponse from fastapi.encoders import jsonable_encoder -from cognee.modules.search.types import SearchType +from cognee.modules.search.types import SearchType, SearchResult, CombinedSearchResult from cognee.api.DTO import InDTO, OutDTO from cognee.modules.users.exceptions.exceptions import PermissionDeniedError from cognee.modules.users.models import User @@ -73,7 +73,7 @@ def get_search_router() -> APIRouter: except Exception as error: return JSONResponse(status_code=500, content={"error": str(error)}) - @router.post("", response_model=list) + @router.post("", response_model=Union[List[SearchResult], CombinedSearchResult, List]) async def search(payload: SearchPayloadDTO, user: User = Depends(get_authenticated_user)): """ Search for nodes in the graph database. diff --git a/cognee/api/v1/ui/ui.py b/cognee/api/v1/ui/ui.py index aae0ad5a9..6faca19e8 100644 --- a/cognee/api/v1/ui/ui.py +++ b/cognee/api/v1/ui/ui.py @@ -7,7 +7,7 @@ import webbrowser import zipfile import requests from pathlib import Path -from typing import Optional, Tuple +from typing import Callable, Optional, Tuple import tempfile import shutil @@ -326,38 +326,93 @@ def prompt_user_for_download() -> bool: def start_ui( + pid_callback: Callable[[int], None], host: str = "localhost", - port: int = 3001, + port: int = 3000, open_browser: bool = True, auto_download: bool = False, + start_backend: bool = False, + backend_host: str = "localhost", + backend_port: int = 8000, ) -> Optional[subprocess.Popen]: """ - Start the cognee frontend UI server. + Start the cognee frontend UI server, optionally with the backend API server. This function will: - 1. Find the cognee-frontend directory (development) or download it (pip install) - 2. Check if Node.js and npm are available (for development mode) - 3. Install dependencies if needed (development mode) - 4. Start the appropriate server - 5. Optionally open the browser + 1. Optionally start the cognee backend API server + 2. Find the cognee-frontend directory (development) or download it (pip install) + 3. Check if Node.js and npm are available (for development mode) + 4. Install dependencies if needed (development mode) + 5. Start the frontend server + 6. Optionally open the browser Args: - host: Host to bind the server to (default: localhost) - port: Port to run the server on (default: 3001) + pid_callback: Callback to notify with PID of each spawned process + host: Host to bind the frontend server to (default: localhost) + port: Port to run the frontend server on (default: 3000) open_browser: Whether to open the browser automatically (default: True) auto_download: If True, download frontend without prompting (default: False) + start_backend: If True, also start the cognee API backend server (default: False) + backend_host: Host to bind the backend server to (default: localhost) + backend_port: Port to run the backend server on (default: 8000) Returns: - subprocess.Popen object representing the running server, or None if failed + subprocess.Popen object representing the running frontend server, or None if failed + Note: If backend is started, it runs in a separate process that will be cleaned up + when the frontend process is terminated. Example: >>> import cognee + >>> # Start just the frontend >>> server = cognee.start_ui() - >>> # UI will be available at http://localhost:3001 - >>> # To stop the server later: + >>> + >>> # Start both frontend and backend + >>> server = cognee.start_ui(start_backend=True) + >>> # UI will be available at http://localhost:3000 + >>> # API will be available at http://localhost:8000 + >>> # To stop both servers later: >>> server.terminate() """ logger.info("Starting cognee UI...") + backend_process = None + + # Start backend server if requested + if start_backend: + logger.info("Starting cognee backend API server...") + try: + import sys + + backend_process = subprocess.Popen( + [ + sys.executable, + "-m", + "uvicorn", + "cognee.api.client:app", + "--host", + backend_host, + "--port", + str(backend_port), + ], + # Inherit stdout/stderr from parent process to show logs + stdout=None, + stderr=None, + preexec_fn=os.setsid if hasattr(os, "setsid") else None, + ) + + pid_callback(backend_process.pid) + + # Give the backend a moment to start + time.sleep(2) + + if backend_process.poll() is not None: + logger.error("Backend server failed to start - process exited early") + return None + + logger.info(f"✓ Backend API started at http://{backend_host}:{backend_port}") + + except Exception as e: + logger.error(f"Failed to start backend server: {str(e)}") + return None # Find frontend directory frontend_path = find_frontend_path() @@ -406,7 +461,7 @@ def start_ui( logger.info("This may take a moment to compile and start...") try: - # Use process group to ensure all child processes get terminated together + # Create frontend in its own process group for clean termination process = subprocess.Popen( ["npm", "run", "dev"], cwd=frontend_path, @@ -414,11 +469,11 @@ def start_ui( stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, - preexec_fn=os.setsid - if hasattr(os, "setsid") - else None, # Create new process group on Unix + preexec_fn=os.setsid if hasattr(os, "setsid") else None, ) + pid_callback(process.pid) + # Give it a moment to start up time.sleep(3) @@ -447,16 +502,32 @@ def start_ui( logger.info(f"✓ Open your browser to: http://{host}:{port}") logger.info("✓ The UI will be available once Next.js finishes compiling") + # Store backend process reference in the frontend process for cleanup + if backend_process: + process._cognee_backend_process = backend_process + return process except Exception as e: logger.error(f"Failed to start frontend server: {str(e)}") + # Clean up backend process if it was started + if backend_process: + logger.info("Cleaning up backend process due to frontend failure...") + try: + backend_process.terminate() + backend_process.wait(timeout=5) + except (subprocess.TimeoutExpired, OSError, ProcessLookupError): + try: + backend_process.kill() + backend_process.wait() + except (OSError, ProcessLookupError): + pass return None def stop_ui(process: subprocess.Popen) -> bool: """ - Stop a running UI server process and all its children. + Stop a running UI server process and backend process (if started), along with all their children. Args: process: The subprocess.Popen object returned by start_ui() @@ -467,7 +538,29 @@ def stop_ui(process: subprocess.Popen) -> bool: if not process: return False + success = True + try: + # First, stop the backend process if it exists + backend_process = getattr(process, "_cognee_backend_process", None) + if backend_process: + logger.info("Stopping backend server...") + try: + backend_process.terminate() + try: + backend_process.wait(timeout=5) + logger.info("Backend server stopped gracefully") + except subprocess.TimeoutExpired: + logger.warning("Backend didn't terminate gracefully, forcing kill") + backend_process.kill() + backend_process.wait() + logger.info("Backend server stopped") + except Exception as e: + logger.error(f"Error stopping backend server: {str(e)}") + success = False + + # Now stop the frontend process + logger.info("Stopping frontend server...") # Try to terminate the process group (includes child processes like Next.js) if hasattr(os, "killpg"): try: @@ -484,9 +577,9 @@ def stop_ui(process: subprocess.Popen) -> bool: try: process.wait(timeout=10) - logger.info("UI server stopped gracefully") + logger.info("Frontend server stopped gracefully") except subprocess.TimeoutExpired: - logger.warning("Process didn't terminate gracefully, forcing kill") + logger.warning("Frontend didn't terminate gracefully, forcing kill") # Force kill the process group if hasattr(os, "killpg"): @@ -502,11 +595,13 @@ def stop_ui(process: subprocess.Popen) -> bool: process.wait() - logger.info("UI server stopped") - return True + if success: + logger.info("UI servers stopped successfully") + + return success except Exception as e: - logger.error(f"Error stopping UI server: {str(e)}") + logger.error(f"Error stopping UI servers: {str(e)}") return False diff --git a/cognee/cli/_cognee.py b/cognee/cli/_cognee.py index 1c9406143..52915594b 100644 --- a/cognee/cli/_cognee.py +++ b/cognee/cli/_cognee.py @@ -174,30 +174,23 @@ def main() -> int: # Handle UI flag if hasattr(args, "start_ui") and args.start_ui: - server_process = None + spawned_pids = [] def signal_handler(signum, frame): """Handle Ctrl+C and other termination signals""" - nonlocal server_process + nonlocal spawned_pids fmt.echo("\nShutting down UI server...") - if server_process: + + for pid in spawned_pids: try: - # Try graceful termination first - server_process.terminate() - try: - server_process.wait(timeout=5) - fmt.success("UI server stopped gracefully.") - except subprocess.TimeoutExpired: - # If graceful termination fails, force kill - fmt.echo("Force stopping UI server...") - server_process.kill() - server_process.wait() - fmt.success("UI server stopped.") - except Exception as e: - fmt.warning(f"Error stopping server: {e}") + pgid = os.getpgid(pid) + os.killpg(pgid, signal.SIGTERM) + fmt.success(f"✓ Process group {pgid} (PID {pid}) terminated.") + except (OSError, ProcessLookupError) as e: + fmt.warning(f"Could not terminate process {pid}: {e}") + sys.exit(0) - # Set up signal handlers signal.signal(signal.SIGINT, signal_handler) # Ctrl+C signal.signal(signal.SIGTERM, signal_handler) # Termination request @@ -205,11 +198,25 @@ def main() -> int: from cognee import start_ui fmt.echo("Starting cognee UI...") - server_process = start_ui(host="localhost", port=3001, open_browser=True) + + # Callback to capture PIDs of all spawned processes + def pid_callback(pid): + nonlocal spawned_pids + spawned_pids.append(pid) + + server_process = start_ui( + host="localhost", + port=3000, + open_browser=True, + start_backend=True, + auto_download=True, + pid_callback=pid_callback, + ) if server_process: fmt.success("UI server started successfully!") - fmt.echo("The interface is available at: http://localhost:3001") + fmt.echo("The interface is available at: http://localhost:3000") + fmt.echo("The API backend is available at: http://localhost:8000") fmt.note("Press Ctrl+C to stop the server...") try: @@ -225,10 +232,12 @@ def main() -> int: return 0 else: fmt.error("Failed to start UI server. Check the logs above for details.") + signal_handler(signal.SIGTERM, None) return 1 except Exception as ex: fmt.error(f"Error starting UI: {str(ex)}") + signal_handler(signal.SIGTERM, None) if debug.is_debug_enabled(): raise ex return 1 diff --git a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py index 4f4af1f06..f4c37fc93 100644 --- a/cognee/modules/retrieval/graph_completion_context_extension_retriever.py +++ b/cognee/modules/retrieval/graph_completion_context_extension_retriever.py @@ -128,4 +128,4 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever): question=query, answer=completion, context=context_text, triplets=triplets ) - return completion + return [completion] diff --git a/cognee/modules/retrieval/graph_completion_cot_retriever.py b/cognee/modules/retrieval/graph_completion_cot_retriever.py index 282c6147e..f51433751 100644 --- a/cognee/modules/retrieval/graph_completion_cot_retriever.py +++ b/cognee/modules/retrieval/graph_completion_cot_retriever.py @@ -138,4 +138,4 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever): question=query, answer=completion, context=context_text, triplets=triplets ) - return completion + return [completion] diff --git a/cognee/modules/retrieval/graph_completion_retriever.py b/cognee/modules/retrieval/graph_completion_retriever.py index 45e7f85ff..29b1e9d19 100644 --- a/cognee/modules/retrieval/graph_completion_retriever.py +++ b/cognee/modules/retrieval/graph_completion_retriever.py @@ -171,7 +171,7 @@ class GraphCompletionRetriever(BaseGraphRetriever): question=query, answer=completion, context=context_text, triplets=triplets ) - return completion + return [completion] async def save_qa(self, question: str, answer: str, context: str, triplets: List) -> None: """ diff --git a/cognee/modules/retrieval/insights_retriever.py b/cognee/modules/retrieval/insights_retriever.py index 43b77e951..0b1991e92 100644 --- a/cognee/modules/retrieval/insights_retriever.py +++ b/cognee/modules/retrieval/insights_retriever.py @@ -96,17 +96,18 @@ class InsightsRetriever(BaseGraphRetriever): unique_node_connections_map[unique_id] = True unique_node_connections.append(node_connection) - return [ - Edge( - node1=Node(node_id=connection[0]["id"], attributes=connection[0]), - node2=Node(node_id=connection[2]["id"], attributes=connection[2]), - attributes={ - **connection[1], - "relationship_type": connection[1]["relationship_name"], - }, - ) - for connection in unique_node_connections - ] + return unique_node_connections + # return [ + # Edge( + # node1=Node(node_id=connection[0]["id"], attributes=connection[0]), + # node2=Node(node_id=connection[2]["id"], attributes=connection[2]), + # attributes={ + # **connection[1], + # "relationship_type": connection[1]["relationship_name"], + # }, + # ) + # for connection in unique_node_connections + # ] async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: """ diff --git a/cognee/modules/retrieval/temporal_retriever.py b/cognee/modules/retrieval/temporal_retriever.py index 09f2980dd..36cdbd33f 100644 --- a/cognee/modules/retrieval/temporal_retriever.py +++ b/cognee/modules/retrieval/temporal_retriever.py @@ -149,4 +149,4 @@ class TemporalRetriever(GraphCompletionRetriever): system_prompt_path=self.system_prompt_path, ) - return completion + return [completion] diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py index 405207114..0c236d896 100644 --- a/cognee/modules/search/methods/search.py +++ b/cognee/modules/search/methods/search.py @@ -132,14 +132,37 @@ async def search( ], ) else: - return [ - SearchResult( - search_result=result, - dataset_id=datasets[min(index, len(datasets) - 1)].id if datasets else None, - dataset_name=datasets[min(index, len(datasets) - 1)].name if datasets else None, - ) - for index, (result, _, datasets) in enumerate(search_results) - ] + # This is for maintaining backwards compatibility + if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true": + return_value = [] + for search_result in search_results: + result, context, datasets = search_result + return_value.append( + { + "search_result": result, + "dataset_id": datasets[0].id, + "dataset_name": datasets[0].name, + } + ) + return return_value + else: + return_value = [] + for search_result in search_results: + result, context, datasets = search_result + return_value.append(result) + # For maintaining backwards compatibility + if len(return_value) == 1 and isinstance(return_value[0], list): + return return_value[0] + else: + return return_value + # return [ + # SearchResult( + # search_result=result, + # dataset_id=datasets[min(index, len(datasets) - 1)].id if datasets else None, + # dataset_name=datasets[min(index, len(datasets) - 1)].name if datasets else None, + # ) + # for index, (result, _, datasets) in enumerate(search_results) + # ] async def authorized_search( diff --git a/cognee/tests/integration/cli/__init__.py b/cognee/tests/cli_tests/cli_integration_tests/__init__.py similarity index 100% rename from cognee/tests/integration/cli/__init__.py rename to cognee/tests/cli_tests/cli_integration_tests/__init__.py diff --git a/cognee/tests/integration/cli/test_cli_integration.py b/cognee/tests/cli_tests/cli_integration_tests/test_cli_integration.py similarity index 100% rename from cognee/tests/integration/cli/test_cli_integration.py rename to cognee/tests/cli_tests/cli_integration_tests/test_cli_integration.py diff --git a/cognee/tests/unit/cli/__init__.py b/cognee/tests/cli_tests/cli_unit_tests/__init__.py similarity index 100% rename from cognee/tests/unit/cli/__init__.py rename to cognee/tests/cli_tests/cli_unit_tests/__init__.py diff --git a/cognee/tests/unit/cli/test_cli_commands.py b/cognee/tests/cli_tests/cli_unit_tests/test_cli_commands.py similarity index 100% rename from cognee/tests/unit/cli/test_cli_commands.py rename to cognee/tests/cli_tests/cli_unit_tests/test_cli_commands.py diff --git a/cognee/tests/unit/cli/test_cli_edge_cases.py b/cognee/tests/cli_tests/cli_unit_tests/test_cli_edge_cases.py similarity index 100% rename from cognee/tests/unit/cli/test_cli_edge_cases.py rename to cognee/tests/cli_tests/cli_unit_tests/test_cli_edge_cases.py diff --git a/cognee/tests/unit/cli/test_cli_main.py b/cognee/tests/cli_tests/cli_unit_tests/test_cli_main.py similarity index 100% rename from cognee/tests/unit/cli/test_cli_main.py rename to cognee/tests/cli_tests/cli_unit_tests/test_cli_main.py diff --git a/cognee/tests/unit/cli/test_cli_runner.py b/cognee/tests/cli_tests/cli_unit_tests/test_cli_runner.py similarity index 100% rename from cognee/tests/unit/cli/test_cli_runner.py rename to cognee/tests/cli_tests/cli_unit_tests/test_cli_runner.py diff --git a/cognee/tests/unit/cli/test_cli_utils.py b/cognee/tests/cli_tests/cli_unit_tests/test_cli_utils.py similarity index 100% rename from cognee/tests/unit/cli/test_cli_utils.py rename to cognee/tests/cli_tests/cli_unit_tests/test_cli_utils.py diff --git a/cognee/tests/test_permissions.py b/cognee/tests/test_permissions.py index cfa3aade2..95f769263 100644 --- a/cognee/tests/test_permissions.py +++ b/cognee/tests/test_permissions.py @@ -79,7 +79,7 @@ async def main(): print("\n\nExtracted sentences are:\n") for result in search_results: print(f"{result}\n") - assert search_results[0].dataset_name == "NLP", ( + assert search_results[0]["dataset_name"] == "NLP", ( f"Dict must contain dataset name 'NLP': {search_results[0]}" ) @@ -93,7 +93,7 @@ async def main(): print("\n\nExtracted sentences are:\n") for result in search_results: print(f"{result}\n") - assert search_results[0].dataset_name == "QUANTUM", ( + assert search_results[0]["dataset_name"] == "QUANTUM", ( f"Dict must contain dataset name 'QUANTUM': {search_results[0]}" ) @@ -170,7 +170,7 @@ async def main(): for result in search_results: print(f"{result}\n") - assert search_results[0].dataset_name == "QUANTUM", ( + assert search_results[0]["dataset_name"] == "QUANTUM", ( f"Dict must contain dataset name 'QUANTUM': {search_results[0]}" ) diff --git a/cognee/tests/test_relational_db_migration.py b/cognee/tests/test_relational_db_migration.py index 49508144f..68b46dbf5 100644 --- a/cognee/tests/test_relational_db_migration.py +++ b/cognee/tests/test_relational_db_migration.py @@ -45,15 +45,13 @@ async def relational_db_migration(): await migrate_relational_database(graph_engine, schema=schema) # 1. Search the graph - search_results: List[SearchResult] = await cognee.search( + search_results = await cognee.search( query_type=SearchType.GRAPH_COMPLETION, query_text="Tell me about the artist AC/DC" - ) # type: ignore + ) print("Search results:", search_results) # 2. Assert that the search results contain "AC/DC" - assert any("AC/DC" in r.search_result for r in search_results), ( - "AC/DC not found in search results!" - ) + assert any("AC/DC" in r for r in search_results), "AC/DC not found in search results!" migration_db_provider = migration_engine.engine.dialect.name if migration_db_provider == "postgresql": diff --git a/cognee/tests/test_search_db.py b/cognee/tests/test_search_db.py index 62b07f31a..cb4636470 100644 --- a/cognee/tests/test_search_db.py +++ b/cognee/tests/test_search_db.py @@ -144,13 +144,16 @@ async def main(): ("GRAPH_COMPLETION_CONTEXT_EXTENSION", completion_ext), ("GRAPH_SUMMARY_COMPLETION", completion_sum), ]: - for search_result in search_results: - completion = search_result.search_result - assert isinstance(completion, str), f"{name}: should return a string" - assert completion.strip(), f"{name}: string should not be empty" - assert "netherlands" in completion.lower(), ( - f"{name}: expected 'netherlands' in result, got: {completion!r}" - ) + assert isinstance(search_results, list), f"{name}: should return a list" + assert len(search_results) == 1, ( + f"{name}: expected single-element list, got {len(search_results)}" + ) + text = search_results[0] + assert isinstance(text, str), f"{name}: element should be a string" + assert text.strip(), f"{name}: string should not be empty" + assert "netherlands" in text.lower(), ( + f"{name}: expected 'netherlands' in result, got: {text!r}" + ) graph_engine = await get_graph_engine() graph = await graph_engine.get_graph_data() diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py index 02e3f73e2..74def2ae7 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py @@ -59,8 +59,10 @@ class TestGraphCompletionWithContextExtensionRetriever: answer = await retriever.get_completion("Who works at Canva?") - assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}" - assert answer.strip(), "Answer must contain only non-empty strings" + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), ( + "Answer must contain only non-empty strings" + ) @pytest.mark.asyncio async def test_graph_completion_extension_context_complex(self): @@ -140,8 +142,10 @@ class TestGraphCompletionWithContextExtensionRetriever: answer = await retriever.get_completion("Who works at Figma?") - assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}" - assert answer.strip(), "Answer must contain only non-empty strings" + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), ( + "Answer must contain only non-empty strings" + ) @pytest.mark.asyncio async def test_get_graph_completion_extension_context_on_empty_graph(self): @@ -171,5 +175,7 @@ class TestGraphCompletionWithContextExtensionRetriever: answer = await retriever.get_completion("Who works at Figma?") - assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}" - assert answer.strip(), "Answer must contain only non-empty strings" + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), ( + "Answer must contain only non-empty strings" + ) diff --git a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py index 54fa12f41..9a789a1bd 100644 --- a/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +++ b/cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py @@ -55,8 +55,10 @@ class TestGraphCompletionCoTRetriever: answer = await retriever.get_completion("Who works at Canva?") - assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}" - assert answer.strip(), "Answer must contain only non-empty strings" + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), ( + "Answer must contain only non-empty strings" + ) @pytest.mark.asyncio async def test_graph_completion_cot_context_complex(self): @@ -133,8 +135,10 @@ class TestGraphCompletionCoTRetriever: answer = await retriever.get_completion("Who works at Figma?") - assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}" - assert answer.strip(), "Answer must contain only non-empty strings" + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), ( + "Answer must contain only non-empty strings" + ) @pytest.mark.asyncio async def test_get_graph_completion_cot_context_on_empty_graph(self): @@ -164,5 +168,7 @@ class TestGraphCompletionCoTRetriever: answer = await retriever.get_completion("Who works at Figma?") - assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}" - assert answer.strip(), "Answer must contain only non-empty strings" + assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}" + assert all(isinstance(item, str) and item.strip() for item in answer), ( + "Answer must contain only non-empty strings" + ) diff --git a/cognee/tests/unit/modules/retrieval/insights_retriever_test.py b/cognee/tests/unit/modules/retrieval/insights_retriever_test.py index a3d9da63a..21dbc98dd 100644 --- a/cognee/tests/unit/modules/retrieval/insights_retriever_test.py +++ b/cognee/tests/unit/modules/retrieval/insights_retriever_test.py @@ -82,7 +82,7 @@ class TestInsightsRetriever: context = await retriever.get_context("Mike") - assert context[0].node1.attributes["name"] == "Mike Broski", "Failed to get Mike Broski" + assert context[0][0]["name"] == "Mike Broski", "Failed to get Mike Broski" @pytest.mark.asyncio async def test_insights_context_complex(self): @@ -222,9 +222,7 @@ class TestInsightsRetriever: context = await retriever.get_context("Christina") - assert context[0].node1.attributes["name"] == "Christina Mayer", ( - "Failed to get Christina Mayer" - ) + assert context[0][0]["name"] == "Christina Mayer", "Failed to get Christina Mayer" @pytest.mark.asyncio async def test_insights_context_on_empty_graph(self): diff --git a/poetry.lock b/poetry.lock index 64c1bb050..de2be7768 100644 --- a/poetry.lock +++ b/poetry.lock @@ -11728,4 +11728,4 @@ posthog = ["posthog"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<=3.13" -content-hash = "576318d370b89d128a7c3e755fe3c898fef4e359acdd3f05f952ae497751fb04" +content-hash = "1e8cdbf6919cea9657d51b7839630dac7a0d8a2815eca0bd811838a282051625" diff --git a/pyproject.toml b/pyproject.toml index 87b25dc71..6e6d9e1a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "cognee" -version = "0.3.0" +version = "0.3.2" description = "Cognee - is a library for enriching LLM context with a semantic layer for better understanding and reasoning." authors = [ { name = "Vasilije Markovic" }, diff --git a/uv.lock b/uv.lock index fb8ecd9bd..4c9a6da17 100644 --- a/uv.lock +++ b/uv.lock @@ -807,7 +807,7 @@ wheels = [ [[package]] name = "cognee" -version = "0.2.4" +version = "0.3.2" source = { editable = "." } dependencies = [ { name = "aiofiles" }, @@ -1029,7 +1029,7 @@ requires-dist = [ { name = "pylance", specifier = ">=0.22.0,<1.0.0" }, { name = "pylint", marker = "extra == 'dev'", specifier = ">=3.0.3,<4" }, { name = "pympler", specifier = ">=1.1,<2.0.0" }, - { name = "pypdf", specifier = ">=4.1.0,<6.0.0" }, + { name = "pypdf", specifier = ">=4.1.0,<7.0.0" }, { name = "pypika", marker = "extra == 'chromadb'", specifier = "==0.48.8" }, { name = "pyside6", marker = "extra == 'gui'", specifier = ">=6.8.3,<7" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=7.4.0,<8" },