Merge branch 'dev' into feature/cog-2923-create-ci-test-for-fastembed

This commit is contained in:
Igor Ilic 2025-09-12 13:33:15 +02:00
commit 8d7738d713
29 changed files with 285 additions and 122 deletions

View file

@ -60,7 +60,7 @@ jobs:
python-version: ${{ inputs.python-version }} python-version: ${{ inputs.python-version }}
- name: Run CLI Unit Tests - name: Run CLI Unit Tests
run: uv run pytest cognee/tests/unit/cli/ -v run: uv run pytest cognee/tests/cli_tests/cli_unit_tests/ -v
cli-integration-tests: cli-integration-tests:
name: CLI Integration Tests name: CLI Integration Tests
@ -87,7 +87,7 @@ jobs:
python-version: ${{ inputs.python-version }} python-version: ${{ inputs.python-version }}
- name: Run CLI Integration Tests - name: Run CLI Integration Tests
run: uv run pytest cognee/tests/integration/cli/ -v run: uv run pytest cognee/tests/cli_tests/cli_integration_tests/ -v
cli-functionality-tests: cli-functionality-tests:
name: CLI Functionality Tests name: CLI Functionality Tests
@ -135,12 +135,12 @@ jobs:
run: | run: |
# Test invalid command (should fail gracefully) # Test invalid command (should fail gracefully)
! uv run python -m cognee.cli._cognee invalid_command ! uv run python -m cognee.cli._cognee invalid_command
# Test missing required arguments (should fail gracefully) # Test missing required arguments (should fail gracefully)
! uv run python -m cognee.cli._cognee search ! uv run python -m cognee.cli._cognee search
# Test invalid search type (should fail gracefully) # Test invalid search type (should fail gracefully)
! uv run python -m cognee.cli._cognee search "test" --query-type INVALID_TYPE ! uv run python -m cognee.cli._cognee search "test" --query-type INVALID_TYPE
# Test invalid chunker (should fail gracefully) # Test invalid chunker (should fail gracefully)
! uv run python -m cognee.cli._cognee cognify --chunker InvalidChunker ! uv run python -m cognee.cli._cognee cognify --chunker InvalidChunker

View file

@ -106,3 +106,28 @@ jobs:
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./examples/python/dynamic_steps_example.py run: uv run python ./examples/python/dynamic_steps_example.py
test-memify:
name: Run Memify Example
runs-on: ubuntu-22.04
steps:
- name: Check out repository
uses: actions/checkout@v4
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11.x'
- name: Run Memify Tests
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./examples/python/memify_coding_agent_example.py

View file

@ -34,49 +34,49 @@ jobs:
docker-compose-test: docker-compose-test:
name: Docker Compose Test name: Docker Compose Test
needs: [basic-tests, e2e-tests, cli-tests] needs: [basic-tests, e2e-tests]
uses: ./.github/workflows/docker_compose.yml uses: ./.github/workflows/docker_compose.yml
secrets: inherit secrets: inherit
docker-ci-test: docker-ci-test:
name: Docker CI test name: Docker CI test
needs: [basic-tests, e2e-tests, cli-tests] needs: [basic-tests, e2e-tests]
uses: ./.github/workflows/backend_docker_build_test.yml uses: ./.github/workflows/backend_docker_build_test.yml
secrets: inherit secrets: inherit
graph-db-tests: graph-db-tests:
name: Graph Database Tests name: Graph Database Tests
needs: [basic-tests, e2e-tests, cli-tests] needs: [basic-tests, e2e-tests]
uses: ./.github/workflows/graph_db_tests.yml uses: ./.github/workflows/graph_db_tests.yml
secrets: inherit secrets: inherit
temporal-graph-tests: temporal-graph-tests:
name: Temporal Graph Test name: Temporal Graph Test
needs: [ basic-tests, e2e-tests, cli-tests, graph-db-tests ] needs: [ basic-tests, e2e-tests, graph-db-tests ]
uses: ./.github/workflows/temporal_graph_tests.yml uses: ./.github/workflows/temporal_graph_tests.yml
secrets: inherit secrets: inherit
search-db-tests: search-db-tests:
name: Search Test on Different DBs name: Search Test on Different DBs
needs: [basic-tests, e2e-tests, cli-tests, graph-db-tests] needs: [basic-tests, e2e-tests, graph-db-tests]
uses: ./.github/workflows/search_db_tests.yml uses: ./.github/workflows/search_db_tests.yml
secrets: inherit secrets: inherit
relational-db-migration-tests: relational-db-migration-tests:
name: Relational DB Migration Tests name: Relational DB Migration Tests
needs: [basic-tests, e2e-tests, cli-tests, graph-db-tests] needs: [basic-tests, e2e-tests, graph-db-tests]
uses: ./.github/workflows/relational_db_migration_tests.yml uses: ./.github/workflows/relational_db_migration_tests.yml
secrets: inherit secrets: inherit
notebook-tests: notebook-tests:
name: Notebook Tests name: Notebook Tests
needs: [basic-tests, e2e-tests, cli-tests] needs: [basic-tests, e2e-tests]
uses: ./.github/workflows/notebooks_tests.yml uses: ./.github/workflows/notebooks_tests.yml
secrets: inherit secrets: inherit
different-operating-systems-tests: different-operating-systems-tests:
name: Operating System and Python Tests name: Operating System and Python Tests
needs: [basic-tests, e2e-tests, cli-tests] needs: [basic-tests, e2e-tests]
uses: ./.github/workflows/test_different_operating_systems.yml uses: ./.github/workflows/test_different_operating_systems.yml
with: with:
python-versions: '["3.10.x", "3.11.x", "3.12.x"]' python-versions: '["3.10.x", "3.11.x", "3.12.x"]'
@ -85,20 +85,20 @@ jobs:
# Matrix-based vector database tests # Matrix-based vector database tests
vector-db-tests: vector-db-tests:
name: Vector DB Tests name: Vector DB Tests
needs: [basic-tests, e2e-tests, cli-tests] needs: [basic-tests, e2e-tests]
uses: ./.github/workflows/vector_db_tests.yml uses: ./.github/workflows/vector_db_tests.yml
secrets: inherit secrets: inherit
# Matrix-based example tests # Matrix-based example tests
example-tests: example-tests:
name: Example Tests name: Example Tests
needs: [basic-tests, e2e-tests, cli-tests] needs: [basic-tests, e2e-tests]
uses: ./.github/workflows/examples_tests.yml uses: ./.github/workflows/examples_tests.yml
secrets: inherit secrets: inherit
mcp-test: mcp-test:
name: MCP Tests name: MCP Tests
needs: [basic-tests, e2e-tests, cli-tests] needs: [basic-tests, e2e-tests]
uses: ./.github/workflows/test_mcp.yml uses: ./.github/workflows/test_mcp.yml
secrets: inherit secrets: inherit
@ -110,14 +110,14 @@ jobs:
s3-file-storage-test: s3-file-storage-test:
name: S3 File Storage Test name: S3 File Storage Test
needs: [basic-tests, e2e-tests, cli-tests] needs: [basic-tests, e2e-tests]
uses: ./.github/workflows/test_s3_file_storage.yml uses: ./.github/workflows/test_s3_file_storage.yml
secrets: inherit secrets: inherit
# Additional LLM tests # Additional LLM tests
llm-tests: llm-tests:
name: LLM Test Suite name: LLM Test Suite
needs: [ basic-tests, e2e-tests, cli-tests ] needs: [ basic-tests, e2e-tests ]
uses: ./.github/workflows/test_llms.yml uses: ./.github/workflows/test_llms.yml
secrets: inherit secrets: inherit
@ -127,7 +127,6 @@ jobs:
needs: [ needs: [
basic-tests, basic-tests,
e2e-tests, e2e-tests,
cli-tests,
graph-db-tests, graph-db-tests,
notebook-tests, notebook-tests,
different-operating-systems-tests, different-operating-systems-tests,

View file

@ -1,12 +1,12 @@
from uuid import UUID from uuid import UUID
from typing import Optional from typing import Optional, Union, List, Any
from datetime import datetime from datetime import datetime
from pydantic import Field from pydantic import Field
from fastapi import Depends, APIRouter from fastapi import Depends, APIRouter
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
from cognee.modules.search.types import SearchType from cognee.modules.search.types import SearchType, SearchResult, CombinedSearchResult
from cognee.api.DTO import InDTO, OutDTO from cognee.api.DTO import InDTO, OutDTO
from cognee.modules.users.exceptions.exceptions import PermissionDeniedError from cognee.modules.users.exceptions.exceptions import PermissionDeniedError
from cognee.modules.users.models import User from cognee.modules.users.models import User
@ -73,7 +73,7 @@ def get_search_router() -> APIRouter:
except Exception as error: except Exception as error:
return JSONResponse(status_code=500, content={"error": str(error)}) return JSONResponse(status_code=500, content={"error": str(error)})
@router.post("", response_model=list) @router.post("", response_model=Union[List[SearchResult], CombinedSearchResult, List])
async def search(payload: SearchPayloadDTO, user: User = Depends(get_authenticated_user)): async def search(payload: SearchPayloadDTO, user: User = Depends(get_authenticated_user)):
""" """
Search for nodes in the graph database. Search for nodes in the graph database.

View file

@ -7,7 +7,7 @@ import webbrowser
import zipfile import zipfile
import requests import requests
from pathlib import Path from pathlib import Path
from typing import Optional, Tuple from typing import Callable, Optional, Tuple
import tempfile import tempfile
import shutil import shutil
@ -326,38 +326,93 @@ def prompt_user_for_download() -> bool:
def start_ui( def start_ui(
pid_callback: Callable[[int], None],
host: str = "localhost", host: str = "localhost",
port: int = 3001, port: int = 3000,
open_browser: bool = True, open_browser: bool = True,
auto_download: bool = False, auto_download: bool = False,
start_backend: bool = False,
backend_host: str = "localhost",
backend_port: int = 8000,
) -> Optional[subprocess.Popen]: ) -> Optional[subprocess.Popen]:
""" """
Start the cognee frontend UI server. Start the cognee frontend UI server, optionally with the backend API server.
This function will: This function will:
1. Find the cognee-frontend directory (development) or download it (pip install) 1. Optionally start the cognee backend API server
2. Check if Node.js and npm are available (for development mode) 2. Find the cognee-frontend directory (development) or download it (pip install)
3. Install dependencies if needed (development mode) 3. Check if Node.js and npm are available (for development mode)
4. Start the appropriate server 4. Install dependencies if needed (development mode)
5. Optionally open the browser 5. Start the frontend server
6. Optionally open the browser
Args: Args:
host: Host to bind the server to (default: localhost) pid_callback: Callback to notify with PID of each spawned process
port: Port to run the server on (default: 3001) host: Host to bind the frontend server to (default: localhost)
port: Port to run the frontend server on (default: 3000)
open_browser: Whether to open the browser automatically (default: True) open_browser: Whether to open the browser automatically (default: True)
auto_download: If True, download frontend without prompting (default: False) auto_download: If True, download frontend without prompting (default: False)
start_backend: If True, also start the cognee API backend server (default: False)
backend_host: Host to bind the backend server to (default: localhost)
backend_port: Port to run the backend server on (default: 8000)
Returns: Returns:
subprocess.Popen object representing the running server, or None if failed subprocess.Popen object representing the running frontend server, or None if failed
Note: If backend is started, it runs in a separate process that will be cleaned up
when the frontend process is terminated.
Example: Example:
>>> import cognee >>> import cognee
>>> # Start just the frontend
>>> server = cognee.start_ui() >>> server = cognee.start_ui()
>>> # UI will be available at http://localhost:3001 >>>
>>> # To stop the server later: >>> # Start both frontend and backend
>>> server = cognee.start_ui(start_backend=True)
>>> # UI will be available at http://localhost:3000
>>> # API will be available at http://localhost:8000
>>> # To stop both servers later:
>>> server.terminate() >>> server.terminate()
""" """
logger.info("Starting cognee UI...") logger.info("Starting cognee UI...")
backend_process = None
# Start backend server if requested
if start_backend:
logger.info("Starting cognee backend API server...")
try:
import sys
backend_process = subprocess.Popen(
[
sys.executable,
"-m",
"uvicorn",
"cognee.api.client:app",
"--host",
backend_host,
"--port",
str(backend_port),
],
# Inherit stdout/stderr from parent process to show logs
stdout=None,
stderr=None,
preexec_fn=os.setsid if hasattr(os, "setsid") else None,
)
pid_callback(backend_process.pid)
# Give the backend a moment to start
time.sleep(2)
if backend_process.poll() is not None:
logger.error("Backend server failed to start - process exited early")
return None
logger.info(f"✓ Backend API started at http://{backend_host}:{backend_port}")
except Exception as e:
logger.error(f"Failed to start backend server: {str(e)}")
return None
# Find frontend directory # Find frontend directory
frontend_path = find_frontend_path() frontend_path = find_frontend_path()
@ -406,7 +461,7 @@ def start_ui(
logger.info("This may take a moment to compile and start...") logger.info("This may take a moment to compile and start...")
try: try:
# Use process group to ensure all child processes get terminated together # Create frontend in its own process group for clean termination
process = subprocess.Popen( process = subprocess.Popen(
["npm", "run", "dev"], ["npm", "run", "dev"],
cwd=frontend_path, cwd=frontend_path,
@ -414,11 +469,11 @@ def start_ui(
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, stderr=subprocess.PIPE,
text=True, text=True,
preexec_fn=os.setsid preexec_fn=os.setsid if hasattr(os, "setsid") else None,
if hasattr(os, "setsid")
else None, # Create new process group on Unix
) )
pid_callback(process.pid)
# Give it a moment to start up # Give it a moment to start up
time.sleep(3) time.sleep(3)
@ -447,16 +502,32 @@ def start_ui(
logger.info(f"✓ Open your browser to: http://{host}:{port}") logger.info(f"✓ Open your browser to: http://{host}:{port}")
logger.info("✓ The UI will be available once Next.js finishes compiling") logger.info("✓ The UI will be available once Next.js finishes compiling")
# Store backend process reference in the frontend process for cleanup
if backend_process:
process._cognee_backend_process = backend_process
return process return process
except Exception as e: except Exception as e:
logger.error(f"Failed to start frontend server: {str(e)}") logger.error(f"Failed to start frontend server: {str(e)}")
# Clean up backend process if it was started
if backend_process:
logger.info("Cleaning up backend process due to frontend failure...")
try:
backend_process.terminate()
backend_process.wait(timeout=5)
except (subprocess.TimeoutExpired, OSError, ProcessLookupError):
try:
backend_process.kill()
backend_process.wait()
except (OSError, ProcessLookupError):
pass
return None return None
def stop_ui(process: subprocess.Popen) -> bool: def stop_ui(process: subprocess.Popen) -> bool:
""" """
Stop a running UI server process and all its children. Stop a running UI server process and backend process (if started), along with all their children.
Args: Args:
process: The subprocess.Popen object returned by start_ui() process: The subprocess.Popen object returned by start_ui()
@ -467,7 +538,29 @@ def stop_ui(process: subprocess.Popen) -> bool:
if not process: if not process:
return False return False
success = True
try: try:
# First, stop the backend process if it exists
backend_process = getattr(process, "_cognee_backend_process", None)
if backend_process:
logger.info("Stopping backend server...")
try:
backend_process.terminate()
try:
backend_process.wait(timeout=5)
logger.info("Backend server stopped gracefully")
except subprocess.TimeoutExpired:
logger.warning("Backend didn't terminate gracefully, forcing kill")
backend_process.kill()
backend_process.wait()
logger.info("Backend server stopped")
except Exception as e:
logger.error(f"Error stopping backend server: {str(e)}")
success = False
# Now stop the frontend process
logger.info("Stopping frontend server...")
# Try to terminate the process group (includes child processes like Next.js) # Try to terminate the process group (includes child processes like Next.js)
if hasattr(os, "killpg"): if hasattr(os, "killpg"):
try: try:
@ -484,9 +577,9 @@ def stop_ui(process: subprocess.Popen) -> bool:
try: try:
process.wait(timeout=10) process.wait(timeout=10)
logger.info("UI server stopped gracefully") logger.info("Frontend server stopped gracefully")
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired:
logger.warning("Process didn't terminate gracefully, forcing kill") logger.warning("Frontend didn't terminate gracefully, forcing kill")
# Force kill the process group # Force kill the process group
if hasattr(os, "killpg"): if hasattr(os, "killpg"):
@ -502,11 +595,13 @@ def stop_ui(process: subprocess.Popen) -> bool:
process.wait() process.wait()
logger.info("UI server stopped") if success:
return True logger.info("UI servers stopped successfully")
return success
except Exception as e: except Exception as e:
logger.error(f"Error stopping UI server: {str(e)}") logger.error(f"Error stopping UI servers: {str(e)}")
return False return False

View file

@ -174,30 +174,23 @@ def main() -> int:
# Handle UI flag # Handle UI flag
if hasattr(args, "start_ui") and args.start_ui: if hasattr(args, "start_ui") and args.start_ui:
server_process = None spawned_pids = []
def signal_handler(signum, frame): def signal_handler(signum, frame):
"""Handle Ctrl+C and other termination signals""" """Handle Ctrl+C and other termination signals"""
nonlocal server_process nonlocal spawned_pids
fmt.echo("\nShutting down UI server...") fmt.echo("\nShutting down UI server...")
if server_process:
for pid in spawned_pids:
try: try:
# Try graceful termination first pgid = os.getpgid(pid)
server_process.terminate() os.killpg(pgid, signal.SIGTERM)
try: fmt.success(f"✓ Process group {pgid} (PID {pid}) terminated.")
server_process.wait(timeout=5) except (OSError, ProcessLookupError) as e:
fmt.success("UI server stopped gracefully.") fmt.warning(f"Could not terminate process {pid}: {e}")
except subprocess.TimeoutExpired:
# If graceful termination fails, force kill
fmt.echo("Force stopping UI server...")
server_process.kill()
server_process.wait()
fmt.success("UI server stopped.")
except Exception as e:
fmt.warning(f"Error stopping server: {e}")
sys.exit(0) sys.exit(0)
# Set up signal handlers
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
signal.signal(signal.SIGTERM, signal_handler) # Termination request signal.signal(signal.SIGTERM, signal_handler) # Termination request
@ -205,11 +198,25 @@ def main() -> int:
from cognee import start_ui from cognee import start_ui
fmt.echo("Starting cognee UI...") fmt.echo("Starting cognee UI...")
server_process = start_ui(host="localhost", port=3001, open_browser=True)
# Callback to capture PIDs of all spawned processes
def pid_callback(pid):
nonlocal spawned_pids
spawned_pids.append(pid)
server_process = start_ui(
host="localhost",
port=3000,
open_browser=True,
start_backend=True,
auto_download=True,
pid_callback=pid_callback,
)
if server_process: if server_process:
fmt.success("UI server started successfully!") fmt.success("UI server started successfully!")
fmt.echo("The interface is available at: http://localhost:3001") fmt.echo("The interface is available at: http://localhost:3000")
fmt.echo("The API backend is available at: http://localhost:8000")
fmt.note("Press Ctrl+C to stop the server...") fmt.note("Press Ctrl+C to stop the server...")
try: try:
@ -225,10 +232,12 @@ def main() -> int:
return 0 return 0
else: else:
fmt.error("Failed to start UI server. Check the logs above for details.") fmt.error("Failed to start UI server. Check the logs above for details.")
signal_handler(signal.SIGTERM, None)
return 1 return 1
except Exception as ex: except Exception as ex:
fmt.error(f"Error starting UI: {str(ex)}") fmt.error(f"Error starting UI: {str(ex)}")
signal_handler(signal.SIGTERM, None)
if debug.is_debug_enabled(): if debug.is_debug_enabled():
raise ex raise ex
return 1 return 1

View file

@ -128,4 +128,4 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
question=query, answer=completion, context=context_text, triplets=triplets question=query, answer=completion, context=context_text, triplets=triplets
) )
return completion return [completion]

View file

@ -138,4 +138,4 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
question=query, answer=completion, context=context_text, triplets=triplets question=query, answer=completion, context=context_text, triplets=triplets
) )
return completion return [completion]

View file

@ -171,7 +171,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
question=query, answer=completion, context=context_text, triplets=triplets question=query, answer=completion, context=context_text, triplets=triplets
) )
return completion return [completion]
async def save_qa(self, question: str, answer: str, context: str, triplets: List) -> None: async def save_qa(self, question: str, answer: str, context: str, triplets: List) -> None:
""" """

View file

@ -96,17 +96,18 @@ class InsightsRetriever(BaseGraphRetriever):
unique_node_connections_map[unique_id] = True unique_node_connections_map[unique_id] = True
unique_node_connections.append(node_connection) unique_node_connections.append(node_connection)
return [ return unique_node_connections
Edge( # return [
node1=Node(node_id=connection[0]["id"], attributes=connection[0]), # Edge(
node2=Node(node_id=connection[2]["id"], attributes=connection[2]), # node1=Node(node_id=connection[0]["id"], attributes=connection[0]),
attributes={ # node2=Node(node_id=connection[2]["id"], attributes=connection[2]),
**connection[1], # attributes={
"relationship_type": connection[1]["relationship_name"], # **connection[1],
}, # "relationship_type": connection[1]["relationship_name"],
) # },
for connection in unique_node_connections # )
] # for connection in unique_node_connections
# ]
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any: async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
""" """

View file

@ -149,4 +149,4 @@ class TemporalRetriever(GraphCompletionRetriever):
system_prompt_path=self.system_prompt_path, system_prompt_path=self.system_prompt_path,
) )
return completion return [completion]

View file

@ -132,14 +132,37 @@ async def search(
], ],
) )
else: else:
return [ # This is for maintaining backwards compatibility
SearchResult( if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
search_result=result, return_value = []
dataset_id=datasets[min(index, len(datasets) - 1)].id if datasets else None, for search_result in search_results:
dataset_name=datasets[min(index, len(datasets) - 1)].name if datasets else None, result, context, datasets = search_result
) return_value.append(
for index, (result, _, datasets) in enumerate(search_results) {
] "search_result": result,
"dataset_id": datasets[0].id,
"dataset_name": datasets[0].name,
}
)
return return_value
else:
return_value = []
for search_result in search_results:
result, context, datasets = search_result
return_value.append(result)
# For maintaining backwards compatibility
if len(return_value) == 1 and isinstance(return_value[0], list):
return return_value[0]
else:
return return_value
# return [
# SearchResult(
# search_result=result,
# dataset_id=datasets[min(index, len(datasets) - 1)].id if datasets else None,
# dataset_name=datasets[min(index, len(datasets) - 1)].name if datasets else None,
# )
# for index, (result, _, datasets) in enumerate(search_results)
# ]
async def authorized_search( async def authorized_search(

View file

@ -79,7 +79,7 @@ async def main():
print("\n\nExtracted sentences are:\n") print("\n\nExtracted sentences are:\n")
for result in search_results: for result in search_results:
print(f"{result}\n") print(f"{result}\n")
assert search_results[0].dataset_name == "NLP", ( assert search_results[0]["dataset_name"] == "NLP", (
f"Dict must contain dataset name 'NLP': {search_results[0]}" f"Dict must contain dataset name 'NLP': {search_results[0]}"
) )
@ -93,7 +93,7 @@ async def main():
print("\n\nExtracted sentences are:\n") print("\n\nExtracted sentences are:\n")
for result in search_results: for result in search_results:
print(f"{result}\n") print(f"{result}\n")
assert search_results[0].dataset_name == "QUANTUM", ( assert search_results[0]["dataset_name"] == "QUANTUM", (
f"Dict must contain dataset name 'QUANTUM': {search_results[0]}" f"Dict must contain dataset name 'QUANTUM': {search_results[0]}"
) )
@ -170,7 +170,7 @@ async def main():
for result in search_results: for result in search_results:
print(f"{result}\n") print(f"{result}\n")
assert search_results[0].dataset_name == "QUANTUM", ( assert search_results[0]["dataset_name"] == "QUANTUM", (
f"Dict must contain dataset name 'QUANTUM': {search_results[0]}" f"Dict must contain dataset name 'QUANTUM': {search_results[0]}"
) )

View file

@ -45,15 +45,13 @@ async def relational_db_migration():
await migrate_relational_database(graph_engine, schema=schema) await migrate_relational_database(graph_engine, schema=schema)
# 1. Search the graph # 1. Search the graph
search_results: List[SearchResult] = await cognee.search( search_results = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION, query_text="Tell me about the artist AC/DC" query_type=SearchType.GRAPH_COMPLETION, query_text="Tell me about the artist AC/DC"
) # type: ignore )
print("Search results:", search_results) print("Search results:", search_results)
# 2. Assert that the search results contain "AC/DC" # 2. Assert that the search results contain "AC/DC"
assert any("AC/DC" in r.search_result for r in search_results), ( assert any("AC/DC" in r for r in search_results), "AC/DC not found in search results!"
"AC/DC not found in search results!"
)
migration_db_provider = migration_engine.engine.dialect.name migration_db_provider = migration_engine.engine.dialect.name
if migration_db_provider == "postgresql": if migration_db_provider == "postgresql":

View file

@ -144,13 +144,16 @@ async def main():
("GRAPH_COMPLETION_CONTEXT_EXTENSION", completion_ext), ("GRAPH_COMPLETION_CONTEXT_EXTENSION", completion_ext),
("GRAPH_SUMMARY_COMPLETION", completion_sum), ("GRAPH_SUMMARY_COMPLETION", completion_sum),
]: ]:
for search_result in search_results: assert isinstance(search_results, list), f"{name}: should return a list"
completion = search_result.search_result assert len(search_results) == 1, (
assert isinstance(completion, str), f"{name}: should return a string" f"{name}: expected single-element list, got {len(search_results)}"
assert completion.strip(), f"{name}: string should not be empty" )
assert "netherlands" in completion.lower(), ( text = search_results[0]
f"{name}: expected 'netherlands' in result, got: {completion!r}" assert isinstance(text, str), f"{name}: element should be a string"
) assert text.strip(), f"{name}: string should not be empty"
assert "netherlands" in text.lower(), (
f"{name}: expected 'netherlands' in result, got: {text!r}"
)
graph_engine = await get_graph_engine() graph_engine = await get_graph_engine()
graph = await graph_engine.get_graph_data() graph = await graph_engine.get_graph_data()

View file

@ -59,8 +59,10 @@ class TestGraphCompletionWithContextExtensionRetriever:
answer = await retriever.get_completion("Who works at Canva?") answer = await retriever.get_completion("Who works at Canva?")
assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}" assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert answer.strip(), "Answer must contain only non-empty strings" assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_graph_completion_extension_context_complex(self): async def test_graph_completion_extension_context_complex(self):
@ -140,8 +142,10 @@ class TestGraphCompletionWithContextExtensionRetriever:
answer = await retriever.get_completion("Who works at Figma?") answer = await retriever.get_completion("Who works at Figma?")
assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}" assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert answer.strip(), "Answer must contain only non-empty strings" assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_graph_completion_extension_context_on_empty_graph(self): async def test_get_graph_completion_extension_context_on_empty_graph(self):
@ -171,5 +175,7 @@ class TestGraphCompletionWithContextExtensionRetriever:
answer = await retriever.get_completion("Who works at Figma?") answer = await retriever.get_completion("Who works at Figma?")
assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}" assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert answer.strip(), "Answer must contain only non-empty strings" assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)

View file

@ -55,8 +55,10 @@ class TestGraphCompletionCoTRetriever:
answer = await retriever.get_completion("Who works at Canva?") answer = await retriever.get_completion("Who works at Canva?")
assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}" assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert answer.strip(), "Answer must contain only non-empty strings" assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_graph_completion_cot_context_complex(self): async def test_graph_completion_cot_context_complex(self):
@ -133,8 +135,10 @@ class TestGraphCompletionCoTRetriever:
answer = await retriever.get_completion("Who works at Figma?") answer = await retriever.get_completion("Who works at Figma?")
assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}" assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert answer.strip(), "Answer must contain only non-empty strings" assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_graph_completion_cot_context_on_empty_graph(self): async def test_get_graph_completion_cot_context_on_empty_graph(self):
@ -164,5 +168,7 @@ class TestGraphCompletionCoTRetriever:
answer = await retriever.get_completion("Who works at Figma?") answer = await retriever.get_completion("Who works at Figma?")
assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}" assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
assert answer.strip(), "Answer must contain only non-empty strings" assert all(isinstance(item, str) and item.strip() for item in answer), (
"Answer must contain only non-empty strings"
)

View file

@ -82,7 +82,7 @@ class TestInsightsRetriever:
context = await retriever.get_context("Mike") context = await retriever.get_context("Mike")
assert context[0].node1.attributes["name"] == "Mike Broski", "Failed to get Mike Broski" assert context[0][0]["name"] == "Mike Broski", "Failed to get Mike Broski"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_insights_context_complex(self): async def test_insights_context_complex(self):
@ -222,9 +222,7 @@ class TestInsightsRetriever:
context = await retriever.get_context("Christina") context = await retriever.get_context("Christina")
assert context[0].node1.attributes["name"] == "Christina Mayer", ( assert context[0][0]["name"] == "Christina Mayer", "Failed to get Christina Mayer"
"Failed to get Christina Mayer"
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_insights_context_on_empty_graph(self): async def test_insights_context_on_empty_graph(self):

2
poetry.lock generated
View file

@ -11728,4 +11728,4 @@ posthog = ["posthog"]
[metadata] [metadata]
lock-version = "2.1" lock-version = "2.1"
python-versions = ">=3.10,<=3.13" python-versions = ">=3.10,<=3.13"
content-hash = "576318d370b89d128a7c3e755fe3c898fef4e359acdd3f05f952ae497751fb04" content-hash = "1e8cdbf6919cea9657d51b7839630dac7a0d8a2815eca0bd811838a282051625"

View file

@ -1,7 +1,7 @@
[project] [project]
name = "cognee" name = "cognee"
version = "0.3.0" version = "0.3.2"
description = "Cognee - is a library for enriching LLM context with a semantic layer for better understanding and reasoning." description = "Cognee - is a library for enriching LLM context with a semantic layer for better understanding and reasoning."
authors = [ authors = [
{ name = "Vasilije Markovic" }, { name = "Vasilije Markovic" },

4
uv.lock generated
View file

@ -807,7 +807,7 @@ wheels = [
[[package]] [[package]]
name = "cognee" name = "cognee"
version = "0.2.4" version = "0.3.2"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "aiofiles" }, { name = "aiofiles" },
@ -1029,7 +1029,7 @@ requires-dist = [
{ name = "pylance", specifier = ">=0.22.0,<1.0.0" }, { name = "pylance", specifier = ">=0.22.0,<1.0.0" },
{ name = "pylint", marker = "extra == 'dev'", specifier = ">=3.0.3,<4" }, { name = "pylint", marker = "extra == 'dev'", specifier = ">=3.0.3,<4" },
{ name = "pympler", specifier = ">=1.1,<2.0.0" }, { name = "pympler", specifier = ">=1.1,<2.0.0" },
{ name = "pypdf", specifier = ">=4.1.0,<6.0.0" }, { name = "pypdf", specifier = ">=4.1.0,<7.0.0" },
{ name = "pypika", marker = "extra == 'chromadb'", specifier = "==0.48.8" }, { name = "pypika", marker = "extra == 'chromadb'", specifier = "==0.48.8" },
{ name = "pyside6", marker = "extra == 'gui'", specifier = ">=6.8.3,<7" }, { name = "pyside6", marker = "extra == 'gui'", specifier = ">=6.8.3,<7" },
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=7.4.0,<8" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=7.4.0,<8" },