Merge branch 'dev' of https://github.com/topoteretes/cognee
This commit is contained in:
commit
d6da7a999b
37 changed files with 1198 additions and 52 deletions
|
|
@ -21,6 +21,10 @@ LLM_PROVIDER="openai"
|
|||
LLM_ENDPOINT=""
|
||||
LLM_API_VERSION=""
|
||||
LLM_MAX_TOKENS="16384"
|
||||
# Instructor's modes determine how structured data is requested from and extracted from LLM responses
|
||||
# You can change this type (i.e. mode) via this env variable
|
||||
# Each LLM has its own default value, e.g. gpt-5 models have "json_schema_mode"
|
||||
LLM_INSTRUCTOR_MODE=""
|
||||
|
||||
EMBEDDING_PROVIDER="openai"
|
||||
EMBEDDING_MODEL="openai/text-embedding-3-large"
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from cognee.api.v1.settings.routers import get_settings_router
|
|||
from cognee.api.v1.datasets.routers import get_datasets_router
|
||||
from cognee.api.v1.cognify.routers import get_code_pipeline_router, get_cognify_router
|
||||
from cognee.api.v1.search.routers import get_search_router
|
||||
from cognee.api.v1.ontologies.routers.get_ontology_router import get_ontology_router
|
||||
from cognee.api.v1.memify.routers import get_memify_router
|
||||
from cognee.api.v1.add.routers import get_add_router
|
||||
from cognee.api.v1.delete.routers import get_delete_router
|
||||
|
|
@ -263,6 +264,8 @@ app.include_router(
|
|||
|
||||
app.include_router(get_datasets_router(), prefix="/api/v1/datasets", tags=["datasets"])
|
||||
|
||||
app.include_router(get_ontology_router(), prefix="/api/v1/ontologies", tags=["ontologies"])
|
||||
|
||||
app.include_router(get_settings_router(), prefix="/api/v1/settings", tags=["settings"])
|
||||
|
||||
app.include_router(get_visualize_router(), prefix="/api/v1/visualize", tags=["visualize"])
|
||||
|
|
|
|||
|
|
@ -41,6 +41,9 @@ class CognifyPayloadDTO(InDTO):
|
|||
custom_prompt: Optional[str] = Field(
|
||||
default="", description="Custom prompt for entity extraction and graph generation"
|
||||
)
|
||||
ontology_key: Optional[List[str]] = Field(
|
||||
default=None, description="Reference to one or more previously uploaded ontologies"
|
||||
)
|
||||
|
||||
|
||||
def get_cognify_router() -> APIRouter:
|
||||
|
|
@ -68,6 +71,7 @@ def get_cognify_router() -> APIRouter:
|
|||
- **dataset_ids** (Optional[List[UUID]]): List of existing dataset UUIDs to process. UUIDs allow processing of datasets not owned by the user (if permitted).
|
||||
- **run_in_background** (Optional[bool]): Whether to execute processing asynchronously. Defaults to False (blocking).
|
||||
- **custom_prompt** (Optional[str]): Custom prompt for entity extraction and graph generation. If provided, this prompt will be used instead of the default prompts for knowledge graph extraction.
|
||||
- **ontology_key** (Optional[List[str]]): Reference to one or more previously uploaded ontology files to use for knowledge graph construction.
|
||||
|
||||
## Response
|
||||
- **Blocking execution**: Complete pipeline run information with entity counts, processing duration, and success/failure status
|
||||
|
|
@ -82,7 +86,8 @@ def get_cognify_router() -> APIRouter:
|
|||
{
|
||||
"datasets": ["research_papers", "documentation"],
|
||||
"run_in_background": false,
|
||||
"custom_prompt": "Extract entities focusing on technical concepts and their relationships. Identify key technologies, methodologies, and their interconnections."
|
||||
"custom_prompt": "Extract entities focusing on technical concepts and their relationships. Identify key technologies, methodologies, and their interconnections.",
|
||||
"ontology_key": ["medical_ontology_v1"]
|
||||
}
|
||||
```
|
||||
|
||||
|
|
@ -108,13 +113,35 @@ def get_cognify_router() -> APIRouter:
|
|||
)
|
||||
|
||||
from cognee.api.v1.cognify import cognify as cognee_cognify
|
||||
from cognee.api.v1.ontologies.ontologies import OntologyService
|
||||
|
||||
try:
|
||||
datasets = payload.dataset_ids if payload.dataset_ids else payload.datasets
|
||||
config_to_use = None
|
||||
|
||||
if payload.ontology_key:
|
||||
ontology_service = OntologyService()
|
||||
ontology_contents = ontology_service.get_ontology_contents(
|
||||
payload.ontology_key, user
|
||||
)
|
||||
|
||||
from cognee.modules.ontology.ontology_config import Config
|
||||
from cognee.modules.ontology.rdf_xml.RDFLibOntologyResolver import (
|
||||
RDFLibOntologyResolver,
|
||||
)
|
||||
from io import StringIO
|
||||
|
||||
ontology_streams = [StringIO(content) for content in ontology_contents]
|
||||
config_to_use: Config = {
|
||||
"ontology_config": {
|
||||
"ontology_resolver": RDFLibOntologyResolver(ontology_file=ontology_streams)
|
||||
}
|
||||
}
|
||||
|
||||
cognify_run = await cognee_cognify(
|
||||
datasets,
|
||||
user,
|
||||
config=config_to_use,
|
||||
run_in_background=payload.run_in_background,
|
||||
custom_prompt=payload.custom_prompt,
|
||||
)
|
||||
|
|
|
|||
4
cognee/api/v1/ontologies/__init__.py
Normal file
4
cognee/api/v1/ontologies/__init__.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
from .ontologies import OntologyService
|
||||
from .routers.get_ontology_router import get_ontology_router
|
||||
|
||||
__all__ = ["OntologyService", "get_ontology_router"]
|
||||
183
cognee/api/v1/ontologies/ontologies.py
Normal file
183
cognee/api/v1/ontologies/ontologies.py
Normal file
|
|
@ -0,0 +1,183 @@
|
|||
import os
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, List
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class OntologyMetadata:
|
||||
ontology_key: str
|
||||
filename: str
|
||||
size_bytes: int
|
||||
uploaded_at: str
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class OntologyService:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def base_dir(self) -> Path:
|
||||
return Path(tempfile.gettempdir()) / "ontologies"
|
||||
|
||||
def _get_user_dir(self, user_id: str) -> Path:
|
||||
user_dir = self.base_dir / str(user_id)
|
||||
user_dir.mkdir(parents=True, exist_ok=True)
|
||||
return user_dir
|
||||
|
||||
def _get_metadata_path(self, user_dir: Path) -> Path:
|
||||
return user_dir / "metadata.json"
|
||||
|
||||
def _load_metadata(self, user_dir: Path) -> dict:
|
||||
metadata_path = self._get_metadata_path(user_dir)
|
||||
if metadata_path.exists():
|
||||
with open(metadata_path, "r") as f:
|
||||
return json.load(f)
|
||||
return {}
|
||||
|
||||
def _save_metadata(self, user_dir: Path, metadata: dict):
|
||||
metadata_path = self._get_metadata_path(user_dir)
|
||||
with open(metadata_path, "w") as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
|
||||
async def upload_ontology(
|
||||
self, ontology_key: str, file, user, description: Optional[str] = None
|
||||
) -> OntologyMetadata:
|
||||
if not file.filename.lower().endswith(".owl"):
|
||||
raise ValueError("File must be in .owl format")
|
||||
|
||||
user_dir = self._get_user_dir(str(user.id))
|
||||
metadata = self._load_metadata(user_dir)
|
||||
|
||||
if ontology_key in metadata:
|
||||
raise ValueError(f"Ontology key '{ontology_key}' already exists")
|
||||
|
||||
content = await file.read()
|
||||
if len(content) > 10 * 1024 * 1024:
|
||||
raise ValueError("File size exceeds 10MB limit")
|
||||
|
||||
file_path = user_dir / f"{ontology_key}.owl"
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
ontology_metadata = {
|
||||
"filename": file.filename,
|
||||
"size_bytes": len(content),
|
||||
"uploaded_at": datetime.now(timezone.utc).isoformat(),
|
||||
"description": description,
|
||||
}
|
||||
metadata[ontology_key] = ontology_metadata
|
||||
self._save_metadata(user_dir, metadata)
|
||||
|
||||
return OntologyMetadata(
|
||||
ontology_key=ontology_key,
|
||||
filename=file.filename,
|
||||
size_bytes=len(content),
|
||||
uploaded_at=ontology_metadata["uploaded_at"],
|
||||
description=description,
|
||||
)
|
||||
|
||||
async def upload_ontologies(
|
||||
self, ontology_key: List[str], files: List, user, descriptions: Optional[List[str]] = None
|
||||
) -> List[OntologyMetadata]:
|
||||
"""
|
||||
Upload ontology files with their respective keys.
|
||||
|
||||
Args:
|
||||
ontology_key: List of unique keys for each ontology
|
||||
files: List of UploadFile objects (same length as keys)
|
||||
user: Authenticated user
|
||||
descriptions: Optional list of descriptions for each file
|
||||
|
||||
Returns:
|
||||
List of OntologyMetadata objects for uploaded files
|
||||
|
||||
Raises:
|
||||
ValueError: If keys duplicate, file format invalid, or array lengths don't match
|
||||
"""
|
||||
if len(ontology_key) != len(files):
|
||||
raise ValueError("Number of keys must match number of files")
|
||||
|
||||
if len(set(ontology_key)) != len(ontology_key):
|
||||
raise ValueError("Duplicate ontology keys not allowed")
|
||||
|
||||
if descriptions and len(descriptions) != len(files):
|
||||
raise ValueError("Number of descriptions must match number of files")
|
||||
|
||||
results = []
|
||||
user_dir = self._get_user_dir(str(user.id))
|
||||
metadata = self._load_metadata(user_dir)
|
||||
|
||||
for i, (key, file) in enumerate(zip(ontology_key, files)):
|
||||
if key in metadata:
|
||||
raise ValueError(f"Ontology key '{key}' already exists")
|
||||
|
||||
if not file.filename.lower().endswith(".owl"):
|
||||
raise ValueError(f"File '{file.filename}' must be in .owl format")
|
||||
|
||||
content = await file.read()
|
||||
if len(content) > 10 * 1024 * 1024:
|
||||
raise ValueError(f"File '{file.filename}' exceeds 10MB limit")
|
||||
|
||||
file_path = user_dir / f"{key}.owl"
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
ontology_metadata = {
|
||||
"filename": file.filename,
|
||||
"size_bytes": len(content),
|
||||
"uploaded_at": datetime.now(timezone.utc).isoformat(),
|
||||
"description": descriptions[i] if descriptions else None,
|
||||
}
|
||||
metadata[key] = ontology_metadata
|
||||
|
||||
results.append(
|
||||
OntologyMetadata(
|
||||
ontology_key=key,
|
||||
filename=file.filename,
|
||||
size_bytes=len(content),
|
||||
uploaded_at=ontology_metadata["uploaded_at"],
|
||||
description=descriptions[i] if descriptions else None,
|
||||
)
|
||||
)
|
||||
|
||||
self._save_metadata(user_dir, metadata)
|
||||
return results
|
||||
|
||||
def get_ontology_contents(self, ontology_key: List[str], user) -> List[str]:
|
||||
"""
|
||||
Retrieve ontology content for one or more keys.
|
||||
|
||||
Args:
|
||||
ontology_key: List of ontology keys to retrieve (can contain single item)
|
||||
user: Authenticated user
|
||||
|
||||
Returns:
|
||||
List of ontology content strings
|
||||
|
||||
Raises:
|
||||
ValueError: If any ontology key not found
|
||||
"""
|
||||
user_dir = self._get_user_dir(str(user.id))
|
||||
metadata = self._load_metadata(user_dir)
|
||||
|
||||
contents = []
|
||||
for key in ontology_key:
|
||||
if key not in metadata:
|
||||
raise ValueError(f"Ontology key '{key}' not found")
|
||||
|
||||
file_path = user_dir / f"{key}.owl"
|
||||
if not file_path.exists():
|
||||
raise ValueError(f"Ontology file for key '{key}' not found")
|
||||
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
contents.append(f.read())
|
||||
return contents
|
||||
|
||||
def list_ontologies(self, user) -> dict:
|
||||
user_dir = self._get_user_dir(str(user.id))
|
||||
return self._load_metadata(user_dir)
|
||||
0
cognee/api/v1/ontologies/routers/__init__.py
Normal file
0
cognee/api/v1/ontologies/routers/__init__.py
Normal file
107
cognee/api/v1/ontologies/routers/get_ontology_router.py
Normal file
107
cognee/api/v1/ontologies/routers/get_ontology_router.py
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
from fastapi import APIRouter, File, Form, UploadFile, Depends, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from typing import Optional, List
|
||||
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.methods import get_authenticated_user
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee import __version__ as cognee_version
|
||||
from ..ontologies import OntologyService
|
||||
|
||||
|
||||
def get_ontology_router() -> APIRouter:
|
||||
router = APIRouter()
|
||||
ontology_service = OntologyService()
|
||||
|
||||
@router.post("", response_model=dict)
|
||||
async def upload_ontology(
|
||||
ontology_key: str = Form(...),
|
||||
ontology_file: List[UploadFile] = File(...),
|
||||
descriptions: Optional[str] = Form(None),
|
||||
user: User = Depends(get_authenticated_user),
|
||||
):
|
||||
"""
|
||||
Upload ontology files with their respective keys for later use in cognify operations.
|
||||
|
||||
Supports both single and multiple file uploads:
|
||||
- Single file: ontology_key=["key"], ontology_file=[file]
|
||||
- Multiple files: ontology_key=["key1", "key2"], ontology_file=[file1, file2]
|
||||
|
||||
## Request Parameters
|
||||
- **ontology_key** (str): JSON array string of user-defined identifiers for the ontologies
|
||||
- **ontology_file** (List[UploadFile]): OWL format ontology files
|
||||
- **descriptions** (Optional[str]): JSON array string of optional descriptions
|
||||
|
||||
## Response
|
||||
Returns metadata about uploaded ontologies including keys, filenames, sizes, and upload timestamps.
|
||||
|
||||
## Error Codes
|
||||
- **400 Bad Request**: Invalid file format, duplicate keys, array length mismatches, file size exceeded
|
||||
- **500 Internal Server Error**: File system or processing errors
|
||||
"""
|
||||
send_telemetry(
|
||||
"Ontology Upload API Endpoint Invoked",
|
||||
user.id,
|
||||
additional_properties={
|
||||
"endpoint": "POST /api/v1/ontologies",
|
||||
"cognee_version": cognee_version,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
import json
|
||||
|
||||
ontology_keys = json.loads(ontology_key)
|
||||
description_list = json.loads(descriptions) if descriptions else None
|
||||
|
||||
if not isinstance(ontology_keys, list):
|
||||
raise ValueError("ontology_key must be a JSON array")
|
||||
|
||||
results = await ontology_service.upload_ontologies(
|
||||
ontology_keys, ontology_file, user, description_list
|
||||
)
|
||||
|
||||
return {
|
||||
"uploaded_ontologies": [
|
||||
{
|
||||
"ontology_key": result.ontology_key,
|
||||
"filename": result.filename,
|
||||
"size_bytes": result.size_bytes,
|
||||
"uploaded_at": result.uploaded_at,
|
||||
"description": result.description,
|
||||
}
|
||||
for result in results
|
||||
]
|
||||
}
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
return JSONResponse(status_code=400, content={"error": str(e)})
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=500, content={"error": str(e)})
|
||||
|
||||
@router.get("", response_model=dict)
|
||||
async def list_ontologies(user: User = Depends(get_authenticated_user)):
|
||||
"""
|
||||
List all uploaded ontologies for the authenticated user.
|
||||
|
||||
## Response
|
||||
Returns a dictionary mapping ontology keys to their metadata including filename, size, and upload timestamp.
|
||||
|
||||
## Error Codes
|
||||
- **500 Internal Server Error**: File system or processing errors
|
||||
"""
|
||||
send_telemetry(
|
||||
"Ontology List API Endpoint Invoked",
|
||||
user.id,
|
||||
additional_properties={
|
||||
"endpoint": "GET /api/v1/ontologies",
|
||||
"cognee_version": cognee_version,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
metadata = ontology_service.list_ontologies(user)
|
||||
return metadata
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=500, content={"error": str(e)})
|
||||
|
||||
return router
|
||||
|
|
@ -22,7 +22,7 @@ relationships, and creates semantic connections for enhanced search and reasonin
|
|||
|
||||
Processing Pipeline:
|
||||
1. **Document Classification**: Identifies document types and structures
|
||||
2. **Permission Validation**: Ensures user has processing rights
|
||||
2. **Permission Validation**: Ensures user has processing rights
|
||||
3. **Text Chunking**: Breaks content into semantically meaningful segments
|
||||
4. **Entity Extraction**: Identifies key concepts, people, places, organizations
|
||||
5. **Relationship Detection**: Discovers connections between entities
|
||||
|
|
@ -97,6 +97,13 @@ After successful cognify processing, use `cognee search` to query the knowledge
|
|||
chunker_class = LangchainChunker
|
||||
except ImportError:
|
||||
fmt.warning("LangchainChunker not available, using TextChunker")
|
||||
elif args.chunker == "CsvChunker":
|
||||
try:
|
||||
from cognee.modules.chunking.CsvChunker import CsvChunker
|
||||
|
||||
chunker_class = CsvChunker
|
||||
except ImportError:
|
||||
fmt.warning("CsvChunker not available, using TextChunker")
|
||||
|
||||
result = await cognee.cognify(
|
||||
datasets=datasets,
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ SEARCH_TYPE_CHOICES = [
|
|||
]
|
||||
|
||||
# Chunker choices
|
||||
CHUNKER_CHOICES = ["TextChunker", "LangchainChunker"]
|
||||
CHUNKER_CHOICES = ["TextChunker", "LangchainChunker", "CsvChunker"]
|
||||
|
||||
# Output format choices
|
||||
OUTPUT_FORMAT_CHOICES = ["json", "pretty", "simple"]
|
||||
|
|
|
|||
|
|
@ -55,6 +55,10 @@ def guess_file_type(file: BinaryIO, name: Optional[str] = None) -> filetype.Type
|
|||
file_type = Type("text/plain", "txt")
|
||||
return file_type
|
||||
|
||||
if ext in [".csv"]:
|
||||
file_type = Type("text/csv", "csv")
|
||||
return file_type
|
||||
|
||||
file_type = filetype.guess(file)
|
||||
|
||||
# If file type could not be determined consider it a plain text file as they don't have magic number encoding
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ class LLMConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
structured_output_framework: str = "instructor"
|
||||
llm_instructor_mode: str = ""
|
||||
llm_provider: str = "openai"
|
||||
llm_model: str = "openai/gpt-5-mini"
|
||||
llm_endpoint: str = ""
|
||||
|
|
@ -181,6 +182,7 @@ class LLMConfig(BaseSettings):
|
|||
instance.
|
||||
"""
|
||||
return {
|
||||
"llm_instructor_mode": self.llm_instructor_mode.lower(),
|
||||
"provider": self.llm_provider,
|
||||
"model": self.llm_model,
|
||||
"endpoint": self.llm_endpoint,
|
||||
|
|
|
|||
|
|
@ -28,13 +28,16 @@ class AnthropicAdapter(LLMInterface):
|
|||
|
||||
name = "Anthropic"
|
||||
model: str
|
||||
default_instructor_mode = "anthropic_tools"
|
||||
|
||||
def __init__(self, max_completion_tokens: int, model: str = None):
|
||||
def __init__(self, max_completion_tokens: int, model: str = None, instructor_mode: str = None):
|
||||
import anthropic
|
||||
|
||||
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
|
||||
|
||||
self.aclient = instructor.patch(
|
||||
create=anthropic.AsyncAnthropic(api_key=get_llm_config().llm_api_key).messages.create,
|
||||
mode=instructor.Mode.ANTHROPIC_TOOLS,
|
||||
mode=instructor.Mode(self.instructor_mode),
|
||||
)
|
||||
|
||||
self.model = model
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@ class GeminiAdapter(LLMInterface):
|
|||
name: str
|
||||
model: str
|
||||
api_key: str
|
||||
default_instructor_mode = "json_mode"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -49,6 +50,7 @@ class GeminiAdapter(LLMInterface):
|
|||
model: str,
|
||||
api_version: str,
|
||||
max_completion_tokens: int,
|
||||
instructor_mode: str = None,
|
||||
fallback_model: str = None,
|
||||
fallback_api_key: str = None,
|
||||
fallback_endpoint: str = None,
|
||||
|
|
@ -63,7 +65,11 @@ class GeminiAdapter(LLMInterface):
|
|||
self.fallback_api_key = fallback_api_key
|
||||
self.fallback_endpoint = fallback_endpoint
|
||||
|
||||
self.aclient = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON)
|
||||
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
|
||||
|
||||
self.aclient = instructor.from_litellm(
|
||||
litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
|
||||
)
|
||||
|
||||
@retry(
|
||||
stop=stop_after_delay(128),
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@ class GenericAPIAdapter(LLMInterface):
|
|||
name: str
|
||||
model: str
|
||||
api_key: str
|
||||
default_instructor_mode = "json_mode"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -49,6 +50,7 @@ class GenericAPIAdapter(LLMInterface):
|
|||
model: str,
|
||||
name: str,
|
||||
max_completion_tokens: int,
|
||||
instructor_mode: str = None,
|
||||
fallback_model: str = None,
|
||||
fallback_api_key: str = None,
|
||||
fallback_endpoint: str = None,
|
||||
|
|
@ -63,7 +65,11 @@ class GenericAPIAdapter(LLMInterface):
|
|||
self.fallback_api_key = fallback_api_key
|
||||
self.fallback_endpoint = fallback_endpoint
|
||||
|
||||
self.aclient = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON)
|
||||
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
|
||||
|
||||
self.aclient = instructor.from_litellm(
|
||||
litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
|
||||
)
|
||||
|
||||
@retry(
|
||||
stop=stop_after_delay(128),
|
||||
|
|
|
|||
|
|
@ -81,6 +81,7 @@ def get_llm_client(raise_api_key_error: bool = True):
|
|||
model=llm_config.llm_model,
|
||||
transcription_model=llm_config.transcription_model,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
instructor_mode=llm_config.llm_instructor_mode.lower(),
|
||||
streaming=llm_config.llm_streaming,
|
||||
fallback_api_key=llm_config.fallback_api_key,
|
||||
fallback_endpoint=llm_config.fallback_endpoint,
|
||||
|
|
@ -101,6 +102,7 @@ def get_llm_client(raise_api_key_error: bool = True):
|
|||
llm_config.llm_model,
|
||||
"Ollama",
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
instructor_mode=llm_config.llm_instructor_mode.lower(),
|
||||
)
|
||||
|
||||
elif provider == LLMProvider.ANTHROPIC:
|
||||
|
|
@ -109,7 +111,9 @@ def get_llm_client(raise_api_key_error: bool = True):
|
|||
)
|
||||
|
||||
return AnthropicAdapter(
|
||||
max_completion_tokens=max_completion_tokens, model=llm_config.llm_model
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
model=llm_config.llm_model,
|
||||
instructor_mode=llm_config.llm_instructor_mode.lower(),
|
||||
)
|
||||
|
||||
elif provider == LLMProvider.CUSTOM:
|
||||
|
|
@ -126,6 +130,7 @@ def get_llm_client(raise_api_key_error: bool = True):
|
|||
llm_config.llm_model,
|
||||
"Custom",
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
instructor_mode=llm_config.llm_instructor_mode.lower(),
|
||||
fallback_api_key=llm_config.fallback_api_key,
|
||||
fallback_endpoint=llm_config.fallback_endpoint,
|
||||
fallback_model=llm_config.fallback_model,
|
||||
|
|
@ -145,6 +150,7 @@ def get_llm_client(raise_api_key_error: bool = True):
|
|||
max_completion_tokens=max_completion_tokens,
|
||||
endpoint=llm_config.llm_endpoint,
|
||||
api_version=llm_config.llm_api_version,
|
||||
instructor_mode=llm_config.llm_instructor_mode.lower(),
|
||||
)
|
||||
|
||||
elif provider == LLMProvider.MISTRAL:
|
||||
|
|
@ -160,6 +166,7 @@ def get_llm_client(raise_api_key_error: bool = True):
|
|||
model=llm_config.llm_model,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
endpoint=llm_config.llm_endpoint,
|
||||
instructor_mode=llm_config.llm_instructor_mode.lower(),
|
||||
)
|
||||
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -37,16 +37,26 @@ class MistralAdapter(LLMInterface):
|
|||
model: str
|
||||
api_key: str
|
||||
max_completion_tokens: int
|
||||
default_instructor_mode = "mistral_tools"
|
||||
|
||||
def __init__(self, api_key: str, model: str, max_completion_tokens: int, endpoint: str = None):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
model: str,
|
||||
max_completion_tokens: int,
|
||||
endpoint: str = None,
|
||||
instructor_mode: str = None,
|
||||
):
|
||||
from mistralai import Mistral
|
||||
|
||||
self.model = model
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
|
||||
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
|
||||
|
||||
self.aclient = instructor.from_litellm(
|
||||
litellm.acompletion,
|
||||
mode=instructor.Mode.MISTRAL_TOOLS,
|
||||
mode=instructor.Mode(self.instructor_mode),
|
||||
api_key=get_llm_config().llm_api_key,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -42,8 +42,16 @@ class OllamaAPIAdapter(LLMInterface):
|
|||
- aclient
|
||||
"""
|
||||
|
||||
default_instructor_mode = "json_mode"
|
||||
|
||||
def __init__(
|
||||
self, endpoint: str, api_key: str, model: str, name: str, max_completion_tokens: int
|
||||
self,
|
||||
endpoint: str,
|
||||
api_key: str,
|
||||
model: str,
|
||||
name: str,
|
||||
max_completion_tokens: int,
|
||||
instructor_mode: str = None,
|
||||
):
|
||||
self.name = name
|
||||
self.model = model
|
||||
|
|
@ -51,8 +59,11 @@ class OllamaAPIAdapter(LLMInterface):
|
|||
self.endpoint = endpoint
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
|
||||
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
|
||||
|
||||
self.aclient = instructor.from_openai(
|
||||
OpenAI(base_url=self.endpoint, api_key=self.api_key), mode=instructor.Mode.JSON
|
||||
OpenAI(base_url=self.endpoint, api_key=self.api_key),
|
||||
mode=instructor.Mode(self.instructor_mode),
|
||||
)
|
||||
|
||||
@retry(
|
||||
|
|
|
|||
|
|
@ -56,6 +56,7 @@ class OpenAIAdapter(LLMInterface):
|
|||
model: str
|
||||
api_key: str
|
||||
api_version: str
|
||||
default_instructor_mode = "json_schema_mode"
|
||||
|
||||
MAX_RETRIES = 5
|
||||
|
||||
|
|
@ -69,19 +70,21 @@ class OpenAIAdapter(LLMInterface):
|
|||
model: str,
|
||||
transcription_model: str,
|
||||
max_completion_tokens: int,
|
||||
instructor_mode: str = None,
|
||||
streaming: bool = False,
|
||||
fallback_model: str = None,
|
||||
fallback_api_key: str = None,
|
||||
fallback_endpoint: str = None,
|
||||
):
|
||||
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
|
||||
# TODO: With gpt5 series models OpenAI expects JSON_SCHEMA as a mode for structured outputs.
|
||||
# Make sure all new gpt models will work with this mode as well.
|
||||
if "gpt-5" in model:
|
||||
self.aclient = instructor.from_litellm(
|
||||
litellm.acompletion, mode=instructor.Mode.JSON_SCHEMA
|
||||
litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
|
||||
)
|
||||
self.client = instructor.from_litellm(
|
||||
litellm.completion, mode=instructor.Mode.JSON_SCHEMA
|
||||
litellm.completion, mode=instructor.Mode(self.instructor_mode)
|
||||
)
|
||||
else:
|
||||
self.aclient = instructor.from_litellm(litellm.acompletion)
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ class LoaderEngine:
|
|||
"pypdf_loader",
|
||||
"image_loader",
|
||||
"audio_loader",
|
||||
"csv_loader",
|
||||
"unstructured_loader",
|
||||
"advanced_pdf_loader",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -3,5 +3,6 @@
|
|||
from .text_loader import TextLoader
|
||||
from .audio_loader import AudioLoader
|
||||
from .image_loader import ImageLoader
|
||||
from .csv_loader import CsvLoader
|
||||
|
||||
__all__ = ["TextLoader", "AudioLoader", "ImageLoader"]
|
||||
__all__ = ["TextLoader", "AudioLoader", "ImageLoader", "CsvLoader"]
|
||||
|
|
|
|||
93
cognee/infrastructure/loaders/core/csv_loader.py
Normal file
93
cognee/infrastructure/loaders/core/csv_loader.py
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
import os
|
||||
from typing import List
|
||||
import csv
|
||||
from cognee.infrastructure.loaders.LoaderInterface import LoaderInterface
|
||||
from cognee.infrastructure.files.storage import get_file_storage, get_storage_config
|
||||
from cognee.infrastructure.files.utils.get_file_metadata import get_file_metadata
|
||||
|
||||
|
||||
class CsvLoader(LoaderInterface):
|
||||
"""
|
||||
Core CSV file loader that handles basic CSV file formats.
|
||||
"""
|
||||
|
||||
@property
|
||||
def supported_extensions(self) -> List[str]:
|
||||
"""Supported text file extensions."""
|
||||
return [
|
||||
"csv",
|
||||
]
|
||||
|
||||
@property
|
||||
def supported_mime_types(self) -> List[str]:
|
||||
"""Supported MIME types for text content."""
|
||||
return [
|
||||
"text/csv",
|
||||
]
|
||||
|
||||
@property
|
||||
def loader_name(self) -> str:
|
||||
"""Unique identifier for this loader."""
|
||||
return "csv_loader"
|
||||
|
||||
def can_handle(self, extension: str, mime_type: str) -> bool:
|
||||
"""
|
||||
Check if this loader can handle the given file.
|
||||
|
||||
Args:
|
||||
extension: File extension
|
||||
mime_type: Optional MIME type
|
||||
|
||||
Returns:
|
||||
True if file can be handled, False otherwise
|
||||
"""
|
||||
if extension in self.supported_extensions and mime_type in self.supported_mime_types:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def load(self, file_path: str, encoding: str = "utf-8", **kwargs):
|
||||
"""
|
||||
Load and process the csv file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load
|
||||
encoding: Text encoding to use (default: utf-8)
|
||||
**kwargs: Additional configuration (unused)
|
||||
|
||||
Returns:
|
||||
LoaderResult containing the file content and metadata
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If file doesn't exist
|
||||
UnicodeDecodeError: If file cannot be decoded with specified encoding
|
||||
OSError: If file cannot be read
|
||||
"""
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
file_metadata = await get_file_metadata(f)
|
||||
# Name ingested file of current loader based on original file content hash
|
||||
storage_file_name = "text_" + file_metadata["content_hash"] + ".txt"
|
||||
|
||||
row_texts = []
|
||||
row_index = 1
|
||||
|
||||
with open(file_path, "r", encoding=encoding, newline="") as file:
|
||||
reader = csv.DictReader(file)
|
||||
for row in reader:
|
||||
pairs = [f"{str(k)}: {str(v)}" for k, v in row.items()]
|
||||
row_text = ", ".join(pairs)
|
||||
row_texts.append(f"Row {row_index}:\n{row_text}\n")
|
||||
row_index += 1
|
||||
|
||||
content = "\n".join(row_texts)
|
||||
|
||||
storage_config = get_storage_config()
|
||||
data_root_directory = storage_config["data_root_directory"]
|
||||
storage = get_file_storage(data_root_directory)
|
||||
|
||||
full_file_path = await storage.store(storage_file_name, content)
|
||||
|
||||
return full_file_path
|
||||
|
|
@ -16,7 +16,7 @@ class TextLoader(LoaderInterface):
|
|||
@property
|
||||
def supported_extensions(self) -> List[str]:
|
||||
"""Supported text file extensions."""
|
||||
return ["txt", "md", "csv", "json", "xml", "yaml", "yml", "log"]
|
||||
return ["txt", "md", "json", "xml", "yaml", "yml", "log"]
|
||||
|
||||
@property
|
||||
def supported_mime_types(self) -> List[str]:
|
||||
|
|
@ -24,7 +24,6 @@ class TextLoader(LoaderInterface):
|
|||
return [
|
||||
"text/plain",
|
||||
"text/markdown",
|
||||
"text/csv",
|
||||
"application/json",
|
||||
"text/xml",
|
||||
"application/xml",
|
||||
|
|
|
|||
|
|
@ -227,12 +227,3 @@ class AdvancedPdfLoader(LoaderInterface):
|
|||
if value is None:
|
||||
return ""
|
||||
return str(value).replace("\xa0", " ").strip()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
loader = AdvancedPdfLoader()
|
||||
asyncio.run(
|
||||
loader.load(
|
||||
"/Users/xiaotao/work/cognee/cognee/infrastructure/loaders/external/attention_is_all_you_need.pdf"
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from cognee.infrastructure.loaders.external import PyPdfLoader
|
||||
from cognee.infrastructure.loaders.core import TextLoader, AudioLoader, ImageLoader
|
||||
from cognee.infrastructure.loaders.core import TextLoader, AudioLoader, ImageLoader, CsvLoader
|
||||
|
||||
# Registry for loader implementations
|
||||
supported_loaders = {
|
||||
|
|
@ -7,6 +7,7 @@ supported_loaders = {
|
|||
TextLoader.loader_name: TextLoader,
|
||||
ImageLoader.loader_name: ImageLoader,
|
||||
AudioLoader.loader_name: AudioLoader,
|
||||
CsvLoader.loader_name: CsvLoader,
|
||||
}
|
||||
|
||||
# Try adding optional loaders
|
||||
|
|
|
|||
35
cognee/modules/chunking/CsvChunker.py
Normal file
35
cognee/modules/chunking/CsvChunker.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
|
||||
from cognee.tasks.chunks import chunk_by_row
|
||||
from cognee.modules.chunking.Chunker import Chunker
|
||||
from .models.DocumentChunk import DocumentChunk
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class CsvChunker(Chunker):
|
||||
async def read(self):
|
||||
async for content_text in self.get_text():
|
||||
if content_text is None:
|
||||
continue
|
||||
|
||||
for chunk_data in chunk_by_row(content_text, self.max_chunk_size):
|
||||
if chunk_data["chunk_size"] <= self.max_chunk_size:
|
||||
yield DocumentChunk(
|
||||
id=chunk_data["chunk_id"],
|
||||
text=chunk_data["text"],
|
||||
chunk_size=chunk_data["chunk_size"],
|
||||
is_part_of=self.document,
|
||||
chunk_index=self.chunk_index,
|
||||
cut_type=chunk_data["cut_type"],
|
||||
contains=[],
|
||||
metadata={
|
||||
"index_fields": ["text"],
|
||||
},
|
||||
)
|
||||
self.chunk_index += 1
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Chunk size is larger than the maximum chunk size {self.max_chunk_size}"
|
||||
)
|
||||
33
cognee/modules/data/processing/document_types/CsvDocument.py
Normal file
33
cognee/modules/data/processing/document_types/CsvDocument.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
import io
|
||||
import csv
|
||||
from typing import Type
|
||||
|
||||
from cognee.modules.chunking.Chunker import Chunker
|
||||
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||
from .Document import Document
|
||||
|
||||
|
||||
class CsvDocument(Document):
|
||||
type: str = "csv"
|
||||
mime_type: str = "text/csv"
|
||||
|
||||
async def read(self, chunker_cls: Type[Chunker], max_chunk_size: int):
|
||||
async def get_text():
|
||||
async with open_data_file(
|
||||
self.raw_data_location, mode="r", encoding="utf-8", newline=""
|
||||
) as file:
|
||||
content = file.read()
|
||||
file_like_obj = io.StringIO(content)
|
||||
reader = csv.DictReader(file_like_obj)
|
||||
|
||||
for row in reader:
|
||||
pairs = [f"{str(k)}: {str(v)}" for k, v in row.items()]
|
||||
row_text = ", ".join(pairs)
|
||||
if not row_text.strip():
|
||||
break
|
||||
yield row_text
|
||||
|
||||
chunker = chunker_cls(self, max_chunk_size=max_chunk_size, get_text=get_text)
|
||||
|
||||
async for chunk in chunker.read():
|
||||
yield chunk
|
||||
|
|
@ -4,3 +4,4 @@ from .TextDocument import TextDocument
|
|||
from .ImageDocument import ImageDocument
|
||||
from .AudioDocument import AudioDocument
|
||||
from .UnstructuredDocument import UnstructuredDocument
|
||||
from .CsvDocument import CsvDocument
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@ import io
|
|||
import sys
|
||||
import traceback
|
||||
|
||||
import cognee
|
||||
|
||||
|
||||
def wrap_in_async_handler(user_code: str) -> str:
|
||||
return (
|
||||
|
|
@ -34,6 +36,7 @@ def run_in_local_sandbox(code, environment=None, loop=None):
|
|||
|
||||
environment["print"] = customPrintFunction
|
||||
environment["running_loop"] = loop
|
||||
environment["cognee"] = cognee
|
||||
|
||||
try:
|
||||
exec(code, environment)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import os
|
|||
import difflib
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from collections import deque
|
||||
from typing import List, Tuple, Dict, Optional, Any, Union
|
||||
from typing import List, Tuple, Dict, Optional, Any, Union, IO
|
||||
from rdflib import Graph, URIRef, RDF, RDFS, OWL
|
||||
|
||||
from cognee.modules.ontology.exceptions import (
|
||||
|
|
@ -26,44 +26,76 @@ class RDFLibOntologyResolver(BaseOntologyResolver):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
ontology_file: Optional[Union[str, List[str]]] = None,
|
||||
ontology_file: Optional[Union[str, List[str], IO, List[IO]]] = None,
|
||||
matching_strategy: Optional[MatchingStrategy] = None,
|
||||
) -> None:
|
||||
super().__init__(matching_strategy)
|
||||
self.ontology_file = ontology_file
|
||||
try:
|
||||
files_to_load = []
|
||||
self.graph = None
|
||||
if ontology_file is not None:
|
||||
if isinstance(ontology_file, str):
|
||||
files_to_load = []
|
||||
file_objects = []
|
||||
|
||||
if hasattr(ontology_file, "read"):
|
||||
file_objects = [ontology_file]
|
||||
elif isinstance(ontology_file, str):
|
||||
files_to_load = [ontology_file]
|
||||
elif isinstance(ontology_file, list):
|
||||
files_to_load = ontology_file
|
||||
if all(hasattr(item, "read") for item in ontology_file):
|
||||
file_objects = ontology_file
|
||||
else:
|
||||
files_to_load = ontology_file
|
||||
else:
|
||||
raise ValueError(
|
||||
f"ontology_file must be a string, list of strings, or None. Got: {type(ontology_file)}"
|
||||
f"ontology_file must be a string, list of strings, file-like object, list of file-like objects, or None. Got: {type(ontology_file)}"
|
||||
)
|
||||
|
||||
if files_to_load:
|
||||
self.graph = Graph()
|
||||
loaded_files = []
|
||||
for file_path in files_to_load:
|
||||
if os.path.exists(file_path):
|
||||
self.graph.parse(file_path)
|
||||
loaded_files.append(file_path)
|
||||
logger.info("Ontology loaded successfully from file: %s", file_path)
|
||||
else:
|
||||
logger.warning(
|
||||
"Ontology file '%s' not found. Skipping this file.",
|
||||
file_path,
|
||||
if file_objects:
|
||||
self.graph = Graph()
|
||||
loaded_objects = []
|
||||
for file_obj in file_objects:
|
||||
try:
|
||||
content = file_obj.read()
|
||||
self.graph.parse(data=content, format="xml")
|
||||
loaded_objects.append(file_obj)
|
||||
logger.info("Ontology loaded successfully from file object")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to parse ontology file object: %s", str(e))
|
||||
|
||||
if not loaded_objects:
|
||||
logger.info(
|
||||
"No valid ontology file objects found. No owl ontology will be attached to the graph."
|
||||
)
|
||||
self.graph = None
|
||||
else:
|
||||
logger.info("Total ontology file objects loaded: %d", len(loaded_objects))
|
||||
|
||||
if not loaded_files:
|
||||
logger.info(
|
||||
"No valid ontology files found. No owl ontology will be attached to the graph."
|
||||
)
|
||||
self.graph = None
|
||||
elif files_to_load:
|
||||
self.graph = Graph()
|
||||
loaded_files = []
|
||||
for file_path in files_to_load:
|
||||
if os.path.exists(file_path):
|
||||
self.graph.parse(file_path)
|
||||
loaded_files.append(file_path)
|
||||
logger.info("Ontology loaded successfully from file: %s", file_path)
|
||||
else:
|
||||
logger.warning(
|
||||
"Ontology file '%s' not found. Skipping this file.",
|
||||
file_path,
|
||||
)
|
||||
|
||||
if not loaded_files:
|
||||
logger.info(
|
||||
"No valid ontology files found. No owl ontology will be attached to the graph."
|
||||
)
|
||||
self.graph = None
|
||||
else:
|
||||
logger.info("Total ontology files loaded: %d", len(loaded_files))
|
||||
else:
|
||||
logger.info("Total ontology files loaded: %d", len(loaded_files))
|
||||
logger.info(
|
||||
"No ontology file provided. No owl ontology will be attached to the graph."
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"No ontology file provided. No owl ontology will be attached to the graph."
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from .chunk_by_word import chunk_by_word
|
||||
from .chunk_by_sentence import chunk_by_sentence
|
||||
from .chunk_by_paragraph import chunk_by_paragraph
|
||||
from .chunk_by_row import chunk_by_row
|
||||
from .remove_disconnected_chunks import remove_disconnected_chunks
|
||||
|
|
|
|||
94
cognee/tasks/chunks/chunk_by_row.py
Normal file
94
cognee/tasks/chunks/chunk_by_row.py
Normal file
|
|
@ -0,0 +1,94 @@
|
|||
from typing import Any, Dict, Iterator
|
||||
from uuid import NAMESPACE_OID, uuid5
|
||||
|
||||
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
|
||||
|
||||
|
||||
def _get_pair_size(pair_text: str) -> int:
|
||||
"""
|
||||
Calculate the size of a given text in terms of tokens.
|
||||
|
||||
If an embedding engine's tokenizer is available, count the tokens for the provided word.
|
||||
If the tokenizer is not available, assume the word counts as one token.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- pair_text (str): The key:value pair text for which the token size is to be calculated.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- int: The number of tokens representing the text, typically an integer, depending
|
||||
on the tokenizer's output.
|
||||
"""
|
||||
embedding_engine = get_embedding_engine()
|
||||
if embedding_engine.tokenizer:
|
||||
return embedding_engine.tokenizer.count_tokens(pair_text)
|
||||
else:
|
||||
return 3
|
||||
|
||||
|
||||
def chunk_by_row(
|
||||
data: str,
|
||||
max_chunk_size,
|
||||
) -> Iterator[Dict[str, Any]]:
|
||||
"""
|
||||
Chunk the input text by row while enabling exact text reconstruction.
|
||||
|
||||
This function divides the given text data into smaller chunks on a line-by-line basis,
|
||||
ensuring that the size of each chunk is less than or equal to the specified maximum
|
||||
chunk size. It guarantees that when the generated chunks are concatenated, they
|
||||
reproduce the original text accurately. The tokenization process is handled by
|
||||
adapters compatible with the vector engine's embedding model.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- data (str): The input text to be chunked.
|
||||
- max_chunk_size: The maximum allowed size for each chunk, in terms of tokens or
|
||||
words.
|
||||
"""
|
||||
current_chunk_list = []
|
||||
chunk_index = 0
|
||||
current_chunk_size = 0
|
||||
|
||||
lines = data.split("\n\n")
|
||||
for line in lines:
|
||||
pairs_text = line.split(", ")
|
||||
|
||||
for pair_text in pairs_text:
|
||||
pair_size = _get_pair_size(pair_text)
|
||||
if current_chunk_size > 0 and (current_chunk_size + pair_size > max_chunk_size):
|
||||
# Yield current cut chunk
|
||||
current_chunk = ", ".join(current_chunk_list)
|
||||
chunk_dict = {
|
||||
"text": current_chunk,
|
||||
"chunk_size": current_chunk_size,
|
||||
"chunk_id": uuid5(NAMESPACE_OID, current_chunk),
|
||||
"chunk_index": chunk_index,
|
||||
"cut_type": "row_cut",
|
||||
}
|
||||
|
||||
yield chunk_dict
|
||||
|
||||
# Start new chunk with current pair text
|
||||
current_chunk_list = []
|
||||
current_chunk_size = 0
|
||||
chunk_index += 1
|
||||
|
||||
current_chunk_list.append(pair_text)
|
||||
current_chunk_size += pair_size
|
||||
|
||||
# Yield row chunk
|
||||
current_chunk = ", ".join(current_chunk_list)
|
||||
if current_chunk:
|
||||
chunk_dict = {
|
||||
"text": current_chunk,
|
||||
"chunk_size": current_chunk_size,
|
||||
"chunk_id": uuid5(NAMESPACE_OID, current_chunk),
|
||||
"chunk_index": chunk_index,
|
||||
"cut_type": "row_end",
|
||||
}
|
||||
|
||||
yield chunk_dict
|
||||
|
|
@ -7,6 +7,7 @@ from cognee.modules.data.processing.document_types import (
|
|||
ImageDocument,
|
||||
TextDocument,
|
||||
UnstructuredDocument,
|
||||
CsvDocument,
|
||||
)
|
||||
from cognee.modules.engine.models.node_set import NodeSet
|
||||
from cognee.modules.engine.utils.generate_node_id import generate_node_id
|
||||
|
|
@ -15,6 +16,7 @@ from cognee.tasks.documents.exceptions import WrongDataDocumentInputError
|
|||
EXTENSION_TO_DOCUMENT_CLASS = {
|
||||
"pdf": PdfDocument, # Text documents
|
||||
"txt": TextDocument,
|
||||
"csv": CsvDocument,
|
||||
"docx": UnstructuredDocument,
|
||||
"doc": UnstructuredDocument,
|
||||
"odt": UnstructuredDocument,
|
||||
|
|
|
|||
70
cognee/tests/integration/documents/CsvDocument_test.py
Normal file
70
cognee/tests/integration/documents/CsvDocument_test.py
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
import os
|
||||
import sys
|
||||
import uuid
|
||||
import pytest
|
||||
import pathlib
|
||||
from unittest.mock import patch
|
||||
|
||||
from cognee.modules.chunking.CsvChunker import CsvChunker
|
||||
from cognee.modules.data.processing.document_types.CsvDocument import CsvDocument
|
||||
from cognee.tests.integration.documents.AudioDocument_test import mock_get_embedding_engine
|
||||
from cognee.tests.integration.documents.async_gen_zip import async_gen_zip
|
||||
|
||||
chunk_by_row_module = sys.modules.get("cognee.tasks.chunks.chunk_by_row")
|
||||
|
||||
|
||||
GROUND_TRUTH = {
|
||||
"chunk_size_10": [
|
||||
{"token_count": 9, "len_text": 26, "cut_type": "row_cut", "chunk_index": 0},
|
||||
{"token_count": 6, "len_text": 29, "cut_type": "row_end", "chunk_index": 1},
|
||||
{"token_count": 9, "len_text": 25, "cut_type": "row_cut", "chunk_index": 2},
|
||||
{"token_count": 6, "len_text": 30, "cut_type": "row_end", "chunk_index": 3},
|
||||
],
|
||||
"chunk_size_128": [
|
||||
{"token_count": 15, "len_text": 57, "cut_type": "row_end", "chunk_index": 0},
|
||||
{"token_count": 15, "len_text": 57, "cut_type": "row_end", "chunk_index": 1},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_file,chunk_size",
|
||||
[("example_with_header.csv", 10), ("example_with_header.csv", 128)],
|
||||
)
|
||||
@patch.object(chunk_by_row_module, "get_embedding_engine", side_effect=mock_get_embedding_engine)
|
||||
@pytest.mark.asyncio
|
||||
async def test_CsvDocument(mock_engine, input_file, chunk_size):
|
||||
# Define file paths of test data
|
||||
csv_file_path = os.path.join(
|
||||
pathlib.Path(__file__).parent.parent.parent,
|
||||
"test_data",
|
||||
input_file,
|
||||
)
|
||||
|
||||
# Define test documents
|
||||
csv_document = CsvDocument(
|
||||
id=uuid.uuid4(),
|
||||
name="example_with_header.csv",
|
||||
raw_data_location=csv_file_path,
|
||||
external_metadata="",
|
||||
mime_type="text/csv",
|
||||
)
|
||||
|
||||
# TEST CSV
|
||||
ground_truth_key = f"chunk_size_{chunk_size}"
|
||||
async for ground_truth, row_data in async_gen_zip(
|
||||
GROUND_TRUTH[ground_truth_key],
|
||||
csv_document.read(chunker_cls=CsvChunker, max_chunk_size=chunk_size),
|
||||
):
|
||||
assert ground_truth["token_count"] == row_data.chunk_size, (
|
||||
f'{ground_truth["token_count"] = } != {row_data.chunk_size = }'
|
||||
)
|
||||
assert ground_truth["len_text"] == len(row_data.text), (
|
||||
f'{ground_truth["len_text"] = } != {len(row_data.text) = }'
|
||||
)
|
||||
assert ground_truth["cut_type"] == row_data.cut_type, (
|
||||
f'{ground_truth["cut_type"] = } != {row_data.cut_type = }'
|
||||
)
|
||||
assert ground_truth["chunk_index"] == row_data.chunk_index, (
|
||||
f'{ground_truth["chunk_index"] = } != {row_data.chunk_index = }'
|
||||
)
|
||||
|
|
@ -7,6 +7,7 @@ import requests
|
|||
from pathlib import Path
|
||||
import sys
|
||||
import uuid
|
||||
import json
|
||||
|
||||
|
||||
class TestCogneeServerStart(unittest.TestCase):
|
||||
|
|
@ -90,12 +91,71 @@ class TestCogneeServerStart(unittest.TestCase):
|
|||
)
|
||||
}
|
||||
|
||||
payload = {"datasets": [dataset_name]}
|
||||
ontology_key = f"test_ontology_{uuid.uuid4().hex[:8]}"
|
||||
payload = {"datasets": [dataset_name], "ontology_key": [ontology_key]}
|
||||
|
||||
add_response = requests.post(url, headers=headers, data=form_data, files=file, timeout=50)
|
||||
if add_response.status_code not in [200, 201]:
|
||||
add_response.raise_for_status()
|
||||
|
||||
ontology_content = b"""<rdf:RDF xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
|
||||
xmlns:owl="http://www.w3.org/2002/07/owl#"
|
||||
xmlns:rdfs="http://www.w3.org/2000/01/rdf-schema#"
|
||||
xmlns="http://example.org/ontology#"
|
||||
xml:base="http://example.org/ontology">
|
||||
|
||||
<owl:Ontology rdf:about="http://example.org/ontology"/>
|
||||
|
||||
<!-- Classes -->
|
||||
<owl:Class rdf:ID="Problem"/>
|
||||
<owl:Class rdf:ID="HardwareProblem"/>
|
||||
<owl:Class rdf:ID="SoftwareProblem"/>
|
||||
<owl:Class rdf:ID="Concept"/>
|
||||
<owl:Class rdf:ID="Object"/>
|
||||
<owl:Class rdf:ID="Joke"/>
|
||||
<owl:Class rdf:ID="Image"/>
|
||||
<owl:Class rdf:ID="Person"/>
|
||||
|
||||
<rdf:Description rdf:about="#HardwareProblem">
|
||||
<rdfs:subClassOf rdf:resource="#Problem"/>
|
||||
<rdfs:comment>A failure caused by physical components.</rdfs:comment>
|
||||
</rdf:Description>
|
||||
|
||||
<rdf:Description rdf:about="#SoftwareProblem">
|
||||
<rdfs:subClassOf rdf:resource="#Problem"/>
|
||||
<rdfs:comment>An error caused by software logic or configuration.</rdfs:comment>
|
||||
</rdf:Description>
|
||||
|
||||
<rdf:Description rdf:about="#Person">
|
||||
<rdfs:comment>A human being or individual.</rdfs:comment>
|
||||
</rdf:Description>
|
||||
|
||||
<!-- Individuals -->
|
||||
<Person rdf:ID="programmers">
|
||||
<rdfs:label>Programmers</rdfs:label>
|
||||
</Person>
|
||||
|
||||
<Object rdf:ID="light_bulb">
|
||||
<rdfs:label>Light Bulb</rdfs:label>
|
||||
</Object>
|
||||
|
||||
<HardwareProblem rdf:ID="hardware_problem">
|
||||
<rdfs:label>Hardware Problem</rdfs:label>
|
||||
</HardwareProblem>
|
||||
|
||||
</rdf:RDF>"""
|
||||
|
||||
ontology_response = requests.post(
|
||||
"http://127.0.0.1:8000/api/v1/ontologies",
|
||||
headers=headers,
|
||||
files=[("ontology_file", ("test.owl", ontology_content, "application/xml"))],
|
||||
data={
|
||||
"ontology_key": json.dumps([ontology_key]),
|
||||
"description": json.dumps(["Test ontology"]),
|
||||
},
|
||||
)
|
||||
self.assertEqual(ontology_response.status_code, 200)
|
||||
|
||||
# Cognify request
|
||||
url = "http://127.0.0.1:8000/api/v1/cognify"
|
||||
headers = {
|
||||
|
|
@ -107,6 +167,29 @@ class TestCogneeServerStart(unittest.TestCase):
|
|||
if cognify_response.status_code not in [200, 201]:
|
||||
cognify_response.raise_for_status()
|
||||
|
||||
datasets_response = requests.get("http://127.0.0.1:8000/api/v1/datasets", headers=headers)
|
||||
|
||||
datasets = datasets_response.json()
|
||||
dataset_id = None
|
||||
for dataset in datasets:
|
||||
if dataset["name"] == dataset_name:
|
||||
dataset_id = dataset["id"]
|
||||
break
|
||||
|
||||
graph_response = requests.get(
|
||||
f"http://127.0.0.1:8000/api/v1/datasets/{dataset_id}/graph", headers=headers
|
||||
)
|
||||
self.assertEqual(graph_response.status_code, 200)
|
||||
|
||||
graph_data = graph_response.json()
|
||||
ontology_nodes = [
|
||||
node for node in graph_data.get("nodes") if node.get("properties").get("ontology_valid")
|
||||
]
|
||||
|
||||
self.assertGreater(
|
||||
len(ontology_nodes), 0, "No ontology nodes found - ontology was not integrated"
|
||||
)
|
||||
|
||||
# TODO: Add test to verify cognify pipeline is complete before testing search
|
||||
|
||||
# Search request
|
||||
|
|
|
|||
3
cognee/tests/test_data/example_with_header.csv
Normal file
3
cognee/tests/test_data/example_with_header.csv
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
id,name,age,city,country
|
||||
1,Eric,30,Beijing,China
|
||||
2,Joe,35,Berlin,Germany
|
||||
|
264
cognee/tests/unit/api/test_ontology_endpoint.py
Normal file
264
cognee/tests/unit/api/test_ontology_endpoint.py
Normal file
|
|
@ -0,0 +1,264 @@
|
|||
import pytest
|
||||
import uuid
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import patch, Mock, AsyncMock
|
||||
from types import SimpleNamespace
|
||||
import importlib
|
||||
from cognee.api.client import app
|
||||
|
||||
gau_mod = importlib.import_module("cognee.modules.users.methods.get_authenticated_user")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user():
|
||||
user = Mock()
|
||||
user.id = "test-user-123"
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_default_user():
|
||||
"""Mock default user for testing."""
|
||||
return SimpleNamespace(
|
||||
id=uuid.uuid4(), email="default@example.com", is_active=True, tenant_id=uuid.uuid4()
|
||||
)
|
||||
|
||||
|
||||
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
|
||||
def test_upload_ontology_success(mock_get_default_user, client, mock_default_user):
|
||||
"""Test successful ontology upload"""
|
||||
import json
|
||||
|
||||
mock_get_default_user.return_value = mock_default_user
|
||||
ontology_content = (
|
||||
b"<rdf:RDF xmlns:rdf='http://www.w3.org/1999/02/22-rdf-syntax-ns#'></rdf:RDF>"
|
||||
)
|
||||
unique_key = f"test_ontology_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/ontologies",
|
||||
files=[("ontology_file", ("test.owl", ontology_content, "application/xml"))],
|
||||
data={"ontology_key": json.dumps([unique_key]), "description": json.dumps(["Test"])},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["uploaded_ontologies"][0]["ontology_key"] == unique_key
|
||||
assert "uploaded_at" in data["uploaded_ontologies"][0]
|
||||
|
||||
|
||||
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
|
||||
def test_upload_ontology_invalid_file(mock_get_default_user, client, mock_default_user):
|
||||
"""Test 400 response for non-.owl files"""
|
||||
mock_get_default_user.return_value = mock_default_user
|
||||
unique_key = f"test_ontology_{uuid.uuid4().hex[:8]}"
|
||||
response = client.post(
|
||||
"/api/v1/ontologies",
|
||||
files={"ontology_file": ("test.txt", b"not xml")},
|
||||
data={"ontology_key": unique_key},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
|
||||
def test_upload_ontology_missing_data(mock_get_default_user, client, mock_default_user):
|
||||
"""Test 400 response for missing file or key"""
|
||||
import json
|
||||
|
||||
mock_get_default_user.return_value = mock_default_user
|
||||
# Missing file
|
||||
response = client.post("/api/v1/ontologies", data={"ontology_key": json.dumps(["test"])})
|
||||
assert response.status_code == 400
|
||||
|
||||
# Missing key
|
||||
response = client.post(
|
||||
"/api/v1/ontologies", files=[("ontology_file", ("test.owl", b"xml", "application/xml"))]
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
|
||||
def test_upload_ontology_unauthorized(mock_get_default_user, client, mock_default_user):
|
||||
"""Test behavior when default user is provided (no explicit authentication)"""
|
||||
import json
|
||||
|
||||
unique_key = f"test_ontology_{uuid.uuid4().hex[:8]}"
|
||||
mock_get_default_user.return_value = mock_default_user
|
||||
response = client.post(
|
||||
"/api/v1/ontologies",
|
||||
files=[("ontology_file", ("test.owl", b"<rdf></rdf>", "application/xml"))],
|
||||
data={"ontology_key": json.dumps([unique_key])},
|
||||
)
|
||||
|
||||
# The current system provides a default user when no explicit authentication is given
|
||||
# This test verifies the system works with conditional authentication
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["uploaded_ontologies"][0]["ontology_key"] == unique_key
|
||||
assert "uploaded_at" in data["uploaded_ontologies"][0]
|
||||
|
||||
|
||||
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
|
||||
def test_upload_multiple_ontologies(mock_get_default_user, client, mock_default_user):
|
||||
"""Test uploading multiple ontology files in single request"""
|
||||
import io
|
||||
|
||||
# Create mock files
|
||||
file1_content = b"<rdf:RDF xmlns:rdf='http://www.w3.org/1999/02/22-rdf-syntax-ns#'></rdf:RDF>"
|
||||
file2_content = b"<rdf:RDF xmlns:rdf='http://www.w3.org/1999/02/22-rdf-syntax-ns#'></rdf:RDF>"
|
||||
|
||||
files = [
|
||||
("ontology_file", ("vehicles.owl", io.BytesIO(file1_content), "application/xml")),
|
||||
("ontology_file", ("manufacturers.owl", io.BytesIO(file2_content), "application/xml")),
|
||||
]
|
||||
data = {
|
||||
"ontology_key": '["vehicles", "manufacturers"]',
|
||||
"descriptions": '["Base vehicles", "Car manufacturers"]',
|
||||
}
|
||||
|
||||
response = client.post("/api/v1/ontologies", files=files, data=data)
|
||||
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
assert "uploaded_ontologies" in result
|
||||
assert len(result["uploaded_ontologies"]) == 2
|
||||
assert result["uploaded_ontologies"][0]["ontology_key"] == "vehicles"
|
||||
assert result["uploaded_ontologies"][1]["ontology_key"] == "manufacturers"
|
||||
|
||||
|
||||
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
|
||||
def test_upload_endpoint_accepts_arrays(mock_get_default_user, client, mock_default_user):
|
||||
"""Test that upload endpoint accepts array parameters"""
|
||||
import io
|
||||
import json
|
||||
|
||||
file_content = b"<rdf:RDF xmlns:rdf='http://www.w3.org/1999/02/22-rdf-syntax-ns#'></rdf:RDF>"
|
||||
|
||||
files = [("ontology_file", ("single.owl", io.BytesIO(file_content), "application/xml"))]
|
||||
data = {
|
||||
"ontology_key": json.dumps(["single_key"]),
|
||||
"descriptions": json.dumps(["Single ontology"]),
|
||||
}
|
||||
|
||||
response = client.post("/api/v1/ontologies", files=files, data=data)
|
||||
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
assert result["uploaded_ontologies"][0]["ontology_key"] == "single_key"
|
||||
|
||||
|
||||
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
|
||||
def test_cognify_with_multiple_ontologies(mock_get_default_user, client, mock_default_user):
|
||||
"""Test cognify endpoint accepts multiple ontology keys"""
|
||||
payload = {
|
||||
"datasets": ["test_dataset"],
|
||||
"ontology_key": ["ontology1", "ontology2"], # Array instead of string
|
||||
"run_in_background": False,
|
||||
}
|
||||
|
||||
response = client.post("/api/v1/cognify", json=payload)
|
||||
|
||||
# Should not fail due to ontology_key type
|
||||
assert response.status_code in [200, 400, 409] # May fail for other reasons, not type
|
||||
|
||||
|
||||
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
|
||||
def test_complete_multifile_workflow(mock_get_default_user, client, mock_default_user):
|
||||
"""Test complete workflow: upload multiple ontologies → cognify with multiple keys"""
|
||||
import io
|
||||
import json
|
||||
|
||||
# Step 1: Upload multiple ontologies
|
||||
file1_content = b"""<?xml version="1.0"?>
|
||||
<rdf:RDF xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
|
||||
xmlns:owl="http://www.w3.org/2002/07/owl#">
|
||||
<owl:Class rdf:ID="Vehicle"/>
|
||||
</rdf:RDF>"""
|
||||
|
||||
file2_content = b"""<?xml version="1.0"?>
|
||||
<rdf:RDF xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
|
||||
xmlns:owl="http://www.w3.org/2002/07/owl#">
|
||||
<owl:Class rdf:ID="Manufacturer"/>
|
||||
</rdf:RDF>"""
|
||||
|
||||
files = [
|
||||
("ontology_file", ("vehicles.owl", io.BytesIO(file1_content), "application/xml")),
|
||||
("ontology_file", ("manufacturers.owl", io.BytesIO(file2_content), "application/xml")),
|
||||
]
|
||||
data = {
|
||||
"ontology_key": json.dumps(["vehicles", "manufacturers"]),
|
||||
"descriptions": json.dumps(["Vehicle ontology", "Manufacturer ontology"]),
|
||||
}
|
||||
|
||||
upload_response = client.post("/api/v1/ontologies", files=files, data=data)
|
||||
assert upload_response.status_code == 200
|
||||
|
||||
# Step 2: Verify ontologies are listed
|
||||
list_response = client.get("/api/v1/ontologies")
|
||||
assert list_response.status_code == 200
|
||||
ontologies = list_response.json()
|
||||
assert "vehicles" in ontologies
|
||||
assert "manufacturers" in ontologies
|
||||
|
||||
# Step 3: Test cognify with multiple ontologies
|
||||
cognify_payload = {
|
||||
"datasets": ["test_dataset"],
|
||||
"ontology_key": ["vehicles", "manufacturers"],
|
||||
"run_in_background": False,
|
||||
}
|
||||
|
||||
cognify_response = client.post("/api/v1/cognify", json=cognify_payload)
|
||||
# Should not fail due to ontology handling (may fail for dataset reasons)
|
||||
assert cognify_response.status_code != 400 # Not a validation error
|
||||
|
||||
|
||||
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
|
||||
def test_multifile_error_handling(mock_get_default_user, client, mock_default_user):
|
||||
"""Test error handling for invalid multifile uploads"""
|
||||
import io
|
||||
import json
|
||||
|
||||
# Test mismatched array lengths
|
||||
file_content = b"<rdf:RDF></rdf:RDF>"
|
||||
files = [("ontology_file", ("test.owl", io.BytesIO(file_content), "application/xml"))]
|
||||
data = {
|
||||
"ontology_key": json.dumps(["key1", "key2"]), # 2 keys, 1 file
|
||||
"descriptions": json.dumps(["desc1"]),
|
||||
}
|
||||
|
||||
response = client.post("/api/v1/ontologies", files=files, data=data)
|
||||
assert response.status_code == 400
|
||||
assert "Number of keys must match number of files" in response.json()["error"]
|
||||
|
||||
# Test duplicate keys
|
||||
files = [
|
||||
("ontology_file", ("test1.owl", io.BytesIO(file_content), "application/xml")),
|
||||
("ontology_file", ("test2.owl", io.BytesIO(file_content), "application/xml")),
|
||||
]
|
||||
data = {
|
||||
"ontology_key": json.dumps(["duplicate", "duplicate"]),
|
||||
"descriptions": json.dumps(["desc1", "desc2"]),
|
||||
}
|
||||
|
||||
response = client.post("/api/v1/ontologies", files=files, data=data)
|
||||
assert response.status_code == 400
|
||||
assert "Duplicate ontology keys not allowed" in response.json()["error"]
|
||||
|
||||
|
||||
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
|
||||
def test_cognify_missing_ontology_key(mock_get_default_user, client, mock_default_user):
|
||||
"""Test cognify with non-existent ontology key"""
|
||||
payload = {
|
||||
"datasets": ["test_dataset"],
|
||||
"ontology_key": ["nonexistent_key"],
|
||||
"run_in_background": False,
|
||||
}
|
||||
|
||||
response = client.post("/api/v1/cognify", json=payload)
|
||||
assert response.status_code == 409
|
||||
assert "Ontology key 'nonexistent_key' not found" in response.json()["error"]
|
||||
52
cognee/tests/unit/processing/chunks/chunk_by_row_test.py
Normal file
52
cognee/tests/unit/processing/chunks/chunk_by_row_test.py
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
from itertools import product
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
|
||||
from cognee.tasks.chunks import chunk_by_row
|
||||
|
||||
INPUT_TEXTS = "name: John, age: 30, city: New York, country: USA"
|
||||
max_chunk_size_vals = [8, 32]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_text,max_chunk_size",
|
||||
list(product([INPUT_TEXTS], max_chunk_size_vals)),
|
||||
)
|
||||
def test_chunk_by_row_isomorphism(input_text, max_chunk_size):
|
||||
chunks = chunk_by_row(input_text, max_chunk_size)
|
||||
reconstructed_text = ", ".join([chunk["text"] for chunk in chunks])
|
||||
assert reconstructed_text == input_text, (
|
||||
f"texts are not identical: {len(input_text) = }, {len(reconstructed_text) = }"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_text,max_chunk_size",
|
||||
list(product([INPUT_TEXTS], max_chunk_size_vals)),
|
||||
)
|
||||
def test_row_chunk_length(input_text, max_chunk_size):
|
||||
chunks = list(chunk_by_row(data=input_text, max_chunk_size=max_chunk_size))
|
||||
embedding_engine = get_embedding_engine()
|
||||
|
||||
chunk_lengths = np.array(
|
||||
[embedding_engine.tokenizer.count_tokens(chunk["text"]) for chunk in chunks]
|
||||
)
|
||||
|
||||
larger_chunks = chunk_lengths[chunk_lengths > max_chunk_size]
|
||||
assert np.all(chunk_lengths <= max_chunk_size), (
|
||||
f"{max_chunk_size = }: {larger_chunks} are too large"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_text,max_chunk_size",
|
||||
list(product([INPUT_TEXTS], max_chunk_size_vals)),
|
||||
)
|
||||
def test_chunk_by_row_chunk_numbering(input_text, max_chunk_size):
|
||||
chunks = chunk_by_row(data=input_text, max_chunk_size=max_chunk_size)
|
||||
chunk_indices = np.array([chunk["chunk_index"] for chunk in chunks])
|
||||
assert np.all(chunk_indices == np.arange(len(chunk_indices))), (
|
||||
f"{chunk_indices = } are not monotonically increasing"
|
||||
)
|
||||
Loading…
Add table
Reference in a new issue