Merge branch 'dev' into feature/cog-2923-create-ci-test-for-fastembed
This commit is contained in:
commit
8d7738d713
29 changed files with 285 additions and 122 deletions
10
.github/workflows/cli_tests.yml
vendored
10
.github/workflows/cli_tests.yml
vendored
|
|
@ -60,7 +60,7 @@ jobs:
|
|||
python-version: ${{ inputs.python-version }}
|
||||
|
||||
- name: Run CLI Unit Tests
|
||||
run: uv run pytest cognee/tests/unit/cli/ -v
|
||||
run: uv run pytest cognee/tests/cli_tests/cli_unit_tests/ -v
|
||||
|
||||
cli-integration-tests:
|
||||
name: CLI Integration Tests
|
||||
|
|
@ -87,7 +87,7 @@ jobs:
|
|||
python-version: ${{ inputs.python-version }}
|
||||
|
||||
- name: Run CLI Integration Tests
|
||||
run: uv run pytest cognee/tests/integration/cli/ -v
|
||||
run: uv run pytest cognee/tests/cli_tests/cli_integration_tests/ -v
|
||||
|
||||
cli-functionality-tests:
|
||||
name: CLI Functionality Tests
|
||||
|
|
@ -135,12 +135,12 @@ jobs:
|
|||
run: |
|
||||
# Test invalid command (should fail gracefully)
|
||||
! uv run python -m cognee.cli._cognee invalid_command
|
||||
|
||||
|
||||
# Test missing required arguments (should fail gracefully)
|
||||
! uv run python -m cognee.cli._cognee search
|
||||
|
||||
|
||||
# Test invalid search type (should fail gracefully)
|
||||
! uv run python -m cognee.cli._cognee search "test" --query-type INVALID_TYPE
|
||||
|
||||
|
||||
# Test invalid chunker (should fail gracefully)
|
||||
! uv run python -m cognee.cli._cognee cognify --chunker InvalidChunker
|
||||
|
|
|
|||
25
.github/workflows/examples_tests.yml
vendored
25
.github/workflows/examples_tests.yml
vendored
|
|
@ -106,3 +106,28 @@ jobs:
|
|||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./examples/python/dynamic_steps_example.py
|
||||
|
||||
test-memify:
|
||||
name: Run Memify Example
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
|
||||
- name: Run Memify Tests
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./examples/python/memify_coding_agent_example.py
|
||||
|
|
|
|||
27
.github/workflows/test_suites.yml
vendored
27
.github/workflows/test_suites.yml
vendored
|
|
@ -34,49 +34,49 @@ jobs:
|
|||
|
||||
docker-compose-test:
|
||||
name: Docker Compose Test
|
||||
needs: [basic-tests, e2e-tests, cli-tests]
|
||||
needs: [basic-tests, e2e-tests]
|
||||
uses: ./.github/workflows/docker_compose.yml
|
||||
secrets: inherit
|
||||
|
||||
docker-ci-test:
|
||||
name: Docker CI test
|
||||
needs: [basic-tests, e2e-tests, cli-tests]
|
||||
needs: [basic-tests, e2e-tests]
|
||||
uses: ./.github/workflows/backend_docker_build_test.yml
|
||||
secrets: inherit
|
||||
|
||||
graph-db-tests:
|
||||
name: Graph Database Tests
|
||||
needs: [basic-tests, e2e-tests, cli-tests]
|
||||
needs: [basic-tests, e2e-tests]
|
||||
uses: ./.github/workflows/graph_db_tests.yml
|
||||
secrets: inherit
|
||||
|
||||
temporal-graph-tests:
|
||||
name: Temporal Graph Test
|
||||
needs: [ basic-tests, e2e-tests, cli-tests, graph-db-tests ]
|
||||
needs: [ basic-tests, e2e-tests, graph-db-tests ]
|
||||
uses: ./.github/workflows/temporal_graph_tests.yml
|
||||
secrets: inherit
|
||||
|
||||
search-db-tests:
|
||||
name: Search Test on Different DBs
|
||||
needs: [basic-tests, e2e-tests, cli-tests, graph-db-tests]
|
||||
needs: [basic-tests, e2e-tests, graph-db-tests]
|
||||
uses: ./.github/workflows/search_db_tests.yml
|
||||
secrets: inherit
|
||||
|
||||
relational-db-migration-tests:
|
||||
name: Relational DB Migration Tests
|
||||
needs: [basic-tests, e2e-tests, cli-tests, graph-db-tests]
|
||||
needs: [basic-tests, e2e-tests, graph-db-tests]
|
||||
uses: ./.github/workflows/relational_db_migration_tests.yml
|
||||
secrets: inherit
|
||||
|
||||
notebook-tests:
|
||||
name: Notebook Tests
|
||||
needs: [basic-tests, e2e-tests, cli-tests]
|
||||
needs: [basic-tests, e2e-tests]
|
||||
uses: ./.github/workflows/notebooks_tests.yml
|
||||
secrets: inherit
|
||||
|
||||
different-operating-systems-tests:
|
||||
name: Operating System and Python Tests
|
||||
needs: [basic-tests, e2e-tests, cli-tests]
|
||||
needs: [basic-tests, e2e-tests]
|
||||
uses: ./.github/workflows/test_different_operating_systems.yml
|
||||
with:
|
||||
python-versions: '["3.10.x", "3.11.x", "3.12.x"]'
|
||||
|
|
@ -85,20 +85,20 @@ jobs:
|
|||
# Matrix-based vector database tests
|
||||
vector-db-tests:
|
||||
name: Vector DB Tests
|
||||
needs: [basic-tests, e2e-tests, cli-tests]
|
||||
needs: [basic-tests, e2e-tests]
|
||||
uses: ./.github/workflows/vector_db_tests.yml
|
||||
secrets: inherit
|
||||
|
||||
# Matrix-based example tests
|
||||
example-tests:
|
||||
name: Example Tests
|
||||
needs: [basic-tests, e2e-tests, cli-tests]
|
||||
needs: [basic-tests, e2e-tests]
|
||||
uses: ./.github/workflows/examples_tests.yml
|
||||
secrets: inherit
|
||||
|
||||
mcp-test:
|
||||
name: MCP Tests
|
||||
needs: [basic-tests, e2e-tests, cli-tests]
|
||||
needs: [basic-tests, e2e-tests]
|
||||
uses: ./.github/workflows/test_mcp.yml
|
||||
secrets: inherit
|
||||
|
||||
|
|
@ -110,14 +110,14 @@ jobs:
|
|||
|
||||
s3-file-storage-test:
|
||||
name: S3 File Storage Test
|
||||
needs: [basic-tests, e2e-tests, cli-tests]
|
||||
needs: [basic-tests, e2e-tests]
|
||||
uses: ./.github/workflows/test_s3_file_storage.yml
|
||||
secrets: inherit
|
||||
|
||||
# Additional LLM tests
|
||||
llm-tests:
|
||||
name: LLM Test Suite
|
||||
needs: [ basic-tests, e2e-tests, cli-tests ]
|
||||
needs: [ basic-tests, e2e-tests ]
|
||||
uses: ./.github/workflows/test_llms.yml
|
||||
secrets: inherit
|
||||
|
||||
|
|
@ -127,7 +127,6 @@ jobs:
|
|||
needs: [
|
||||
basic-tests,
|
||||
e2e-tests,
|
||||
cli-tests,
|
||||
graph-db-tests,
|
||||
notebook-tests,
|
||||
different-operating-systems-tests,
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
from uuid import UUID
|
||||
from typing import Optional
|
||||
from typing import Optional, Union, List, Any
|
||||
from datetime import datetime
|
||||
from pydantic import Field
|
||||
from fastapi import Depends, APIRouter
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.modules.search.types import SearchType, SearchResult, CombinedSearchResult
|
||||
from cognee.api.DTO import InDTO, OutDTO
|
||||
from cognee.modules.users.exceptions.exceptions import PermissionDeniedError
|
||||
from cognee.modules.users.models import User
|
||||
|
|
@ -73,7 +73,7 @@ def get_search_router() -> APIRouter:
|
|||
except Exception as error:
|
||||
return JSONResponse(status_code=500, content={"error": str(error)})
|
||||
|
||||
@router.post("", response_model=list)
|
||||
@router.post("", response_model=Union[List[SearchResult], CombinedSearchResult, List])
|
||||
async def search(payload: SearchPayloadDTO, user: User = Depends(get_authenticated_user)):
|
||||
"""
|
||||
Search for nodes in the graph database.
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import webbrowser
|
|||
import zipfile
|
||||
import requests
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
from typing import Callable, Optional, Tuple
|
||||
import tempfile
|
||||
import shutil
|
||||
|
||||
|
|
@ -326,38 +326,93 @@ def prompt_user_for_download() -> bool:
|
|||
|
||||
|
||||
def start_ui(
|
||||
pid_callback: Callable[[int], None],
|
||||
host: str = "localhost",
|
||||
port: int = 3001,
|
||||
port: int = 3000,
|
||||
open_browser: bool = True,
|
||||
auto_download: bool = False,
|
||||
start_backend: bool = False,
|
||||
backend_host: str = "localhost",
|
||||
backend_port: int = 8000,
|
||||
) -> Optional[subprocess.Popen]:
|
||||
"""
|
||||
Start the cognee frontend UI server.
|
||||
Start the cognee frontend UI server, optionally with the backend API server.
|
||||
|
||||
This function will:
|
||||
1. Find the cognee-frontend directory (development) or download it (pip install)
|
||||
2. Check if Node.js and npm are available (for development mode)
|
||||
3. Install dependencies if needed (development mode)
|
||||
4. Start the appropriate server
|
||||
5. Optionally open the browser
|
||||
1. Optionally start the cognee backend API server
|
||||
2. Find the cognee-frontend directory (development) or download it (pip install)
|
||||
3. Check if Node.js and npm are available (for development mode)
|
||||
4. Install dependencies if needed (development mode)
|
||||
5. Start the frontend server
|
||||
6. Optionally open the browser
|
||||
|
||||
Args:
|
||||
host: Host to bind the server to (default: localhost)
|
||||
port: Port to run the server on (default: 3001)
|
||||
pid_callback: Callback to notify with PID of each spawned process
|
||||
host: Host to bind the frontend server to (default: localhost)
|
||||
port: Port to run the frontend server on (default: 3000)
|
||||
open_browser: Whether to open the browser automatically (default: True)
|
||||
auto_download: If True, download frontend without prompting (default: False)
|
||||
start_backend: If True, also start the cognee API backend server (default: False)
|
||||
backend_host: Host to bind the backend server to (default: localhost)
|
||||
backend_port: Port to run the backend server on (default: 8000)
|
||||
|
||||
Returns:
|
||||
subprocess.Popen object representing the running server, or None if failed
|
||||
subprocess.Popen object representing the running frontend server, or None if failed
|
||||
Note: If backend is started, it runs in a separate process that will be cleaned up
|
||||
when the frontend process is terminated.
|
||||
|
||||
Example:
|
||||
>>> import cognee
|
||||
>>> # Start just the frontend
|
||||
>>> server = cognee.start_ui()
|
||||
>>> # UI will be available at http://localhost:3001
|
||||
>>> # To stop the server later:
|
||||
>>>
|
||||
>>> # Start both frontend and backend
|
||||
>>> server = cognee.start_ui(start_backend=True)
|
||||
>>> # UI will be available at http://localhost:3000
|
||||
>>> # API will be available at http://localhost:8000
|
||||
>>> # To stop both servers later:
|
||||
>>> server.terminate()
|
||||
"""
|
||||
logger.info("Starting cognee UI...")
|
||||
backend_process = None
|
||||
|
||||
# Start backend server if requested
|
||||
if start_backend:
|
||||
logger.info("Starting cognee backend API server...")
|
||||
try:
|
||||
import sys
|
||||
|
||||
backend_process = subprocess.Popen(
|
||||
[
|
||||
sys.executable,
|
||||
"-m",
|
||||
"uvicorn",
|
||||
"cognee.api.client:app",
|
||||
"--host",
|
||||
backend_host,
|
||||
"--port",
|
||||
str(backend_port),
|
||||
],
|
||||
# Inherit stdout/stderr from parent process to show logs
|
||||
stdout=None,
|
||||
stderr=None,
|
||||
preexec_fn=os.setsid if hasattr(os, "setsid") else None,
|
||||
)
|
||||
|
||||
pid_callback(backend_process.pid)
|
||||
|
||||
# Give the backend a moment to start
|
||||
time.sleep(2)
|
||||
|
||||
if backend_process.poll() is not None:
|
||||
logger.error("Backend server failed to start - process exited early")
|
||||
return None
|
||||
|
||||
logger.info(f"✓ Backend API started at http://{backend_host}:{backend_port}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start backend server: {str(e)}")
|
||||
return None
|
||||
|
||||
# Find frontend directory
|
||||
frontend_path = find_frontend_path()
|
||||
|
|
@ -406,7 +461,7 @@ def start_ui(
|
|||
logger.info("This may take a moment to compile and start...")
|
||||
|
||||
try:
|
||||
# Use process group to ensure all child processes get terminated together
|
||||
# Create frontend in its own process group for clean termination
|
||||
process = subprocess.Popen(
|
||||
["npm", "run", "dev"],
|
||||
cwd=frontend_path,
|
||||
|
|
@ -414,11 +469,11 @@ def start_ui(
|
|||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
preexec_fn=os.setsid
|
||||
if hasattr(os, "setsid")
|
||||
else None, # Create new process group on Unix
|
||||
preexec_fn=os.setsid if hasattr(os, "setsid") else None,
|
||||
)
|
||||
|
||||
pid_callback(process.pid)
|
||||
|
||||
# Give it a moment to start up
|
||||
time.sleep(3)
|
||||
|
||||
|
|
@ -447,16 +502,32 @@ def start_ui(
|
|||
logger.info(f"✓ Open your browser to: http://{host}:{port}")
|
||||
logger.info("✓ The UI will be available once Next.js finishes compiling")
|
||||
|
||||
# Store backend process reference in the frontend process for cleanup
|
||||
if backend_process:
|
||||
process._cognee_backend_process = backend_process
|
||||
|
||||
return process
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start frontend server: {str(e)}")
|
||||
# Clean up backend process if it was started
|
||||
if backend_process:
|
||||
logger.info("Cleaning up backend process due to frontend failure...")
|
||||
try:
|
||||
backend_process.terminate()
|
||||
backend_process.wait(timeout=5)
|
||||
except (subprocess.TimeoutExpired, OSError, ProcessLookupError):
|
||||
try:
|
||||
backend_process.kill()
|
||||
backend_process.wait()
|
||||
except (OSError, ProcessLookupError):
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def stop_ui(process: subprocess.Popen) -> bool:
|
||||
"""
|
||||
Stop a running UI server process and all its children.
|
||||
Stop a running UI server process and backend process (if started), along with all their children.
|
||||
|
||||
Args:
|
||||
process: The subprocess.Popen object returned by start_ui()
|
||||
|
|
@ -467,7 +538,29 @@ def stop_ui(process: subprocess.Popen) -> bool:
|
|||
if not process:
|
||||
return False
|
||||
|
||||
success = True
|
||||
|
||||
try:
|
||||
# First, stop the backend process if it exists
|
||||
backend_process = getattr(process, "_cognee_backend_process", None)
|
||||
if backend_process:
|
||||
logger.info("Stopping backend server...")
|
||||
try:
|
||||
backend_process.terminate()
|
||||
try:
|
||||
backend_process.wait(timeout=5)
|
||||
logger.info("Backend server stopped gracefully")
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning("Backend didn't terminate gracefully, forcing kill")
|
||||
backend_process.kill()
|
||||
backend_process.wait()
|
||||
logger.info("Backend server stopped")
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping backend server: {str(e)}")
|
||||
success = False
|
||||
|
||||
# Now stop the frontend process
|
||||
logger.info("Stopping frontend server...")
|
||||
# Try to terminate the process group (includes child processes like Next.js)
|
||||
if hasattr(os, "killpg"):
|
||||
try:
|
||||
|
|
@ -484,9 +577,9 @@ def stop_ui(process: subprocess.Popen) -> bool:
|
|||
|
||||
try:
|
||||
process.wait(timeout=10)
|
||||
logger.info("UI server stopped gracefully")
|
||||
logger.info("Frontend server stopped gracefully")
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning("Process didn't terminate gracefully, forcing kill")
|
||||
logger.warning("Frontend didn't terminate gracefully, forcing kill")
|
||||
|
||||
# Force kill the process group
|
||||
if hasattr(os, "killpg"):
|
||||
|
|
@ -502,11 +595,13 @@ def stop_ui(process: subprocess.Popen) -> bool:
|
|||
|
||||
process.wait()
|
||||
|
||||
logger.info("UI server stopped")
|
||||
return True
|
||||
if success:
|
||||
logger.info("UI servers stopped successfully")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping UI server: {str(e)}")
|
||||
logger.error(f"Error stopping UI servers: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -174,30 +174,23 @@ def main() -> int:
|
|||
|
||||
# Handle UI flag
|
||||
if hasattr(args, "start_ui") and args.start_ui:
|
||||
server_process = None
|
||||
spawned_pids = []
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
"""Handle Ctrl+C and other termination signals"""
|
||||
nonlocal server_process
|
||||
nonlocal spawned_pids
|
||||
fmt.echo("\nShutting down UI server...")
|
||||
if server_process:
|
||||
|
||||
for pid in spawned_pids:
|
||||
try:
|
||||
# Try graceful termination first
|
||||
server_process.terminate()
|
||||
try:
|
||||
server_process.wait(timeout=5)
|
||||
fmt.success("UI server stopped gracefully.")
|
||||
except subprocess.TimeoutExpired:
|
||||
# If graceful termination fails, force kill
|
||||
fmt.echo("Force stopping UI server...")
|
||||
server_process.kill()
|
||||
server_process.wait()
|
||||
fmt.success("UI server stopped.")
|
||||
except Exception as e:
|
||||
fmt.warning(f"Error stopping server: {e}")
|
||||
pgid = os.getpgid(pid)
|
||||
os.killpg(pgid, signal.SIGTERM)
|
||||
fmt.success(f"✓ Process group {pgid} (PID {pid}) terminated.")
|
||||
except (OSError, ProcessLookupError) as e:
|
||||
fmt.warning(f"Could not terminate process {pid}: {e}")
|
||||
|
||||
sys.exit(0)
|
||||
|
||||
# Set up signal handlers
|
||||
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
|
||||
signal.signal(signal.SIGTERM, signal_handler) # Termination request
|
||||
|
||||
|
|
@ -205,11 +198,25 @@ def main() -> int:
|
|||
from cognee import start_ui
|
||||
|
||||
fmt.echo("Starting cognee UI...")
|
||||
server_process = start_ui(host="localhost", port=3001, open_browser=True)
|
||||
|
||||
# Callback to capture PIDs of all spawned processes
|
||||
def pid_callback(pid):
|
||||
nonlocal spawned_pids
|
||||
spawned_pids.append(pid)
|
||||
|
||||
server_process = start_ui(
|
||||
host="localhost",
|
||||
port=3000,
|
||||
open_browser=True,
|
||||
start_backend=True,
|
||||
auto_download=True,
|
||||
pid_callback=pid_callback,
|
||||
)
|
||||
|
||||
if server_process:
|
||||
fmt.success("UI server started successfully!")
|
||||
fmt.echo("The interface is available at: http://localhost:3001")
|
||||
fmt.echo("The interface is available at: http://localhost:3000")
|
||||
fmt.echo("The API backend is available at: http://localhost:8000")
|
||||
fmt.note("Press Ctrl+C to stop the server...")
|
||||
|
||||
try:
|
||||
|
|
@ -225,10 +232,12 @@ def main() -> int:
|
|||
return 0
|
||||
else:
|
||||
fmt.error("Failed to start UI server. Check the logs above for details.")
|
||||
signal_handler(signal.SIGTERM, None)
|
||||
return 1
|
||||
|
||||
except Exception as ex:
|
||||
fmt.error(f"Error starting UI: {str(ex)}")
|
||||
signal_handler(signal.SIGTERM, None)
|
||||
if debug.is_debug_enabled():
|
||||
raise ex
|
||||
return 1
|
||||
|
|
|
|||
|
|
@ -128,4 +128,4 @@ class GraphCompletionContextExtensionRetriever(GraphCompletionRetriever):
|
|||
question=query, answer=completion, context=context_text, triplets=triplets
|
||||
)
|
||||
|
||||
return completion
|
||||
return [completion]
|
||||
|
|
|
|||
|
|
@ -138,4 +138,4 @@ class GraphCompletionCotRetriever(GraphCompletionRetriever):
|
|||
question=query, answer=completion, context=context_text, triplets=triplets
|
||||
)
|
||||
|
||||
return completion
|
||||
return [completion]
|
||||
|
|
|
|||
|
|
@ -171,7 +171,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|||
question=query, answer=completion, context=context_text, triplets=triplets
|
||||
)
|
||||
|
||||
return completion
|
||||
return [completion]
|
||||
|
||||
async def save_qa(self, question: str, answer: str, context: str, triplets: List) -> None:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -96,17 +96,18 @@ class InsightsRetriever(BaseGraphRetriever):
|
|||
unique_node_connections_map[unique_id] = True
|
||||
unique_node_connections.append(node_connection)
|
||||
|
||||
return [
|
||||
Edge(
|
||||
node1=Node(node_id=connection[0]["id"], attributes=connection[0]),
|
||||
node2=Node(node_id=connection[2]["id"], attributes=connection[2]),
|
||||
attributes={
|
||||
**connection[1],
|
||||
"relationship_type": connection[1]["relationship_name"],
|
||||
},
|
||||
)
|
||||
for connection in unique_node_connections
|
||||
]
|
||||
return unique_node_connections
|
||||
# return [
|
||||
# Edge(
|
||||
# node1=Node(node_id=connection[0]["id"], attributes=connection[0]),
|
||||
# node2=Node(node_id=connection[2]["id"], attributes=connection[2]),
|
||||
# attributes={
|
||||
# **connection[1],
|
||||
# "relationship_type": connection[1]["relationship_name"],
|
||||
# },
|
||||
# )
|
||||
# for connection in unique_node_connections
|
||||
# ]
|
||||
|
||||
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -149,4 +149,4 @@ class TemporalRetriever(GraphCompletionRetriever):
|
|||
system_prompt_path=self.system_prompt_path,
|
||||
)
|
||||
|
||||
return completion
|
||||
return [completion]
|
||||
|
|
|
|||
|
|
@ -132,14 +132,37 @@ async def search(
|
|||
],
|
||||
)
|
||||
else:
|
||||
return [
|
||||
SearchResult(
|
||||
search_result=result,
|
||||
dataset_id=datasets[min(index, len(datasets) - 1)].id if datasets else None,
|
||||
dataset_name=datasets[min(index, len(datasets) - 1)].name if datasets else None,
|
||||
)
|
||||
for index, (result, _, datasets) in enumerate(search_results)
|
||||
]
|
||||
# This is for maintaining backwards compatibility
|
||||
if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
|
||||
return_value = []
|
||||
for search_result in search_results:
|
||||
result, context, datasets = search_result
|
||||
return_value.append(
|
||||
{
|
||||
"search_result": result,
|
||||
"dataset_id": datasets[0].id,
|
||||
"dataset_name": datasets[0].name,
|
||||
}
|
||||
)
|
||||
return return_value
|
||||
else:
|
||||
return_value = []
|
||||
for search_result in search_results:
|
||||
result, context, datasets = search_result
|
||||
return_value.append(result)
|
||||
# For maintaining backwards compatibility
|
||||
if len(return_value) == 1 and isinstance(return_value[0], list):
|
||||
return return_value[0]
|
||||
else:
|
||||
return return_value
|
||||
# return [
|
||||
# SearchResult(
|
||||
# search_result=result,
|
||||
# dataset_id=datasets[min(index, len(datasets) - 1)].id if datasets else None,
|
||||
# dataset_name=datasets[min(index, len(datasets) - 1)].name if datasets else None,
|
||||
# )
|
||||
# for index, (result, _, datasets) in enumerate(search_results)
|
||||
# ]
|
||||
|
||||
|
||||
async def authorized_search(
|
||||
|
|
|
|||
|
|
@ -79,7 +79,7 @@ async def main():
|
|||
print("\n\nExtracted sentences are:\n")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
assert search_results[0].dataset_name == "NLP", (
|
||||
assert search_results[0]["dataset_name"] == "NLP", (
|
||||
f"Dict must contain dataset name 'NLP': {search_results[0]}"
|
||||
)
|
||||
|
||||
|
|
@ -93,7 +93,7 @@ async def main():
|
|||
print("\n\nExtracted sentences are:\n")
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
assert search_results[0].dataset_name == "QUANTUM", (
|
||||
assert search_results[0]["dataset_name"] == "QUANTUM", (
|
||||
f"Dict must contain dataset name 'QUANTUM': {search_results[0]}"
|
||||
)
|
||||
|
||||
|
|
@ -170,7 +170,7 @@ async def main():
|
|||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
||||
assert search_results[0].dataset_name == "QUANTUM", (
|
||||
assert search_results[0]["dataset_name"] == "QUANTUM", (
|
||||
f"Dict must contain dataset name 'QUANTUM': {search_results[0]}"
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -45,15 +45,13 @@ async def relational_db_migration():
|
|||
await migrate_relational_database(graph_engine, schema=schema)
|
||||
|
||||
# 1. Search the graph
|
||||
search_results: List[SearchResult] = await cognee.search(
|
||||
search_results = await cognee.search(
|
||||
query_type=SearchType.GRAPH_COMPLETION, query_text="Tell me about the artist AC/DC"
|
||||
) # type: ignore
|
||||
)
|
||||
print("Search results:", search_results)
|
||||
|
||||
# 2. Assert that the search results contain "AC/DC"
|
||||
assert any("AC/DC" in r.search_result for r in search_results), (
|
||||
"AC/DC not found in search results!"
|
||||
)
|
||||
assert any("AC/DC" in r for r in search_results), "AC/DC not found in search results!"
|
||||
|
||||
migration_db_provider = migration_engine.engine.dialect.name
|
||||
if migration_db_provider == "postgresql":
|
||||
|
|
|
|||
|
|
@ -144,13 +144,16 @@ async def main():
|
|||
("GRAPH_COMPLETION_CONTEXT_EXTENSION", completion_ext),
|
||||
("GRAPH_SUMMARY_COMPLETION", completion_sum),
|
||||
]:
|
||||
for search_result in search_results:
|
||||
completion = search_result.search_result
|
||||
assert isinstance(completion, str), f"{name}: should return a string"
|
||||
assert completion.strip(), f"{name}: string should not be empty"
|
||||
assert "netherlands" in completion.lower(), (
|
||||
f"{name}: expected 'netherlands' in result, got: {completion!r}"
|
||||
)
|
||||
assert isinstance(search_results, list), f"{name}: should return a list"
|
||||
assert len(search_results) == 1, (
|
||||
f"{name}: expected single-element list, got {len(search_results)}"
|
||||
)
|
||||
text = search_results[0]
|
||||
assert isinstance(text, str), f"{name}: element should be a string"
|
||||
assert text.strip(), f"{name}: string should not be empty"
|
||||
assert "netherlands" in text.lower(), (
|
||||
f"{name}: expected 'netherlands' in result, got: {text!r}"
|
||||
)
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
graph = await graph_engine.get_graph_data()
|
||||
|
|
|
|||
|
|
@ -59,8 +59,10 @@ class TestGraphCompletionWithContextExtensionRetriever:
|
|||
|
||||
answer = await retriever.get_completion("Who works at Canva?")
|
||||
|
||||
assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}"
|
||||
assert answer.strip(), "Answer must contain only non-empty strings"
|
||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||
"Answer must contain only non-empty strings"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_completion_extension_context_complex(self):
|
||||
|
|
@ -140,8 +142,10 @@ class TestGraphCompletionWithContextExtensionRetriever:
|
|||
|
||||
answer = await retriever.get_completion("Who works at Figma?")
|
||||
|
||||
assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}"
|
||||
assert answer.strip(), "Answer must contain only non-empty strings"
|
||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||
"Answer must contain only non-empty strings"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_graph_completion_extension_context_on_empty_graph(self):
|
||||
|
|
@ -171,5 +175,7 @@ class TestGraphCompletionWithContextExtensionRetriever:
|
|||
|
||||
answer = await retriever.get_completion("Who works at Figma?")
|
||||
|
||||
assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}"
|
||||
assert answer.strip(), "Answer must contain only non-empty strings"
|
||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||
"Answer must contain only non-empty strings"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -55,8 +55,10 @@ class TestGraphCompletionCoTRetriever:
|
|||
|
||||
answer = await retriever.get_completion("Who works at Canva?")
|
||||
|
||||
assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}"
|
||||
assert answer.strip(), "Answer must contain only non-empty strings"
|
||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||
"Answer must contain only non-empty strings"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_completion_cot_context_complex(self):
|
||||
|
|
@ -133,8 +135,10 @@ class TestGraphCompletionCoTRetriever:
|
|||
|
||||
answer = await retriever.get_completion("Who works at Figma?")
|
||||
|
||||
assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}"
|
||||
assert answer.strip(), "Answer must contain only non-empty strings"
|
||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||
"Answer must contain only non-empty strings"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_graph_completion_cot_context_on_empty_graph(self):
|
||||
|
|
@ -164,5 +168,7 @@ class TestGraphCompletionCoTRetriever:
|
|||
|
||||
answer = await retriever.get_completion("Who works at Figma?")
|
||||
|
||||
assert isinstance(answer, str), f"Expected string, got {type(answer).__name__}"
|
||||
assert answer.strip(), "Answer must contain only non-empty strings"
|
||||
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
||||
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
||||
"Answer must contain only non-empty strings"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -82,7 +82,7 @@ class TestInsightsRetriever:
|
|||
|
||||
context = await retriever.get_context("Mike")
|
||||
|
||||
assert context[0].node1.attributes["name"] == "Mike Broski", "Failed to get Mike Broski"
|
||||
assert context[0][0]["name"] == "Mike Broski", "Failed to get Mike Broski"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_insights_context_complex(self):
|
||||
|
|
@ -222,9 +222,7 @@ class TestInsightsRetriever:
|
|||
|
||||
context = await retriever.get_context("Christina")
|
||||
|
||||
assert context[0].node1.attributes["name"] == "Christina Mayer", (
|
||||
"Failed to get Christina Mayer"
|
||||
)
|
||||
assert context[0][0]["name"] == "Christina Mayer", "Failed to get Christina Mayer"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_insights_context_on_empty_graph(self):
|
||||
|
|
|
|||
2
poetry.lock
generated
2
poetry.lock
generated
|
|
@ -11728,4 +11728,4 @@ posthog = ["posthog"]
|
|||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<=3.13"
|
||||
content-hash = "576318d370b89d128a7c3e755fe3c898fef4e359acdd3f05f952ae497751fb04"
|
||||
content-hash = "1e8cdbf6919cea9657d51b7839630dac7a0d8a2815eca0bd811838a282051625"
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
[project]
|
||||
name = "cognee"
|
||||
|
||||
version = "0.3.0"
|
||||
version = "0.3.2"
|
||||
description = "Cognee - is a library for enriching LLM context with a semantic layer for better understanding and reasoning."
|
||||
authors = [
|
||||
{ name = "Vasilije Markovic" },
|
||||
|
|
|
|||
4
uv.lock
generated
4
uv.lock
generated
|
|
@ -807,7 +807,7 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "cognee"
|
||||
version = "0.2.4"
|
||||
version = "0.3.2"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "aiofiles" },
|
||||
|
|
@ -1029,7 +1029,7 @@ requires-dist = [
|
|||
{ name = "pylance", specifier = ">=0.22.0,<1.0.0" },
|
||||
{ name = "pylint", marker = "extra == 'dev'", specifier = ">=3.0.3,<4" },
|
||||
{ name = "pympler", specifier = ">=1.1,<2.0.0" },
|
||||
{ name = "pypdf", specifier = ">=4.1.0,<6.0.0" },
|
||||
{ name = "pypdf", specifier = ">=4.1.0,<7.0.0" },
|
||||
{ name = "pypika", marker = "extra == 'chromadb'", specifier = "==0.48.8" },
|
||||
{ name = "pyside6", marker = "extra == 'gui'", specifier = ">=6.8.3,<7" },
|
||||
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=7.4.0,<8" },
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue