Merge branch 'dev' into feature/cog-3698-enable-batch-queries-in-all-graph-completion-retrievers
This commit is contained in:
commit
ab7b5d5445
8 changed files with 97 additions and 37 deletions
10
CLAUDE.md
10
CLAUDE.md
|
|
@ -427,10 +427,12 @@ git checkout -b feature/your-feature-name
|
||||||
|
|
||||||
## Code Style
|
## Code Style
|
||||||
|
|
||||||
- Ruff for linting and formatting (configured in `pyproject.toml`)
|
- **Formatter**: Ruff (configured in `pyproject.toml`)
|
||||||
- Line length: 100 characters
|
- **Line length**: 100 characters
|
||||||
- Pre-commit hooks run ruff automatically
|
- **String quotes**: Use double quotes `"` not single quotes `'` (enforced by ruff-format)
|
||||||
- Type hints encouraged (mypy checks enabled)
|
- **Pre-commit hooks**: Run ruff linting and formatting automatically
|
||||||
|
- **Type hints**: Encouraged (mypy checks enabled)
|
||||||
|
- **Important**: Always run `pre-commit run --all-files` before committing to catch formatting issues
|
||||||
|
|
||||||
## Testing Strategy
|
## Testing Strategy
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -252,7 +252,7 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
|
||||||
chunk_size: int = None,
|
chunk_size: int = None,
|
||||||
config: Config = None,
|
config: Config = None,
|
||||||
custom_prompt: Optional[str] = None,
|
custom_prompt: Optional[str] = None,
|
||||||
chunks_per_batch: int = 100,
|
chunks_per_batch: int = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> list[Task]:
|
) -> list[Task]:
|
||||||
if config is None:
|
if config is None:
|
||||||
|
|
@ -272,12 +272,14 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
|
||||||
"ontology_config": {"ontology_resolver": get_default_ontology_resolver()}
|
"ontology_config": {"ontology_resolver": get_default_ontology_resolver()}
|
||||||
}
|
}
|
||||||
|
|
||||||
if chunks_per_batch is None:
|
|
||||||
chunks_per_batch = 100
|
|
||||||
|
|
||||||
cognify_config = get_cognify_config()
|
cognify_config = get_cognify_config()
|
||||||
embed_triplets = cognify_config.triplet_embedding
|
embed_triplets = cognify_config.triplet_embedding
|
||||||
|
|
||||||
|
if chunks_per_batch is None:
|
||||||
|
chunks_per_batch = (
|
||||||
|
cognify_config.chunks_per_batch if cognify_config.chunks_per_batch is not None else 100
|
||||||
|
)
|
||||||
|
|
||||||
default_tasks = [
|
default_tasks = [
|
||||||
Task(classify_documents),
|
Task(classify_documents),
|
||||||
Task(
|
Task(
|
||||||
|
|
@ -308,7 +310,7 @@ async def get_default_tasks( # TODO: Find out a better way to do this (Boris's
|
||||||
|
|
||||||
|
|
||||||
async def get_temporal_tasks(
|
async def get_temporal_tasks(
|
||||||
user: User = None, chunker=TextChunker, chunk_size: int = None, chunks_per_batch: int = 10
|
user: User = None, chunker=TextChunker, chunk_size: int = None, chunks_per_batch: int = None
|
||||||
) -> list[Task]:
|
) -> list[Task]:
|
||||||
"""
|
"""
|
||||||
Builds and returns a list of temporal processing tasks to be executed in sequence.
|
Builds and returns a list of temporal processing tasks to be executed in sequence.
|
||||||
|
|
@ -330,7 +332,10 @@ async def get_temporal_tasks(
|
||||||
list[Task]: A list of Task objects representing the temporal processing pipeline.
|
list[Task]: A list of Task objects representing the temporal processing pipeline.
|
||||||
"""
|
"""
|
||||||
if chunks_per_batch is None:
|
if chunks_per_batch is None:
|
||||||
chunks_per_batch = 10
|
from cognee.modules.cognify.config import get_cognify_config
|
||||||
|
|
||||||
|
configured = get_cognify_config().chunks_per_batch
|
||||||
|
chunks_per_batch = configured if configured is not None else 10
|
||||||
|
|
||||||
temporal_tasks = [
|
temporal_tasks = [
|
||||||
Task(classify_documents),
|
Task(classify_documents),
|
||||||
|
|
|
||||||
|
|
@ -46,6 +46,11 @@ class CognifyPayloadDTO(InDTO):
|
||||||
examples=[[]],
|
examples=[[]],
|
||||||
description="Reference to one or more previously uploaded ontologies",
|
description="Reference to one or more previously uploaded ontologies",
|
||||||
)
|
)
|
||||||
|
chunks_per_batch: Optional[int] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Number of chunks to process per task batch in Cognify (overrides default).",
|
||||||
|
examples=[10, 20, 50, 100],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_cognify_router() -> APIRouter:
|
def get_cognify_router() -> APIRouter:
|
||||||
|
|
@ -146,6 +151,7 @@ def get_cognify_router() -> APIRouter:
|
||||||
config=config_to_use,
|
config=config_to_use,
|
||||||
run_in_background=payload.run_in_background,
|
run_in_background=payload.run_in_background,
|
||||||
custom_prompt=payload.custom_prompt,
|
custom_prompt=payload.custom_prompt,
|
||||||
|
chunks_per_batch=payload.chunks_per_batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
# If any cognify run errored return JSONResponse with proper error status code
|
# If any cognify run errored return JSONResponse with proper error status code
|
||||||
|
|
|
||||||
|
|
@ -62,6 +62,11 @@ After successful cognify processing, use `cognee search` to query the knowledge
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--verbose", "-v", action="store_true", help="Show detailed progress information"
|
"--verbose", "-v", action="store_true", help="Show detailed progress information"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--chunks-per-batch",
|
||||||
|
type=int,
|
||||||
|
help="Number of chunks to process per task batch (try 50 for large single documents).",
|
||||||
|
)
|
||||||
|
|
||||||
def execute(self, args: argparse.Namespace) -> None:
|
def execute(self, args: argparse.Namespace) -> None:
|
||||||
try:
|
try:
|
||||||
|
|
@ -111,6 +116,7 @@ After successful cognify processing, use `cognee search` to query the knowledge
|
||||||
chunk_size=args.chunk_size,
|
chunk_size=args.chunk_size,
|
||||||
ontology_file_path=args.ontology_file,
|
ontology_file_path=args.ontology_file,
|
||||||
run_in_background=args.background,
|
run_in_background=args.background,
|
||||||
|
chunks_per_batch=getattr(args, "chunks_per_batch", None),
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,13 @@
|
||||||
import os
|
import os
|
||||||
|
import aiohttp
|
||||||
import asyncio
|
import asyncio
|
||||||
import requests
|
|
||||||
import base64
|
import base64
|
||||||
import hashlib
|
import hashlib
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from urllib.parse import urlparse
|
||||||
from cryptography.fernet import Fernet
|
from cryptography.fernet import Fernet
|
||||||
|
from aiohttp import BasicAuth
|
||||||
|
|
||||||
from cognee.infrastructure.databases.graph import get_graph_config
|
from cognee.infrastructure.databases.graph import get_graph_config
|
||||||
from cognee.modules.users.models import User, DatasetDatabase
|
from cognee.modules.users.models import User, DatasetDatabase
|
||||||
|
|
@ -23,7 +25,6 @@ class Neo4jAuraDevDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
||||||
|
|
||||||
Quality of life improvements:
|
Quality of life improvements:
|
||||||
- Allow configuration of different Neo4j Aura plans and regions.
|
- Allow configuration of different Neo4j Aura plans and regions.
|
||||||
- Requests should be made async, currently a blocking requests library is used.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -49,6 +50,7 @@ class Neo4jAuraDevDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
||||||
graph_db_name = f"{dataset_id}"
|
graph_db_name = f"{dataset_id}"
|
||||||
|
|
||||||
# Client credentials and encryption
|
# Client credentials and encryption
|
||||||
|
# Note: Should not be used as class variables so that they are not persisted in memory longer than needed
|
||||||
client_id = os.environ.get("NEO4J_CLIENT_ID", None)
|
client_id = os.environ.get("NEO4J_CLIENT_ID", None)
|
||||||
client_secret = os.environ.get("NEO4J_CLIENT_SECRET", None)
|
client_secret = os.environ.get("NEO4J_CLIENT_SECRET", None)
|
||||||
tenant_id = os.environ.get("NEO4J_TENANT_ID", None)
|
tenant_id = os.environ.get("NEO4J_TENANT_ID", None)
|
||||||
|
|
@ -63,22 +65,13 @@ class Neo4jAuraDevDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
||||||
"NEO4J_CLIENT_ID, NEO4J_CLIENT_SECRET, and NEO4J_TENANT_ID environment variables must be set to use Neo4j Aura DatasetDatabase Handling."
|
"NEO4J_CLIENT_ID, NEO4J_CLIENT_SECRET, and NEO4J_TENANT_ID environment variables must be set to use Neo4j Aura DatasetDatabase Handling."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Make the request with HTTP Basic Auth
|
resp_token = await cls._get_aura_token(client_id, client_secret)
|
||||||
def get_aura_token(client_id: str, client_secret: str) -> dict:
|
|
||||||
url = "https://api.neo4j.io/oauth/token"
|
|
||||||
data = {"grant_type": "client_credentials"} # sent as application/x-www-form-urlencoded
|
|
||||||
|
|
||||||
resp = requests.post(url, data=data, auth=(client_id, client_secret))
|
|
||||||
resp.raise_for_status() # raises if the request failed
|
|
||||||
return resp.json()
|
|
||||||
|
|
||||||
resp = get_aura_token(client_id, client_secret)
|
|
||||||
|
|
||||||
url = "https://api.neo4j.io/v1/instances"
|
url = "https://api.neo4j.io/v1/instances"
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
"accept": "application/json",
|
"accept": "application/json",
|
||||||
"Authorization": f"Bearer {resp['access_token']}",
|
"Authorization": f"Bearer {resp_token['access_token']}",
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -96,31 +89,38 @@ class Neo4jAuraDevDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
||||||
"cloud_provider": "gcp",
|
"cloud_provider": "gcp",
|
||||||
}
|
}
|
||||||
|
|
||||||
response = requests.post(url, headers=headers, json=payload)
|
async def _create_database_instance_request():
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(url, headers=headers, json=payload) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
return await resp.json()
|
||||||
|
|
||||||
|
resp_create = await _create_database_instance_request()
|
||||||
|
|
||||||
graph_db_name = "neo4j" # Has to be 'neo4j' for Aura
|
graph_db_name = "neo4j" # Has to be 'neo4j' for Aura
|
||||||
graph_db_url = response.json()["data"]["connection_url"]
|
graph_db_url = resp_create["data"]["connection_url"]
|
||||||
graph_db_key = resp["access_token"]
|
graph_db_key = resp_token["access_token"]
|
||||||
graph_db_username = response.json()["data"]["username"]
|
graph_db_username = resp_create["data"]["username"]
|
||||||
graph_db_password = response.json()["data"]["password"]
|
graph_db_password = resp_create["data"]["password"]
|
||||||
|
|
||||||
async def _wait_for_neo4j_instance_provisioning(instance_id: str, headers: dict):
|
async def _wait_for_neo4j_instance_provisioning(instance_id: str, headers: dict):
|
||||||
# Poll until the instance is running
|
# Poll until the instance is running
|
||||||
status_url = f"https://api.neo4j.io/v1/instances/{instance_id}"
|
status_url = f"https://api.neo4j.io/v1/instances/{instance_id}"
|
||||||
status = ""
|
status = ""
|
||||||
for attempt in range(30): # Try for up to ~5 minutes
|
for attempt in range(30): # Try for up to ~5 minutes
|
||||||
status_resp = requests.get(
|
async with aiohttp.ClientSession() as session:
|
||||||
status_url, headers=headers
|
async with session.get(status_url, headers=headers) as resp:
|
||||||
) # TODO: Use async requests with httpx
|
resp.raise_for_status()
|
||||||
status = status_resp.json()["data"]["status"]
|
status_resp = await resp.json()
|
||||||
if status.lower() == "running":
|
status = status_resp["data"]["status"]
|
||||||
return
|
if status.lower() == "running":
|
||||||
await asyncio.sleep(10)
|
return
|
||||||
|
await asyncio.sleep(10)
|
||||||
raise TimeoutError(
|
raise TimeoutError(
|
||||||
f"Neo4j instance '{graph_db_name}' did not become ready within 5 minutes. Status: {status}"
|
f"Neo4j instance '{graph_db_name}' did not become ready within 5 minutes. Status: {status}"
|
||||||
)
|
)
|
||||||
|
|
||||||
instance_id = response.json()["data"]["id"]
|
instance_id = resp_create["data"]["id"]
|
||||||
await _wait_for_neo4j_instance_provisioning(instance_id, headers)
|
await _wait_for_neo4j_instance_provisioning(instance_id, headers)
|
||||||
|
|
||||||
encrypted_db_password_bytes = cipher.encrypt(graph_db_password.encode())
|
encrypted_db_password_bytes = cipher.encrypt(graph_db_password.encode())
|
||||||
|
|
@ -165,4 +165,39 @@ class Neo4jAuraDevDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def delete_dataset(cls, dataset_database: DatasetDatabase):
|
async def delete_dataset(cls, dataset_database: DatasetDatabase):
|
||||||
pass
|
# Get dataset database information and credentials
|
||||||
|
dataset_database = await cls.resolve_dataset_connection_info(dataset_database)
|
||||||
|
|
||||||
|
parsed_url = urlparse(dataset_database.graph_database_url)
|
||||||
|
instance_id = parsed_url.hostname.split(".")[0]
|
||||||
|
|
||||||
|
url = f"https://api.neo4j.io/v1/instances/{instance_id}"
|
||||||
|
|
||||||
|
# Get access token for Neo4j Aura API
|
||||||
|
# Client credentials
|
||||||
|
client_id = os.environ.get("NEO4J_CLIENT_ID", None)
|
||||||
|
client_secret = os.environ.get("NEO4J_CLIENT_SECRET", None)
|
||||||
|
resp = await cls._get_aura_token(client_id, client_secret)
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"accept": "application/json",
|
||||||
|
"Authorization": f"Bearer {resp['access_token']}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.delete(url, headers=headers) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
return await resp.json()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _get_aura_token(cls, client_id: str, client_secret: str) -> dict:
|
||||||
|
url = "https://api.neo4j.io/oauth/token"
|
||||||
|
data = {"grant_type": "client_credentials"} # sent as application/x-www-form-urlencoded
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
url, data=data, auth=BasicAuth(client_id, client_secret)
|
||||||
|
) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
return await resp.json()
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ class CognifyConfig(BaseSettings):
|
||||||
classification_model: object = DefaultContentPrediction
|
classification_model: object = DefaultContentPrediction
|
||||||
summarization_model: object = SummarizedContent
|
summarization_model: object = SummarizedContent
|
||||||
triplet_embedding: bool = False
|
triplet_embedding: bool = False
|
||||||
|
chunks_per_batch: Optional[int] = None
|
||||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
def to_dict(self) -> dict:
|
||||||
|
|
@ -16,6 +17,7 @@ class CognifyConfig(BaseSettings):
|
||||||
"classification_model": self.classification_model,
|
"classification_model": self.classification_model,
|
||||||
"summarization_model": self.summarization_model,
|
"summarization_model": self.summarization_model,
|
||||||
"triplet_embedding": self.triplet_embedding,
|
"triplet_embedding": self.triplet_embedding,
|
||||||
|
"chunks_per_batch": self.chunks_per_batch,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -238,6 +238,7 @@ class TestCognifyCommand:
|
||||||
ontology_file_path=None,
|
ontology_file_path=None,
|
||||||
chunker=TextChunker,
|
chunker=TextChunker,
|
||||||
run_in_background=False,
|
run_in_background=False,
|
||||||
|
chunks_per_batch=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch("cognee.cli.commands.cognify_command.asyncio.run")
|
@patch("cognee.cli.commands.cognify_command.asyncio.run")
|
||||||
|
|
|
||||||
|
|
@ -262,6 +262,7 @@ class TestCognifyCommandEdgeCases:
|
||||||
ontology_file_path=None,
|
ontology_file_path=None,
|
||||||
chunker=TextChunker,
|
chunker=TextChunker,
|
||||||
run_in_background=False,
|
run_in_background=False,
|
||||||
|
chunks_per_batch=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch("cognee.cli.commands.cognify_command.asyncio.run", side_effect=_mock_run)
|
@patch("cognee.cli.commands.cognify_command.asyncio.run", side_effect=_mock_run)
|
||||||
|
|
@ -295,6 +296,7 @@ class TestCognifyCommandEdgeCases:
|
||||||
ontology_file_path="/nonexistent/path/ontology.owl",
|
ontology_file_path="/nonexistent/path/ontology.owl",
|
||||||
chunker=TextChunker,
|
chunker=TextChunker,
|
||||||
run_in_background=False,
|
run_in_background=False,
|
||||||
|
chunks_per_batch=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch("cognee.cli.commands.cognify_command.asyncio.run")
|
@patch("cognee.cli.commands.cognify_command.asyncio.run")
|
||||||
|
|
@ -373,6 +375,7 @@ class TestCognifyCommandEdgeCases:
|
||||||
ontology_file_path=None,
|
ontology_file_path=None,
|
||||||
chunker=TextChunker,
|
chunker=TextChunker,
|
||||||
run_in_background=False,
|
run_in_background=False,
|
||||||
|
chunks_per_batch=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue