Compare commits

...
Sign in to create a new pull request.

1 commit

Author SHA1 Message Date
vasilije
191a9ed0ee first idea 2025-08-11 14:47:27 +02:00
21 changed files with 6081 additions and 0 deletions

View file

@ -0,0 +1,581 @@
# 🎫 Epic: Refactor Ontology System for General Pipeline Usage
**Epic ID:** CGNEE-2024-ONT-001
**Priority:** High
**Story Points:** 21
**Type:** Epic
**Labels:** `refactoring`, `architecture`, `ontology`, `pipeline-integration`
## 📋 Overview
Refactor the current monolithic `OntologyResolver` to follow Cognee's architectural patterns and be general, extensible, and usable across all pipelines. The current system is tightly coupled to specific tasks and only supports RDF/OWL formats. We need a modular system that follows Cognee's established patterns for configuration, methods organization, and module structure.
## 🎯 Business Value
- **Consistency**: Follow established Cognee patterns and conventions
- **Flexibility**: Support multiple ontology formats and domains
- **Reusability**: One ontology system usable across all pipelines
- **Maintainability**: Modular architecture following Cognee's separation of concerns
- **Developer Experience**: Familiar patterns and simple configuration
## 🏗️ Architecture Analysis
Based on examination of existing Cognee patterns:
### **Configuration Pattern** (Following `cognee/base_config.py`, `cognee/modules/cognify/config.py`)
```python
# Follow BaseSettings pattern with @lru_cache
class OntologyConfig(BaseSettings):
default_format: OntologyFormat = OntologyFormat.JSON
enable_semantic_search: bool = False
registry_type: str = "memory"
model_config = SettingsConfigDict(env_file=".env", extra="allow")
@lru_cache
def get_ontology_config():
return OntologyConfig()
```
### **Methods Organization** (Following `cognee/modules/data/methods/`, `cognee/modules/users/methods/`)
```python
# Organize methods in separate files:
cognee/modules/ontology/methods/
├── __init__.py # Export all methods
├── create_ontology.py # async def create_ontology(...)
├── get_ontology.py # async def get_ontology(...)
├── load_ontology.py # async def load_ontology(...)
├── register_ontology.py # async def register_ontology(...)
└── delete_ontology.py # async def delete_ontology(...)
```
### **Models Structure** (Following `cognee/modules/data/models/`, `cognee/modules/users/models/`)
```python
# Create models with proper inheritance:
cognee/modules/ontology/models/
├── __init__.py # Export all models
├── OntologyGraph.py # class OntologyGraph(BaseModel)
├── OntologyNode.py # class OntologyNode(BaseModel)
├── OntologyContext.py # class OntologyContext(BaseModel)
└── DataPointMapping.py # class DataPointMapping(BaseModel)
```
### **Module Organization** (Following existing module patterns)
```python
cognee/modules/ontology/
├── __init__.py # Public API with convenience functions
├── config.py # OntologyConfig with @lru_cache
├── models/ # Pydantic models
├── methods/ # Async method functions
├── providers/ # Format-specific providers
├── adapters/ # Query and search operations
├── operations/ # Core business logic
└── utils/ # Utility functions
```
## 📦 Epic Breakdown (Cognee-Style Implementation)
### Story 1: Configuration System & Models
**Story Points:** 3
**Assignee:** Backend Developer
#### Acceptance Criteria
- [ ] Create `OntologyConfig` following `BaseSettings` pattern with `@lru_cache`
- [ ] Create Pydantic models in `models/` directory with proper `__init__.py` exports
- [ ] Follow naming conventions: `OntologyNode`, `OntologyEdge`, etc.
- [ ] Use proper type hints and docstring patterns from existing code
- [ ] Environment variable support following existing config patterns
#### Files to Create
```
cognee/modules/ontology/config.py
cognee/modules/ontology/models/__init__.py
cognee/modules/ontology/models/OntologyGraph.py
cognee/modules/ontology/models/OntologyNode.py
cognee/modules/ontology/models/OntologyContext.py
cognee/modules/ontology/models/DataPointMapping.py
```
#### Implementation Notes
```python
# config.py - Follow existing pattern
@lru_cache
def get_ontology_config():
return OntologyConfig()
# models/OntologyNode.py - Follow DataPoint pattern
class OntologyNode(BaseModel):
id: str = Field(..., description="Unique identifier")
name: str
type: str
# ... follow existing model patterns
```
---
### Story 2: Methods Organization
**Story Points:** 2
**Assignee:** Backend Developer
**Dependencies:** Story 1
#### Acceptance Criteria
- [ ] Create `methods/` directory following existing pattern
- [ ] Implement async functions following `create_dataset`, `get_user` patterns
- [ ] Use proper error handling patterns from existing methods
- [ ] Follow parameter naming and return type conventions
- [ ] Export all methods in `methods/__init__.py`
#### Files to Create
```
cognee/modules/ontology/methods/__init__.py
cognee/modules/ontology/methods/create_ontology.py
cognee/modules/ontology/methods/get_ontology.py
cognee/modules/ontology/methods/load_ontology.py
cognee/modules/ontology/methods/register_ontology.py
cognee/modules/ontology/methods/delete_ontology.py
```
#### Implementation Notes
```python
# methods/create_ontology.py - Follow create_dataset pattern
async def create_ontology(
ontology_data: Dict[str, Any],
user: User,
scope: OntologyScope = OntologyScope.USER
) -> OntologyGraph:
# Follow existing error handling and validation patterns
```
---
### Story 3: Provider System (Following Existing Patterns)
**Story Points:** 4
**Assignee:** Backend Developer
**Dependencies:** Story 2
#### Acceptance Criteria
- [ ] Create providers following the adapter pattern seen in retrieval systems
- [ ] Use abstract base classes like `CogneeAbstractGraph`
- [ ] Follow error handling patterns from existing providers
- [ ] Support graceful degradation (like RDF provider with optional rdflib)
- [ ] Use proper logging patterns with `get_logger()`
#### Files to Create
```
cognee/modules/ontology/providers/__init__.py
cognee/modules/ontology/providers/base.py
cognee/modules/ontology/providers/rdf_provider.py
cognee/modules/ontology/providers/json_provider.py
cognee/modules/ontology/providers/csv_provider.py
```
#### Implementation Notes
```python
# providers/base.py - Follow CogneeAbstractGraph pattern
class BaseOntologyProvider(ABC):
@abstractmethod
async def load_ontology(self, source: str) -> OntologyGraph:
pass
# providers/rdf_provider.py - Follow graceful fallback pattern
class RDFOntologyProvider(BaseOntologyProvider):
def __init__(self):
try:
import rdflib
self.available = True
except ImportError:
logger.warning("rdflib not available")
self.available = False
```
---
### Story 4: Operations Layer (Core Business Logic)
**Story Points:** 4
**Assignee:** Senior Backend Developer
**Dependencies:** Story 3
#### Acceptance Criteria
- [ ] Create `operations/` directory following pipeline operations pattern
- [ ] Implement core business logic following `cognee_pipeline` pattern
- [ ] Use dependency injection patterns seen in existing operations
- [ ] Follow async/await patterns consistently
- [ ] Implement proper error handling and logging
#### Files to Create
```
cognee/modules/ontology/operations/__init__.py
cognee/modules/ontology/operations/ontology_manager.py
cognee/modules/ontology/operations/datapoint_resolver.py
cognee/modules/ontology/operations/graph_binder.py
cognee/modules/ontology/operations/registry.py
```
#### Implementation Notes
```python
# operations/ontology_manager.py - Follow cognee_pipeline pattern
async def manage_ontology_processing(
context: OntologyContext,
providers: Dict[str, BaseOntologyProvider],
config: OntologyConfig = None
) -> List[OntologyGraph]:
# Follow existing operation patterns
```
---
### Story 5: Pipeline Integration (Following Task Pattern)
**Story Points:** 3
**Assignee:** Backend Developer
**Dependencies:** Story 4
#### Acceptance Criteria
- [ ] Create integration following `Task` class pattern
- [ ] Support injection into existing pipeline operations
- [ ] Follow parameter passing patterns from `cognee_pipeline`
- [ ] Maintain backward compatibility with existing tasks
- [ ] Use context variables pattern for configuration
#### Files to Create
```
cognee/modules/ontology/operations/pipeline_integration.py
cognee/modules/ontology/operations/task_enhancer.py
```
#### Implementation Notes
```python
# operations/pipeline_integration.py - Follow cognee_pipeline pattern
async def inject_ontology_context(
tasks: list[Task],
ontology_context: OntologyContext,
config: OntologyConfig = None
) -> list[Task]:
# Follow existing task enhancement patterns
```
---
### Story 6: Utils and Utilities
**Story Points:** 2
**Assignee:** Backend Developer
**Dependencies:** Story 5
#### Acceptance Criteria
- [ ] Create `utils/` directory for utility functions
- [ ] Follow utility patterns from existing modules
- [ ] Implement helper functions for common operations
- [ ] Use proper type hints and error handling
#### Files to Create
```
cognee/modules/ontology/utils/__init__.py
cognee/modules/ontology/utils/ontology_helpers.py
cognee/modules/ontology/utils/mapping_helpers.py
cognee/modules/ontology/utils/validation.py
```
---
### Story 7: Public API and Module Initialization
**Story Points:** 2
**Assignee:** Backend Developer
**Dependencies:** Story 6
#### Acceptance Criteria
- [ ] Create comprehensive `__init__.py` following existing patterns
- [ ] Export key classes and functions following module conventions
- [ ] Create convenience functions following `get_base_config()` pattern
- [ ] Maintain backward compatibility with `OntologyResolver`
- [ ] Use proper `__all__` exports
#### Files to Update/Create
```
cognee/modules/ontology/__init__.py
```
#### Implementation Notes
```python
# __init__.py - Follow existing module export patterns
from .config import get_ontology_config
from .models import OntologyGraph, OntologyNode, OntologyContext
from .methods import create_ontology, load_ontology, get_ontology
# Convenience functions following get_base_config pattern
@lru_cache
def get_ontology_manager():
return OntologyManager()
__all__ = [
"OntologyGraph",
"OntologyNode",
"get_ontology_config",
"create_ontology",
# ... follow existing export patterns
]
```
---
### Story 8: Enhanced Task Implementation
**Story Points:** 2
**Assignee:** Backend Developer
**Dependencies:** Story 7
#### Acceptance Criteria
- [ ] Update existing graph extraction task to be ontology-aware
- [ ] Follow existing task parameter patterns
- [ ] Maintain backward compatibility
- [ ] Use proper error handling and fallback mechanisms
- [ ] Follow existing task documentation patterns
#### Files to Create/Update
```
cognee/tasks/graph/extract_graph_from_data_ontology_aware.py
```
---
### Story 9: Documentation and Examples
**Story Points:** 1
**Assignee:** Technical Writer / Backend Developer
**Dependencies:** Story 8
#### Acceptance Criteria
- [ ] Create migration guide following existing documentation patterns
- [ ] Provide examples following existing code patterns
- [ ] Document configuration options following existing config docs
- [ ] Create troubleshooting guide
#### Files to Create
```
cognee/modules/ontology/MIGRATION_GUIDE.md
cognee/modules/ontology/examples/
```
---
### Story 10: Testing Following Cognee Patterns
**Story Points:** 2
**Assignee:** QA Engineer / Backend Developer
**Dependencies:** Story 9
#### Acceptance Criteria
- [ ] Create tests following existing test structure in `cognee/tests/`
- [ ] Use existing test patterns and fixtures
- [ ] Test integration with existing pipeline system
- [ ] Verify backward compatibility
- [ ] Performance testing following existing benchmarks
#### Files to Create
```
cognee/tests/unit/ontology/
cognee/tests/integration/ontology/
```
---
## 🔧 Technical Requirements (Cognee-Aligned)
### Follow Existing Patterns
- **Configuration**: Use `BaseSettings` with `@lru_cache` pattern
- **Models**: Pydantic models with proper inheritance and validation
- **Methods**: Async functions with consistent error handling
- **Operations**: Business logic separation following existing operations
- **Exports**: Proper `__init__.py` files with `__all__` exports
- **Logging**: Use `get_logger()` pattern consistently
- **Error Handling**: Follow existing exception patterns
### Integration Requirements
- Work seamlessly with existing `Task` system
- Support existing `cognee_pipeline` operations
- Integrate with current configuration management
- Support existing database patterns
- Maintain compatibility with current DataPoint model
### Code Style Requirements
- Follow existing naming conventions (PascalCase for classes, snake_case for functions)
- Use type hints consistently like existing code
- Follow docstring patterns from existing modules
- Use existing import organization patterns
- Follow async/await patterns consistently
## 🚀 Implementation Guidelines
### File Organization (Following Cognee Patterns)
```
cognee/modules/ontology/
├── __init__.py # Public API, convenience functions
├── config.py # OntologyConfig with @lru_cache
├── models/ # Pydantic models
│ ├── __init__.py # Export all models
│ ├── OntologyGraph.py
│ ├── OntologyNode.py
│ └── ...
├── methods/ # Business methods (CRUD operations)
│ ├── __init__.py # Export all methods
│ ├── create_ontology.py
│ ├── get_ontology.py
│ └── ...
├── providers/ # Format-specific providers
├── operations/ # Core business logic
├── utils/ # Utility functions
└── examples/ # Usage examples
```
### Code Patterns to Follow
#### Configuration Pattern
```python
# cognee/modules/ontology/config.py
from functools import lru_cache
from pydantic_settings import BaseSettings, SettingsConfigDict
class OntologyConfig(BaseSettings):
default_format: str = "json"
enable_semantic_search: bool = False
model_config = SettingsConfigDict(env_file=".env", extra="allow")
def to_dict(self) -> dict:
return {"default_format": self.default_format, ...}
@lru_cache
def get_ontology_config():
return OntologyConfig()
```
#### Method Pattern
```python
# cognee/modules/ontology/methods/create_ontology.py
from cognee.shared.logging_utils import get_logger
from cognee.modules.ontology.models import OntologyGraph
logger = get_logger("ontology.create")
async def create_ontology(
ontology_data: Dict[str, Any],
user: User,
scope: OntologyScope = OntologyScope.USER
) -> OntologyGraph:
"""Create ontology following existing method patterns."""
try:
# Implementation following existing patterns
logger.info(f"Creating ontology for user: {user.id}")
# ...
except Exception as e:
logger.error(f"Failed to create ontology: {e}")
raise
```
#### Model Pattern
```python
# cognee/modules/ontology/models/OntologyNode.py
from pydantic import BaseModel, Field
from typing import Optional, Dict, Any
class OntologyNode(BaseModel):
"""Ontology node following Cognee model patterns."""
id: str = Field(..., description="Unique identifier")
name: str
type: str
description: Optional[str] = None
properties: Dict[str, Any] = Field(default_factory=dict)
class Config:
"""Follow existing model config patterns."""
arbitrary_types_allowed = True
```
### Integration with Existing Systems
#### Task Integration
```python
# Follow existing Task parameter patterns
async def extract_graph_from_data_ontology_aware(
data_chunks: List[DocumentChunk],
graph_model: Type[BaseModel] = KnowledgeGraph,
ontology_config: OntologyConfig = None,
**kwargs
) -> List[DocumentChunk]:
"""Enhanced task following existing task patterns."""
config = ontology_config or get_ontology_config()
# Follow existing task implementation patterns
```
#### Pipeline Integration
```python
# Follow cognee_pipeline pattern for integration
async def cognee_pipeline_with_ontology(
tasks: list[Task],
ontology_context: OntologyContext = None,
**kwargs
):
"""Enhanced pipeline following existing pipeline patterns."""
# Inject ontology context following existing parameter injection
enhanced_tasks = []
for task in tasks:
if ontology_context:
# Enhance task following existing enhancement patterns
enhanced_task = enhance_task_with_ontology(task, ontology_context)
enhanced_tasks.append(enhanced_task)
else:
enhanced_tasks.append(task)
# Use existing pipeline execution
return await cognee_pipeline(enhanced_tasks, **kwargs)
```
## 📊 Success Metrics (Aligned with Cognee Standards)
### Code Quality Metrics
- [ ] Follow all existing linting and code style rules
- [ ] Pass all existing code quality checks
- [ ] Maintain or improve test coverage percentage
- [ ] Follow existing documentation standards
### Integration Metrics
- [ ] Zero breaking changes to existing API
- [ ] All existing tests continue to pass
- [ ] Performance meets or exceeds existing benchmarks
- [ ] Memory usage within existing parameters
### Pattern Compliance
- [ ] Configuration follows `BaseSettings` + `@lru_cache` pattern
- [ ] Models follow existing Pydantic patterns
- [ ] Methods follow existing async function patterns
- [ ] Exports follow existing `__init__.py` patterns
- [ ] Error handling follows existing exception patterns
## 🔗 Related Files to Study
### Configuration Patterns
- `cognee/base_config.py`
- `cognee/modules/cognify/config.py`
- `cognee/infrastructure/llm/config.py`
### Model Patterns
- `cognee/modules/data/models/`
- `cognee/modules/users/models/`
- `cognee/infrastructure/engine/models/DataPoint.py`
### Method Patterns
- `cognee/modules/data/methods/`
- `cognee/modules/users/methods/`
### Operation Patterns
- `cognee/modules/pipelines/operations/`
- `cognee/modules/search/methods/`
### Module Organization
- `cognee/modules/pipelines/__init__.py`
- `cognee/modules/data/__init__.py`
- `cognee/modules/users/__init__.py`
---
**Estimated Total Effort:** 21 Story Points (~4-5 Sprints)
**Target Completion:** End of Q2 2024
**Review Required:** Architecture Review, Code Standards Review, Integration Review
**Key Success Factor:** Strict adherence to existing Cognee patterns and conventions to ensure seamless integration and maintainability.

View file

@ -0,0 +1,367 @@
# Ontology System Migration Guide
This guide explains how to migrate from the old ontology system to the new refactored architecture.
## Overview of Changes
The ontology system has been completely refactored to provide:
1. **Better Separation of Concerns**: Clear interfaces for different components
2. **DataPoint Integration**: Automatic mapping between ontologies and DataPoint instances
3. **Custom Graph Binding**: Configurable binding strategies for different graph types
4. **Pipeline Integration**: Seamless injection into pipeline tasks
5. **Domain Configuration**: Pre-configured setups for common domains
6. **Extensibility**: Plugin system for custom behavior
## Architecture Changes
### Old System
```
OntologyResolver (monolithic)
├── RDF/OWL parsing
├── Basic node/edge extraction
└── Simple graph operations
```
### New System
```
OntologyManager (orchestrator)
├── OntologyRegistry (storage/lookup)
├── OntologyProviders (format-specific loading)
├── OntologyAdapters (query/search operations)
├── DataPointResolver (ontology ↔ DataPoint mapping)
└── GraphBinder (ontology → graph structure binding)
```
## Migration Steps
### 1. Replace OntologyResolver Usage
**Old Code:**
```python
from cognee.modules.ontology.rdf_xml.OntologyResolver import OntologyResolver
# Old way
ontology_resolver = OntologyResolver(ontology_file="medical.owl")
nodes, edges, root = ontology_resolver.get_subgraph("Disease", "classes")
```
**New Code:**
```python
from cognee.modules.ontology import (
create_ontology_system,
OntologyContext,
configure_domain,
DataPointMapping,
GraphBindingConfig
)
# New way
ontology_manager = await create_ontology_system()
# Configure domain if needed
mappings = [
DataPointMapping(
ontology_node_type="Disease",
datapoint_class="cognee.infrastructure.engine.models.DataPoint.DataPoint",
field_mappings={"name": "name", "description": "description"}
)
]
binding = GraphBindingConfig(
node_type_mapping={"Disease": "medical_condition"}
)
configure_domain("medical", mappings, binding)
# Use with context
context = OntologyContext(domain="medical")
ontologies = await ontology_manager.get_applicable_ontologies(context)
```
### 2. Update Pipeline Task Integration
**Old Code:**
```python
from cognee.api.v1.cognify.cognify import get_default_tasks
# Old way - hardcoded ontology adapter
tasks = await get_default_tasks(
ontology_file_path="path/to/ontology.owl"
)
```
**New Code:**
```python
from cognee.modules.ontology import create_pipeline_injector
from cognee.modules.pipelines.tasks.task import Task
# New way - configurable ontology injection
ontology_manager = await create_ontology_system()
injector = create_pipeline_injector(ontology_manager, "my_pipeline", "medical")
# Inject into tasks
original_task = Task(extract_graph_from_data)
context = OntologyContext(domain="medical", pipeline_name="my_pipeline")
enhanced_task = await injector.inject_into_task(original_task, context)
```
### 3. Update DataPoint Creation
**Old Code:**
```python
# Old way - manual DataPoint creation
from cognee.infrastructure.engine.models.DataPoint import DataPoint
datapoint = DataPoint(
id="disease_001",
type="Disease",
# Manual field mapping...
)
```
**New Code:**
```python
# New way - automatic resolution from ontology
ontology_nodes = [...] # Nodes from ontology
context = OntologyContext(domain="medical")
# Automatically resolve to DataPoints
datapoints = await ontology_manager.resolve_to_datapoints(ontology_nodes, context)
```
### 4. Update Graph Binding
**Old Code:**
```python
# Old way - hardcoded graph node creation
graph_nodes = []
for ontology_node in nodes:
graph_node = (
ontology_node.id,
{
"name": ontology_node.name,
"type": ontology_node.type,
# Manual property mapping...
}
)
graph_nodes.append(graph_node)
```
**New Code:**
```python
# New way - configurable binding
ontology = # ... loaded ontology
context = OntologyContext(domain="medical")
# Automatically bind to graph structure
graph_nodes, graph_edges = await ontology_manager.bind_to_graph(ontology, context)
```
## Domain-Specific Configurations
### Medical Domain
**Old Code:**
```python
# Old way - manual setup for each pipeline
ontology_resolver = OntologyResolver("medical_ontology.owl")
# ... manual entity extraction
# ... manual graph binding
```
**New Code:**
```python
# New way - use pre-configured medical domain
from cognee.modules.ontology import create_medical_pipeline_config
ontology_manager = await create_ontology_system()
injector = create_pipeline_injector(ontology_manager, "medical_pipeline", "medical")
# Automatically configured for:
# - Disease, Symptom, Treatment entities
# - Medical-specific DataPoint mappings
# - Clinical graph relationships
```
### Legal Domain
```python
# Pre-configured for legal documents
injector = create_pipeline_injector(ontology_manager, "legal_pipeline", "legal")
# Automatically handles:
# - Law, Case, Court entities
# - Legal citation relationships
# - Jurisdiction-aware processing
```
### Code Analysis Domain
```python
# Pre-configured for code analysis
injector = create_pipeline_injector(ontology_manager, "code_pipeline", "code")
# Automatically handles:
# - Function, Class, Module entities
# - Code dependency relationships
# - Language-specific processing
```
## Custom Resolvers and Binding
### Custom DataPoint Resolver
```python
# Define custom resolver for special entity types
async def custom_medical_resolver(ontology_node, mapping_config, context=None):
# Custom logic for creating DataPoints
datapoint = DataPoint(
id=ontology_node.id,
type="medical_entity",
# Custom field mapping and validation
)
return datapoint
# Register the resolver
ontology_manager.register_custom_resolver("medical_resolver", custom_medical_resolver)
# Use in mapping configuration
mapping = DataPointMapping(
ontology_node_type="SpecialDisease",
datapoint_class="cognee.infrastructure.engine.models.DataPoint.DataPoint",
custom_resolver="medical_resolver"
)
```
### Custom Graph Binding
```python
# Define custom binding strategy
async def custom_graph_binding(ontology, binding_config, context=None):
# Custom logic for graph structure creation
graph_nodes = []
graph_edges = []
for node in ontology.nodes:
# Custom node transformation
transformed_node = transform_node(node)
graph_nodes.append(transformed_node)
return graph_nodes, graph_edges
# Register the strategy
ontology_manager.register_binding_strategy("custom_binding", custom_graph_binding)
# Use in binding configuration
binding = GraphBindingConfig(
custom_binding_strategy="custom_binding"
)
```
## Configuration Files
### Create Configuration File
```python
from cognee.modules.ontology import create_example_config_file
# Generate example configuration
create_example_config_file("ontology_config.json")
```
### Load Configuration
```python
from cognee.modules.ontology import load_ontology_config
# Load from file
load_ontology_config("ontology_config.json")
# Or create system with config
ontology_manager = await create_ontology_system(config_file="ontology_config.json")
```
## Backward Compatibility
The old `OntologyResolver` is still available for backward compatibility:
```python
from cognee.modules.ontology import OntologyResolver
# Old interface still works
resolver = OntologyResolver(ontology_file="medical.owl")
```
However, it's recommended to migrate to the new system for better flexibility and features.
## Performance Improvements
The new system provides several performance benefits:
1. **Lazy Loading**: Ontologies are loaded only when needed
2. **Caching**: Registry caches frequently accessed ontologies
3. **Parallel Processing**: Multiple ontologies can be processed simultaneously
4. **Semantic Search**: Optional semantic similarity for better entity matching
## Testing Your Migration
Use the example usage script to test your migration:
```python
from cognee.modules.ontology.example_usage import main
# Run comprehensive examples
await main()
```
## Common Migration Issues
### 1. Import Errors
**Problem**: `ImportError: cannot import name 'OntologyResolver'`
**Solution**: Update imports to use new interfaces:
```python
# Old
from cognee.modules.ontology.rdf_xml.OntologyResolver import OntologyResolver
# New
from cognee.modules.ontology import create_ontology_system, OntologyContext
```
### 2. Task Parameter Changes
**Problem**: Tasks expecting `ontology_adapter` parameter
**Solution**: Use ontology injection:
```python
# Old
Task(extract_graph_from_data, ontology_adapter=resolver)
# New
injector = create_pipeline_injector(ontology_manager, "pipeline", "domain")
enhanced_task = await injector.inject_into_task(original_task, context)
```
### 3. DataPoint Creation Changes
**Problem**: Manual DataPoint creation with ontology data
**Solution**: Use automatic resolution:
```python
# Old
datapoint = DataPoint(id=node.id, type=node.type, ...)
# New
datapoints = await ontology_manager.resolve_to_datapoints(nodes, context)
```
## Support
For additional support with migration:
1. Check the example usage file: `cognee/modules/ontology/example_usage.py`
2. Review the interface documentation in `cognee/modules/ontology/interfaces.py`
3. Use the pre-configured domain setups for common use cases
4. Test with the provided configuration examples
The new system is designed to be more powerful while maintaining ease of use. Most migrations can be completed by updating imports and using the convenience functions provided.

View file

@ -0,0 +1,196 @@
"""
Ontology module for Cognee.
Provides ontology management capabilities including loading, processing,
and integration with pipelines following Cognee's architectural patterns.
"""
# Configuration (following Cognee pattern)
from .config import get_ontology_config
# Core data models
from .interfaces import (
OntologyNode,
OntologyEdge,
OntologyGraph,
OntologyContext,
DataPointMapping,
GraphBindingConfig,
OntologyFormat,
OntologyScope,
)
# Core implementations
from .manager import OntologyManager, create_ontology_manager
from .registry import OntologyRegistry
from .providers import (
RDFOntologyProvider,
JSONOntologyProvider,
CSVOntologyProvider,
)
# Methods (following Cognee pattern)
from .methods import (
create_ontology,
get_ontology,
load_ontology,
register_ontology,
delete_ontology,
)
# Legacy compatibility
from .rdf_xml.OntologyResolver import OntologyResolver
# Convenience functions for quick setup
async def create_ontology_system(
config_file: str = None,
use_database_registry: bool = False,
enable_semantic_search: bool = False
) -> OntologyManager:
"""
Create a fully configured ontology system.
Args:
config_file: Optional configuration file to load
use_database_registry: Whether to use database-backed registry
enable_semantic_search: Whether to enable semantic search capabilities
Returns:
Configured OntologyManager instance
"""
# Create registry
if use_database_registry:
registry = DatabaseOntologyRegistry()
else:
registry = OntologyRegistry()
# Create providers
providers = {
"json_provider": JSONOntologyProvider(),
"csv_provider": CSVOntologyProvider(),
}
# Add RDF provider if available
rdf_provider = RDFOntologyProvider()
if rdf_provider.available:
providers["rdf_provider"] = rdf_provider
# Create adapters
adapters = {
"default_adapter": DefaultOntologyAdapter(),
"graph_adapter": GraphOntologyAdapter(),
}
# Add semantic adapter if requested and available
if enable_semantic_search:
semantic_adapter = SemanticOntologyAdapter()
if semantic_adapter.embeddings_available:
adapters["semantic_adapter"] = semantic_adapter
# Create resolver and binder
datapoint_resolver = DefaultDataPointResolver()
graph_binder = DefaultGraphBinder()
# Create manager
manager = await create_ontology_manager(
registry=registry,
providers=providers,
adapters=adapters,
datapoint_resolver=datapoint_resolver,
graph_binder=graph_binder,
)
# Load configuration if provided
if config_file:
config = get_ontology_config()
config.load_from_file(config_file)
# Apply configurations to manager
for domain, domain_config in config.domain_configs.items():
manager.configure_datapoint_mapping(
domain, domain_config["datapoint_mappings"]
)
manager.configure_graph_binding(
domain, domain_config["graph_binding"]
)
return manager
def create_pipeline_injector(
ontology_manager: OntologyManager,
pipeline_name: str,
domain: str = None
) -> OntologyInjector:
"""
Create an ontology injector for a specific pipeline.
Args:
ontology_manager: The ontology manager instance
pipeline_name: Name of the pipeline
domain: Domain for the pipeline (optional)
Returns:
Configured OntologyInjector
"""
configurator = PipelineOntologyConfigurator(ontology_manager)
# Use pre-configured domain setup if available
if domain in ["medical", "legal", "code"]:
if domain == "medical":
config = create_medical_pipeline_config()
elif domain == "legal":
config = create_legal_pipeline_config()
elif domain == "code":
config = create_code_pipeline_config()
configurator.configure_pipeline(
pipeline_name=pipeline_name,
domain=config["domain"],
datapoint_mappings=config["datapoint_mappings"],
graph_binding=config["graph_binding"],
task_specific_configs=config["task_configs"]
)
return configurator.create_ontology_injector(pipeline_name)
# Export following Cognee pattern
__all__ = [
# Configuration
"get_ontology_config",
# Core classes
"OntologyManager",
"OntologyRegistry",
# Data models
"OntologyNode",
"OntologyEdge",
"OntologyGraph",
"OntologyContext",
"DataPointMapping",
"GraphBindingConfig",
"OntologyFormat",
"OntologyScope",
# Providers
"RDFOntologyProvider",
"JSONOntologyProvider",
"CSVOntologyProvider",
# Methods
"create_ontology",
"get_ontology",
"load_ontology",
"register_ontology",
"delete_ontology",
# Convenience functions
"create_ontology_system",
"create_pipeline_injector",
# Legacy compatibility
"OntologyResolver",
]

View file

@ -0,0 +1,397 @@
"""Ontology adapter implementations."""
import difflib
from typing import List, Tuple, Optional
from collections import deque
from cognee.shared.logging_utils import get_logger
from cognee.modules.ontology.interfaces import (
IOntologyAdapter,
OntologyGraph,
OntologyNode,
OntologyEdge,
)
logger = get_logger("OntologyAdapters")
class DefaultOntologyAdapter(IOntologyAdapter):
"""Default implementation of ontology adapter."""
async def find_matching_nodes(
self,
query_text: str,
ontology: OntologyGraph,
similarity_threshold: float = 0.8
) -> List[OntologyNode]:
"""Find nodes matching query text using simple string similarity."""
matching_nodes = []
query_lower = query_text.lower()
for node in ontology.nodes:
# Check name similarity
name_similarity = self._calculate_similarity(query_lower, node.name.lower())
# Check description similarity
desc_similarity = 0.0
if node.description:
desc_similarity = self._calculate_similarity(query_lower, node.description.lower())
# Check properties similarity
props_similarity = 0.0
for prop_value in node.properties.values():
if isinstance(prop_value, str):
prop_sim = self._calculate_similarity(query_lower, prop_value.lower())
props_similarity = max(props_similarity, prop_sim)
# Take maximum similarity
max_similarity = max(name_similarity, desc_similarity, props_similarity)
if max_similarity >= similarity_threshold:
# Add similarity score to node properties for ranking
node_copy = OntologyNode(**node.dict())
node_copy.properties["_similarity_score"] = max_similarity
matching_nodes.append(node_copy)
# Sort by similarity score
matching_nodes.sort(key=lambda n: n.properties.get("_similarity_score", 0), reverse=True)
logger.debug(f"Found {len(matching_nodes)} nodes matching '{query_text}'")
return matching_nodes
async def get_node_relationships(
self,
node_id: str,
ontology: OntologyGraph,
max_depth: int = 2
) -> List[OntologyEdge]:
"""Get relationships for a specific node."""
relationships = []
visited = set()
queue = deque([(node_id, 0)]) # (node_id, depth)
while queue:
current_id, depth = queue.popleft()
if current_id in visited or depth > max_depth:
continue
visited.add(current_id)
# Find edges where this node is source or target
for edge in ontology.edges:
if edge.source_id == current_id:
relationships.append(edge)
if depth < max_depth:
queue.append((edge.target_id, depth + 1))
elif edge.target_id == current_id:
relationships.append(edge)
if depth < max_depth:
queue.append((edge.source_id, depth + 1))
# Remove duplicates
unique_relationships = []
seen_edges = set()
for rel in relationships:
edge_key = (rel.source_id, rel.target_id, rel.relationship_type)
if edge_key not in seen_edges:
seen_edges.add(edge_key)
unique_relationships.append(rel)
logger.debug(f"Found {len(unique_relationships)} relationships for node {node_id}")
return unique_relationships
async def expand_subgraph(
self,
node_ids: List[str],
ontology: OntologyGraph,
directed: bool = True
) -> Tuple[List[OntologyNode], List[OntologyEdge]]:
"""Expand subgraph around given nodes."""
subgraph_nodes = []
subgraph_edges = []
# Get all nodes in the node_ids list
node_map = {node.id: node for node in ontology.nodes}
included_node_ids = set(node_ids)
# Add initial nodes
for node_id in node_ids:
if node_id in node_map:
subgraph_nodes.append(node_map[node_id])
# Find connected edges and nodes
for edge in ontology.edges:
include_edge = False
if directed:
# Include edge if source is in our set
if edge.source_id in included_node_ids:
include_edge = True
# Add target node if not already included
if edge.target_id not in included_node_ids and edge.target_id in node_map:
subgraph_nodes.append(node_map[edge.target_id])
included_node_ids.add(edge.target_id)
else:
# Include edge if either source or target is in our set
if edge.source_id in included_node_ids or edge.target_id in included_node_ids:
include_edge = True
# Add both nodes if not already included
for node_id in [edge.source_id, edge.target_id]:
if node_id not in included_node_ids and node_id in node_map:
subgraph_nodes.append(node_map[node_id])
included_node_ids.add(node_id)
if include_edge:
subgraph_edges.append(edge)
logger.debug(f"Expanded subgraph with {len(subgraph_nodes)} nodes and {len(subgraph_edges)} edges")
return subgraph_nodes, subgraph_edges
async def merge_ontologies(
self,
ontologies: List[OntologyGraph]
) -> OntologyGraph:
"""Merge multiple ontologies."""
if not ontologies:
raise ValueError("No ontologies to merge")
if len(ontologies) == 1:
return ontologies[0]
# Create merged ontology
merged_nodes = []
merged_edges = []
merged_metadata = {}
# Keep track of node and edge IDs to avoid duplicates
seen_node_ids = set()
seen_edge_ids = set()
# Merge nodes
for ontology in ontologies:
for node in ontology.nodes:
if node.id not in seen_node_ids:
merged_nodes.append(node)
seen_node_ids.add(node.id)
else:
# Handle duplicate nodes by merging properties
existing_node = next(n for n in merged_nodes if n.id == node.id)
existing_node.properties.update(node.properties)
# Merge edges
for ontology in ontologies:
for edge in ontology.edges:
edge_key = (edge.source_id, edge.target_id, edge.relationship_type)
if edge_key not in seen_edge_ids:
merged_edges.append(edge)
seen_edge_ids.add(edge_key)
# Merge metadata
for ontology in ontologies:
merged_metadata.update(ontology.metadata)
merged_metadata["merged_from"] = [ont.id for ont in ontologies]
from datetime import datetime
merged_metadata["merge_timestamp"] = datetime.now().isoformat()
merged_ontology = OntologyGraph(
id=f"merged_{'_'.join([ont.id for ont in ontologies[:3]])}",
name=f"Merged Ontology ({len(ontologies)} sources)",
description=f"Merged from: {', '.join([ont.name for ont in ontologies])}",
format=ontologies[0].format, # Use format of first ontology
scope=ontologies[0].scope, # Use scope of first ontology
nodes=merged_nodes,
edges=merged_edges,
metadata=merged_metadata
)
logger.info(f"Merged {len(ontologies)} ontologies into one with {len(merged_nodes)} nodes and {len(merged_edges)} edges")
return merged_ontology
def _calculate_similarity(self, text1: str, text2: str) -> float:
"""Calculate similarity between two text strings."""
if not text1 or not text2:
return 0.0
# Use difflib for sequence similarity
return difflib.SequenceMatcher(None, text1, text2).ratio()
class SemanticOntologyAdapter(DefaultOntologyAdapter):
"""Semantic-aware ontology adapter using embeddings (if available)."""
def __init__(self):
super().__init__()
self.embeddings_available = False
try:
# Try to import embedding functionality
from cognee.infrastructure.llm import get_embedding_engine
self.get_embedding_engine = get_embedding_engine
self.embeddings_available = True
except ImportError:
logger.warning("Embedding engine not available, falling back to string similarity")
async def find_matching_nodes(
self,
query_text: str,
ontology: OntologyGraph,
similarity_threshold: float = 0.8
) -> List[OntologyNode]:
"""Find nodes using semantic similarity if embeddings are available."""
if not self.embeddings_available:
return await super().find_matching_nodes(query_text, ontology, similarity_threshold)
try:
# Get embedding for query
embedding_engine = await self.get_embedding_engine()
query_embedding = await embedding_engine.embed_text(query_text)
matching_nodes = []
for node in ontology.nodes:
# Create node text for embedding
node_text = f"{node.name} {node.description or ''}"
for prop_value in node.properties.values():
if isinstance(prop_value, str):
node_text += f" {prop_value}"
# Get node embedding
node_embedding = await embedding_engine.embed_text(node_text)
# Calculate cosine similarity
similarity = self._cosine_similarity(query_embedding, node_embedding)
if similarity >= similarity_threshold:
node_copy = OntologyNode(**node.dict())
node_copy.properties["_similarity_score"] = similarity
matching_nodes.append(node_copy)
# Sort by similarity
matching_nodes.sort(key=lambda n: n.properties.get("_similarity_score", 0), reverse=True)
logger.debug(f"Found {len(matching_nodes)} nodes using semantic similarity")
return matching_nodes
except Exception as e:
logger.warning(f"Semantic similarity failed, falling back to string matching: {e}")
return await super().find_matching_nodes(query_text, ontology, similarity_threshold)
def _cosine_similarity(self, vec1, vec2) -> float:
"""Calculate cosine similarity between two vectors."""
import numpy as np
vec1 = np.array(vec1)
vec2 = np.array(vec2)
dot_product = np.dot(vec1, vec2)
norm1 = np.linalg.norm(vec1)
norm2 = np.linalg.norm(vec2)
if norm1 == 0 or norm2 == 0:
return 0.0
return dot_product / (norm1 * norm2)
class GraphOntologyAdapter(DefaultOntologyAdapter):
"""Adapter specialized for graph-based operations."""
async def get_node_relationships(
self,
node_id: str,
ontology: OntologyGraph,
max_depth: int = 2
) -> List[OntologyEdge]:
"""Enhanced relationship discovery with graph algorithms."""
# Build adjacency lists for faster traversal
outgoing = {}
incoming = {}
for edge in ontology.edges:
if edge.source_id not in outgoing:
outgoing[edge.source_id] = []
outgoing[edge.source_id].append(edge)
if edge.target_id not in incoming:
incoming[edge.target_id] = []
incoming[edge.target_id].append(edge)
# BFS traversal
relationships = []
visited = set()
queue = deque([(node_id, 0)])
while queue:
current_id, depth = queue.popleft()
if current_id in visited or depth > max_depth:
continue
visited.add(current_id)
# Add outgoing edges
for edge in outgoing.get(current_id, []):
relationships.append(edge)
if depth < max_depth:
queue.append((edge.target_id, depth + 1))
# Add incoming edges
for edge in incoming.get(current_id, []):
relationships.append(edge)
if depth < max_depth:
queue.append((edge.source_id, depth + 1))
# Remove duplicates and sort by relevance
unique_relationships = list({edge.id: edge for edge in relationships}.values())
# Sort by edge weight if available, then by relationship type
unique_relationships.sort(
key=lambda e: (e.weight or 0, e.relationship_type),
reverse=True
)
return unique_relationships
async def find_shortest_path(
self,
source_id: str,
target_id: str,
ontology: OntologyGraph
) -> List[OntologyEdge]:
"""Find shortest path between two nodes."""
# Build graph
graph = {}
for edge in ontology.edges:
if edge.source_id not in graph:
graph[edge.source_id] = []
graph[edge.source_id].append((edge.target_id, edge))
# BFS for shortest path
queue = deque([(source_id, [])])
visited = set()
while queue:
current_id, path = queue.popleft()
if current_id == target_id:
return path
if current_id in visited:
continue
visited.add(current_id)
for neighbor_id, edge in graph.get(current_id, []):
if neighbor_id not in visited:
queue.append((neighbor_id, path + [edge]))
return [] # No path found

View file

@ -0,0 +1,438 @@
"""Graph binding implementations for ontologies."""
from typing import Any, Dict, List, Optional, Tuple, Callable
from uuid import uuid4
from datetime import datetime, timezone
from cognee.shared.logging_utils import get_logger
from cognee.modules.ontology.interfaces import (
IGraphBinder,
OntologyGraph,
OntologyNode,
OntologyEdge,
GraphBindingConfig,
OntologyContext,
)
logger = get_logger("GraphBinder")
class DefaultGraphBinder(IGraphBinder):
"""Default implementation for binding ontology to graph structures."""
def __init__(self):
self.custom_strategies: Dict[str, Callable] = {}
async def bind_ontology_to_graph(
self,
ontology: OntologyGraph,
binding_config: GraphBindingConfig,
context: Optional[OntologyContext] = None
) -> Tuple[List[Any], List[Any]]: # (graph_nodes, graph_edges)
"""Bind ontology to graph structure."""
# Use custom binding strategy if specified
if binding_config.custom_binding_strategy and binding_config.custom_binding_strategy in self.custom_strategies:
return await self._apply_custom_strategy(ontology, binding_config, context)
# Use default binding logic
return await self._default_binding(ontology, binding_config, context)
async def transform_node_properties(
self,
node: OntologyNode,
transformations: Dict[str, Callable[[Any], Any]]
) -> Dict[str, Any]:
"""Transform node properties according to binding config."""
transformed_props = {}
# Start with node's base properties
base_props = {
"id": node.id,
"name": node.name,
"type": node.type,
"description": node.description,
"category": node.category,
}
# Add custom properties
base_props.update(node.properties)
# Apply transformations
for prop_name, prop_value in base_props.items():
if prop_name in transformations:
try:
transformed_props[prop_name] = transformations[prop_name](prop_value)
except Exception as e:
logger.warning(f"Transformation failed for property {prop_name}: {e}")
transformed_props[prop_name] = prop_value
else:
transformed_props[prop_name] = prop_value
# Add standard graph properties
transformed_props.update({
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
"ontology_source": True,
"ontology_id": node.id,
})
return transformed_props
async def transform_edge_properties(
self,
edge: OntologyEdge,
transformations: Dict[str, Callable[[Any], Any]]
) -> Dict[str, Any]:
"""Transform edge properties according to binding config."""
transformed_props = {}
# Start with edge's base properties
base_props = {
"source_node_id": edge.source_id,
"target_node_id": edge.target_id,
"relationship_name": edge.relationship_type,
"weight": edge.weight,
}
# Add custom properties
base_props.update(edge.properties)
# Apply transformations
for prop_name, prop_value in base_props.items():
if prop_name in transformations:
try:
transformed_props[prop_name] = transformations[prop_name](prop_value)
except Exception as e:
logger.warning(f"Transformation failed for edge property {prop_name}: {e}")
transformed_props[prop_name] = prop_value
else:
transformed_props[prop_name] = prop_value
# Add standard graph properties
transformed_props.update({
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
"ontology_source": True,
"ontology_edge_id": edge.id,
})
return transformed_props
def register_binding_strategy(
self,
strategy_name: str,
strategy_func: Callable[[OntologyGraph, GraphBindingConfig], Tuple[List[Any], List[Any]]]
) -> None:
"""Register a custom binding strategy."""
self.custom_strategies[strategy_name] = strategy_func
logger.info(f"Registered custom binding strategy: {strategy_name}")
async def _apply_custom_strategy(
self,
ontology: OntologyGraph,
binding_config: GraphBindingConfig,
context: Optional[OntologyContext] = None
) -> Tuple[List[Any], List[Any]]:
"""Apply custom binding strategy."""
strategy_func = self.custom_strategies[binding_config.custom_binding_strategy]
try:
if context:
return await strategy_func(ontology, binding_config, context)
else:
return await strategy_func(ontology, binding_config)
except Exception as e:
logger.error(f"Custom binding strategy failed: {e}")
# Fallback to default binding
return await self._default_binding(ontology, binding_config, context)
async def _default_binding(
self,
ontology: OntologyGraph,
binding_config: GraphBindingConfig,
context: Optional[OntologyContext] = None
) -> Tuple[List[Any], List[Any]]:
"""Default binding logic."""
graph_nodes = []
graph_edges = []
# Process nodes
for node in ontology.nodes:
try:
graph_node = await self._bind_node_to_graph(node, binding_config)
if graph_node:
graph_nodes.append(graph_node)
except Exception as e:
logger.error(f"Failed to bind node {node.id}: {e}")
# Process edges
for edge in ontology.edges:
try:
graph_edge = await self._bind_edge_to_graph(edge, binding_config)
if graph_edge:
graph_edges.append(graph_edge)
except Exception as e:
logger.error(f"Failed to bind edge {edge.id}: {e}")
logger.info(f"Bound {len(graph_nodes)} nodes and {len(graph_edges)} edges to graph")
return graph_nodes, graph_edges
async def _bind_node_to_graph(
self,
node: OntologyNode,
binding_config: GraphBindingConfig
) -> Tuple[str, Dict[str, Any]]:
"""Bind a single node to graph format."""
# Map node type if configured
graph_node_type = binding_config.node_type_mapping.get(node.type, node.type)
# Transform properties
node_properties = await self.transform_node_properties(
node, binding_config.property_transformations
)
# Set the mapped type
node_properties["type"] = graph_node_type
# Generate node ID for graph (use ontology ID as base)
node_id = node.id
return (node_id, node_properties)
async def _bind_edge_to_graph(
self,
edge: OntologyEdge,
binding_config: GraphBindingConfig
) -> Tuple[str, str, str, Dict[str, Any]]:
"""Bind a single edge to graph format."""
# Map edge type if configured
graph_edge_type = binding_config.edge_type_mapping.get(
edge.relationship_type, edge.relationship_type
)
# Transform properties
edge_properties = await self.transform_edge_properties(
edge, binding_config.property_transformations
)
# Set the mapped relationship name
edge_properties["relationship_name"] = graph_edge_type
return (
edge.source_id,
edge.target_id,
graph_edge_type,
edge_properties
)
class KnowledgeGraphBinder(DefaultGraphBinder):
"""Specialized binder for KnowledgeGraph format."""
async def _bind_node_to_graph(
self,
node: OntologyNode,
binding_config: GraphBindingConfig
) -> Dict[str, Any]:
"""Bind node to KnowledgeGraph Node format."""
# Transform properties
node_properties = await self.transform_node_properties(
node, binding_config.property_transformations
)
# Create KnowledgeGraph-compatible node
from cognee.shared.data_models import Node
kg_node = Node(
id=node.id,
name=node.name,
type=binding_config.node_type_mapping.get(node.type, node.type),
description=node.description or "",
)
return kg_node
async def _bind_edge_to_graph(
self,
edge: OntologyEdge,
binding_config: GraphBindingConfig
) -> Dict[str, Any]:
"""Bind edge to KnowledgeGraph Edge format."""
from cognee.shared.data_models import Edge
# Map edge type
relationship_name = binding_config.edge_type_mapping.get(
edge.relationship_type, edge.relationship_type
)
kg_edge = Edge(
source_node_id=edge.source_id,
target_node_id=edge.target_id,
relationship_name=relationship_name,
)
return kg_edge
class DataPointGraphBinder(DefaultGraphBinder):
"""Specialized binder for DataPoint-based graphs."""
async def _bind_node_to_graph(
self,
node: OntologyNode,
binding_config: GraphBindingConfig
) -> Any: # DataPoint
"""Bind node to DataPoint instance."""
from cognee.infrastructure.engine.models.DataPoint import DataPoint
# Transform properties
node_properties = await self.transform_node_properties(
node, binding_config.property_transformations
)
# Create DataPoint instance
datapoint = DataPoint(
id=node.id,
type=binding_config.node_type_mapping.get(node.type, node.type),
ontology_valid=True,
metadata={
"type": node.type,
"index_fields": ["name", "type"],
"ontology_source": True,
}
)
# Add custom attributes
for prop_name, prop_value in node_properties.items():
if not hasattr(datapoint, prop_name):
setattr(datapoint, prop_name, prop_value)
return datapoint
class DomainSpecificBinder(DefaultGraphBinder):
"""Domain-specific graph binder with specialized transformation logic."""
def __init__(self, domain: str):
super().__init__()
self.domain = domain
async def transform_node_properties(
self,
node: OntologyNode,
transformations: Dict[str, Callable[[Any], Any]]
) -> Dict[str, Any]:
"""Apply domain-specific node transformations."""
# Apply parent transformations first
props = await super().transform_node_properties(node, transformations)
# Apply domain-specific transformations
if self.domain == "medical":
props = await self._apply_medical_node_transforms(node, props)
elif self.domain == "legal":
props = await self._apply_legal_node_transforms(node, props)
elif self.domain == "code":
props = await self._apply_code_node_transforms(node, props)
return props
async def transform_edge_properties(
self,
edge: OntologyEdge,
transformations: Dict[str, Callable[[Any], Any]]
) -> Dict[str, Any]:
"""Apply domain-specific edge transformations."""
# Apply parent transformations first
props = await super().transform_edge_properties(edge, transformations)
# Apply domain-specific transformations
if self.domain == "medical":
props = await self._apply_medical_edge_transforms(edge, props)
elif self.domain == "legal":
props = await self._apply_legal_edge_transforms(edge, props)
elif self.domain == "code":
props = await self._apply_code_edge_transforms(edge, props)
return props
async def _apply_medical_node_transforms(
self, node: OntologyNode, props: Dict[str, Any]
) -> Dict[str, Any]:
"""Apply medical domain node transformations."""
if node.type == "Disease":
props["medical_category"] = "pathology"
props["severity_level"] = node.properties.get("severity", "unknown")
elif node.type == "Symptom":
props["medical_category"] = "clinical_sign"
props["frequency"] = node.properties.get("frequency", "unknown")
return props
async def _apply_legal_node_transforms(
self, node: OntologyNode, props: Dict[str, Any]
) -> Dict[str, Any]:
"""Apply legal domain node transformations."""
if node.type == "Law":
props["legal_authority"] = node.properties.get("jurisdiction", "unknown")
props["enforcement_level"] = node.properties.get("level", "federal")
elif node.type == "Case":
props["court_level"] = node.properties.get("court", "unknown")
props["precedent_value"] = node.properties.get("binding", "persuasive")
return props
async def _apply_code_node_transforms(
self, node: OntologyNode, props: Dict[str, Any]
) -> Dict[str, Any]:
"""Apply code domain node transformations."""
if node.type == "Function":
props["complexity"] = len(node.properties.get("parameters", []))
props["visibility"] = node.properties.get("access_modifier", "public")
elif node.type == "Class":
props["inheritance_depth"] = node.properties.get("depth", 0)
props["method_count"] = len(node.properties.get("methods", []))
return props
async def _apply_medical_edge_transforms(
self, edge: OntologyEdge, props: Dict[str, Any]
) -> Dict[str, Any]:
"""Apply medical domain edge transformations."""
if edge.relationship_type == "causes":
props["causality_strength"] = edge.properties.get("strength", "unknown")
elif edge.relationship_type == "treats":
props["efficacy"] = edge.properties.get("effectiveness", "unknown")
return props
async def _apply_legal_edge_transforms(
self, edge: OntologyEdge, props: Dict[str, Any]
) -> Dict[str, Any]:
"""Apply legal domain edge transformations."""
if edge.relationship_type == "cites":
props["citation_type"] = edge.properties.get("type", "supporting")
elif edge.relationship_type == "overrules":
props["authority_level"] = edge.properties.get("level", "same")
return props
async def _apply_code_edge_transforms(
self, edge: OntologyEdge, props: Dict[str, Any]
) -> Dict[str, Any]:
"""Apply code domain edge transformations."""
if edge.relationship_type == "calls":
props["call_frequency"] = edge.properties.get("frequency", 1)
elif edge.relationship_type == "inherits":
props["inheritance_type"] = edge.properties.get("type", "extends")
return props

View file

@ -0,0 +1,95 @@
"""Ontology configuration following Cognee patterns."""
import os
from typing import Optional
from functools import lru_cache
from pydantic_settings import BaseSettings, SettingsConfigDict
from cognee.shared.logging_utils import get_logger
logger = get_logger("ontology.config")
class OntologyConfig(BaseSettings):
"""
Configuration settings for the ontology system.
Follows Cognee's BaseSettings pattern with environment variable support.
"""
# Default ontology settings
default_format: str = "json"
enable_semantic_search: bool = False
registry_type: str = "memory" # "memory" or "database"
# Provider settings
rdf_provider_enabled: bool = True
json_provider_enabled: bool = True
csv_provider_enabled: bool = True
# Performance settings
cache_ontologies: bool = True
max_cache_size: int = 100
similarity_threshold: float = 0.8
# Domain-specific settings
medical_domain_enabled: bool = True
legal_domain_enabled: bool = True
code_domain_enabled: bool = True
# File paths
ontology_data_directory: str = os.path.join(
os.getenv("COGNEE_DATA_ROOT", ".data_storage"), "ontologies"
)
default_config_file: Optional[str] = None
# Environment variables
ontology_api_key: Optional[str] = os.getenv("ONTOLOGY_API_KEY")
ontology_endpoint: Optional[str] = os.getenv("ONTOLOGY_ENDPOINT")
model_config = SettingsConfigDict(env_file=".env", extra="allow")
def to_dict(self) -> dict:
"""Convert configuration to dictionary following Cognee pattern."""
return {
"default_format": self.default_format,
"enable_semantic_search": self.enable_semantic_search,
"registry_type": self.registry_type,
"rdf_provider_enabled": self.rdf_provider_enabled,
"json_provider_enabled": self.json_provider_enabled,
"csv_provider_enabled": self.csv_provider_enabled,
"cache_ontologies": self.cache_ontologies,
"max_cache_size": self.max_cache_size,
"similarity_threshold": self.similarity_threshold,
"medical_domain_enabled": self.medical_domain_enabled,
"legal_domain_enabled": self.legal_domain_enabled,
"code_domain_enabled": self.code_domain_enabled,
"ontology_data_directory": self.ontology_data_directory,
}
@lru_cache
def get_ontology_config():
"""Get ontology configuration instance following Cognee pattern."""
return OntologyConfig()
# Configuration helpers following existing patterns
def set_ontology_data_directory(directory: str):
"""Set ontology data directory."""
config = get_ontology_config()
config.ontology_data_directory = directory
logger.info(f"Set ontology data directory to: {directory}")
def enable_semantic_search(enabled: bool = True):
"""Enable or disable semantic search."""
config = get_ontology_config()
config.enable_semantic_search = enabled
logger.info(f"Semantic search {'enabled' if enabled else 'disabled'}")
def set_similarity_threshold(threshold: float):
"""Set similarity threshold for entity matching."""
config = get_ontology_config()
config.similarity_threshold = threshold
logger.info(f"Set similarity threshold to: {threshold}")

View file

@ -0,0 +1,364 @@
"""Configuration system for ontology integration."""
from typing import Dict, List, Optional, Any, Callable
from pathlib import Path
import json
import yaml
from cognee.shared.logging_utils import get_logger
from cognee.modules.ontology.interfaces import (
DataPointMapping,
GraphBindingConfig,
OntologyFormat,
OntologyScope,
)
from cognee.modules.ontology.pipeline_integration import (
create_medical_pipeline_config,
create_legal_pipeline_config,
create_code_pipeline_config,
)
logger = get_logger("OntologyConfiguration")
class OntologyConfiguration:
"""Central configuration management for ontology system."""
def __init__(self):
self.domain_configs: Dict[str, Dict[str, Any]] = {}
self.pipeline_configs: Dict[str, Dict[str, Any]] = {}
self.custom_resolvers: Dict[str, Callable] = {}
self.custom_binding_strategies: Dict[str, Callable] = {}
# Load default configurations
self._load_default_configs()
def _load_default_configs(self):
"""Load default domain configurations."""
self.domain_configs.update({
"medical": create_medical_pipeline_config(),
"legal": create_legal_pipeline_config(),
"code": create_code_pipeline_config(),
})
def register_domain_config(
self,
domain: str,
datapoint_mappings: List[DataPointMapping],
graph_binding: GraphBindingConfig,
task_configs: Optional[Dict[str, Dict[str, Any]]] = None
) -> None:
"""Register configuration for a domain."""
self.domain_configs[domain] = {
"domain": domain,
"datapoint_mappings": datapoint_mappings,
"graph_binding": graph_binding,
"task_configs": task_configs or {},
}
logger.info(f"Registered domain configuration: {domain}")
def register_pipeline_config(
self,
pipeline_name: str,
domain: str,
custom_mappings: Optional[List[DataPointMapping]] = None,
custom_binding: Optional[GraphBindingConfig] = None,
task_configs: Optional[Dict[str, Dict[str, Any]]] = None
) -> None:
"""Register configuration for a specific pipeline."""
# Start with domain config if available
base_config = self.domain_configs.get(domain, {})
pipeline_config = {
"pipeline_name": pipeline_name,
"domain": domain,
"datapoint_mappings": custom_mappings or base_config.get("datapoint_mappings", []),
"graph_binding": custom_binding or base_config.get("graph_binding", GraphBindingConfig()),
"task_configs": {**base_config.get("task_configs", {}), **(task_configs or {})},
}
self.pipeline_configs[pipeline_name] = pipeline_config
logger.info(f"Registered pipeline configuration: {pipeline_name}")
def get_domain_config(self, domain: str) -> Optional[Dict[str, Any]]:
"""Get configuration for a domain."""
return self.domain_configs.get(domain)
def get_pipeline_config(self, pipeline_name: str) -> Optional[Dict[str, Any]]:
"""Get configuration for a pipeline."""
return self.pipeline_configs.get(pipeline_name)
def load_from_file(self, config_file: str) -> None:
"""Load configuration from file."""
config_path = Path(config_file)
if not config_path.exists():
raise FileNotFoundError(f"Configuration file not found: {config_file}")
if config_path.suffix.lower() == '.json':
with open(config_path) as f:
config_data = json.load(f)
elif config_path.suffix.lower() in ['.yml', '.yaml']:
with open(config_path) as f:
config_data = yaml.safe_load(f)
else:
raise ValueError(f"Unsupported configuration file format: {config_path.suffix}")
self._parse_config_data(config_data)
logger.info(f"Loaded configuration from {config_file}")
def save_to_file(self, config_file: str, format: str = "json") -> None:
"""Save configuration to file."""
config_data = {
"domains": self._serialize_domain_configs(),
"pipelines": self._serialize_pipeline_configs(),
}
config_path = Path(config_file)
if format.lower() == "json":
with open(config_path, 'w') as f:
json.dump(config_data, f, indent=2)
elif format.lower() in ["yml", "yaml"]:
with open(config_path, 'w') as f:
yaml.dump(config_data, f, default_flow_style=False)
else:
raise ValueError(f"Unsupported format: {format}")
logger.info(f"Saved configuration to {config_file}")
def register_custom_resolver(
self,
resolver_name: str,
resolver_func: Callable
) -> None:
"""Register a custom DataPoint resolver."""
self.custom_resolvers[resolver_name] = resolver_func
logger.info(f"Registered custom resolver: {resolver_name}")
def register_custom_binding_strategy(
self,
strategy_name: str,
strategy_func: Callable
) -> None:
"""Register a custom graph binding strategy."""
self.custom_binding_strategies[strategy_name] = strategy_func
logger.info(f"Registered custom binding strategy: {strategy_name}")
def create_datapoint_mapping(
self,
ontology_node_type: str,
datapoint_class: str,
field_mappings: Optional[Dict[str, str]] = None,
custom_resolver: Optional[str] = None,
validation_rules: Optional[List[str]] = None
) -> DataPointMapping:
"""Create a DataPointMapping configuration."""
return DataPointMapping(
ontology_node_type=ontology_node_type,
datapoint_class=datapoint_class,
field_mappings=field_mappings or {},
custom_resolver=custom_resolver,
validation_rules=validation_rules or []
)
def create_graph_binding_config(
self,
node_type_mapping: Optional[Dict[str, str]] = None,
edge_type_mapping: Optional[Dict[str, str]] = None,
property_transformations: Optional[Dict[str, Callable]] = None,
custom_binding_strategy: Optional[str] = None
) -> GraphBindingConfig:
"""Create a GraphBindingConfig."""
return GraphBindingConfig(
node_type_mapping=node_type_mapping or {},
edge_type_mapping=edge_type_mapping or {},
property_transformations=property_transformations or {},
custom_binding_strategy=custom_binding_strategy
)
def _parse_config_data(self, config_data: Dict[str, Any]) -> None:
"""Parse configuration data from file."""
# Parse domain configurations
for domain, domain_config in config_data.get("domains", {}).items():
mappings = []
for mapping_data in domain_config.get("datapoint_mappings", []):
mapping = DataPointMapping(**mapping_data)
mappings.append(mapping)
binding_data = domain_config.get("graph_binding", {})
binding = GraphBindingConfig(**binding_data)
task_configs = domain_config.get("task_configs", {})
self.register_domain_config(domain, mappings, binding, task_configs)
# Parse pipeline configurations
for pipeline_name, pipeline_config in config_data.get("pipelines", {}).items():
domain = pipeline_config.get("domain")
custom_mappings = None
if "datapoint_mappings" in pipeline_config:
custom_mappings = []
for mapping_data in pipeline_config["datapoint_mappings"]:
mapping = DataPointMapping(**mapping_data)
custom_mappings.append(mapping)
custom_binding = None
if "graph_binding" in pipeline_config:
binding_data = pipeline_config["graph_binding"]
custom_binding = GraphBindingConfig(**binding_data)
task_configs = pipeline_config.get("task_configs", {})
self.register_pipeline_config(
pipeline_name, domain, custom_mappings, custom_binding, task_configs
)
def _serialize_domain_configs(self) -> Dict[str, Any]:
"""Serialize domain configurations for saving."""
serialized = {}
for domain, config in self.domain_configs.items():
serialized[domain] = {
"domain": config["domain"],
"datapoint_mappings": [
mapping.dict() for mapping in config["datapoint_mappings"]
],
"graph_binding": config["graph_binding"].dict(),
"task_configs": config["task_configs"],
}
return serialized
def _serialize_pipeline_configs(self) -> Dict[str, Any]:
"""Serialize pipeline configurations for saving."""
serialized = {}
for pipeline_name, config in self.pipeline_configs.items():
serialized[pipeline_name] = {
"pipeline_name": config["pipeline_name"],
"domain": config["domain"],
"datapoint_mappings": [
mapping.dict() for mapping in config["datapoint_mappings"]
],
"graph_binding": config["graph_binding"].dict(),
"task_configs": config["task_configs"],
}
return serialized
# Global configuration instance
_global_ontology_config = None
def get_ontology_config() -> OntologyConfiguration:
"""Get the global ontology configuration instance."""
global _global_ontology_config
if _global_ontology_config is None:
_global_ontology_config = OntologyConfiguration()
return _global_ontology_config
def configure_domain(
domain: str,
datapoint_mappings: List[DataPointMapping],
graph_binding: GraphBindingConfig,
task_configs: Optional[Dict[str, Dict[str, Any]]] = None
) -> None:
"""Configure ontology for a domain (convenience function)."""
config = get_ontology_config()
config.register_domain_config(domain, datapoint_mappings, graph_binding, task_configs)
def configure_pipeline(
pipeline_name: str,
domain: str,
custom_mappings: Optional[List[DataPointMapping]] = None,
custom_binding: Optional[GraphBindingConfig] = None,
task_configs: Optional[Dict[str, Dict[str, Any]]] = None
) -> None:
"""Configure ontology for a pipeline (convenience function)."""
config = get_ontology_config()
config.register_pipeline_config(
pipeline_name, domain, custom_mappings, custom_binding, task_configs
)
def load_ontology_config(config_file: str) -> None:
"""Load ontology configuration from file (convenience function)."""
config = get_ontology_config()
config.load_from_file(config_file)
# Example configuration templates
def create_example_config_file(output_file: str) -> None:
"""Create an example configuration file."""
example_config = {
"domains": {
"example_domain": {
"domain": "example_domain",
"datapoint_mappings": [
{
"ontology_node_type": "Entity",
"datapoint_class": "cognee.infrastructure.engine.models.DataPoint.DataPoint",
"field_mappings": {
"name": "name",
"description": "description",
"category": "entity_type"
},
"custom_resolver": None,
"validation_rules": ["required:name"]
}
],
"graph_binding": {
"node_type_mapping": {
"Entity": "domain_entity",
"Concept": "domain_concept"
},
"edge_type_mapping": {
"related_to": "domain_relation",
"part_of": "composition"
},
"property_transformations": {},
"custom_binding_strategy": None
},
"task_configs": {
"extract_graph_from_data": {
"enhance_with_entities": True,
"inject_datapoint_mappings": True,
"inject_graph_binding": True,
"target_entity_types": ["Entity", "Concept"]
}
}
}
},
"pipelines": {
"example_pipeline": {
"pipeline_name": "example_pipeline",
"domain": "example_domain",
"datapoint_mappings": [], # Use domain defaults
"graph_binding": {}, # Use domain defaults
"task_configs": {
"summarize_text": {
"enhance_with_entities": True,
"enable_ontology_validation": True
}
}
}
}
}
with open(output_file, 'w') as f:
json.dump(example_config, f, indent=2)
logger.info(f"Created example configuration file: {output_file}")

View file

@ -0,0 +1,552 @@
"""
Example usage of the refactored ontology system.
This demonstrates how to:
1. Set up ontology providers and registry
2. Configure domain-specific mappings
3. Integrate with pipelines
4. Use custom resolvers and binding strategies
"""
import asyncio
from pathlib import Path
from typing import List, Any, Dict
from cognee.modules.ontology.interfaces import (
OntologyContext,
OntologyScope,
DataPointMapping,
GraphBindingConfig,
)
from cognee.modules.ontology.manager import create_ontology_manager
from cognee.modules.ontology.registry import OntologyRegistry
from cognee.modules.ontology.providers import JSONOntologyProvider, RDFOntologyProvider
from cognee.modules.ontology.adapters import DefaultOntologyAdapter
from cognee.modules.ontology.resolvers import DefaultDataPointResolver, DomainSpecificResolver
from cognee.modules.ontology.binders import DefaultGraphBinder, DomainSpecificBinder
from cognee.modules.ontology.pipeline_integration import (
PipelineOntologyConfigurator,
OntologyInjector,
)
from cognee.modules.ontology.configuration import (
get_ontology_config,
configure_domain,
configure_pipeline,
)
from cognee.modules.pipelines.tasks.task import Task
from cognee.shared.logging_utils import get_logger
logger = get_logger("OntologyExample")
async def example_basic_setup():
"""Example: Basic ontology system setup."""
logger.info("=== Basic Ontology System Setup ===")
# 1. Create registry and providers
registry = OntologyRegistry()
providers = {
"json_provider": JSONOntologyProvider(),
"rdf_provider": RDFOntologyProvider(),
}
adapters = {
"default_adapter": DefaultOntologyAdapter(),
}
# 2. Create resolvers and binders
datapoint_resolver = DefaultDataPointResolver()
graph_binder = DefaultGraphBinder()
# 3. Create ontology manager
ontology_manager = await create_ontology_manager(
registry=registry,
providers=providers,
adapters=adapters,
datapoint_resolver=datapoint_resolver,
graph_binder=graph_binder,
)
logger.info("Ontology system initialized successfully")
return ontology_manager
async def example_load_ontologies(ontology_manager):
"""Example: Loading different types of ontologies."""
logger.info("=== Loading Ontologies ===")
# 1. Load JSON ontology
json_ontology_data = {
"id": "medical_ontology",
"name": "Medical Knowledge Base",
"description": "Basic medical ontology",
"nodes": [
{
"id": "disease_001",
"name": "Diabetes",
"type": "Disease",
"description": "A group of metabolic disorders",
"category": "medical_condition",
"properties": {"icd_code": "E11", "severity": "chronic"}
},
{
"id": "symptom_001",
"name": "Fatigue",
"type": "Symptom",
"description": "Extreme tiredness",
"category": "clinical_finding",
"properties": {"frequency": "common"}
}
],
"edges": [
{
"id": "rel_001",
"source": "disease_001",
"target": "symptom_001",
"relationship": "causes",
"properties": {"strength": "moderate"}
}
]
}
json_provider = ontology_manager.providers["json_provider"]
medical_ontology = await json_provider.load_ontology(json_ontology_data)
# 2. Register ontology in registry
ontology_id = await ontology_manager.registry.register_ontology(
medical_ontology,
OntologyScope.DOMAIN,
OntologyContext(domain="medical")
)
logger.info(f"Loaded and registered medical ontology: {ontology_id}")
return medical_ontology
async def example_configure_domain_mappings(ontology_manager):
"""Example: Configure domain-specific DataPoint mappings."""
logger.info("=== Configuring Domain Mappings ===")
# 1. Define DataPoint mappings for medical domain
medical_mappings = [
DataPointMapping(
ontology_node_type="Disease",
datapoint_class="cognee.infrastructure.engine.models.DataPoint.DataPoint",
field_mappings={
"name": "name",
"description": "description",
"icd_code": "medical_code",
"severity": "severity_level",
},
validation_rules=["required:name", "required:medical_code"]
),
DataPointMapping(
ontology_node_type="Symptom",
datapoint_class="cognee.infrastructure.engine.models.DataPoint.DataPoint",
field_mappings={
"name": "name",
"description": "description",
"frequency": "occurrence_rate",
}
)
]
# 2. Define graph binding configuration
medical_binding = GraphBindingConfig(
node_type_mapping={
"Disease": "medical_condition",
"Symptom": "clinical_finding",
"Treatment": "therapeutic_procedure",
},
edge_type_mapping={
"causes": "causality",
"treats": "therapeutic_relationship",
"associated_with": "clinical_association",
}
)
# 3. Configure the domain
ontology_manager.configure_datapoint_mapping("medical", medical_mappings)
ontology_manager.configure_graph_binding("medical", medical_binding)
logger.info("Configured medical domain mappings and bindings")
async def example_custom_resolver(ontology_manager):
"""Example: Register and use custom DataPoint resolver."""
logger.info("=== Custom DataPoint Resolver ===")
# 1. Define custom resolver function
async def medical_disease_resolver(ontology_node, mapping_config, context=None):
"""Custom resolver for medical disease entities."""
from cognee.infrastructure.engine.models.DataPoint import DataPoint
# Create DataPoint with medical-specific logic
datapoint = DataPoint(
id=ontology_node.id,
type="medical_disease",
ontology_valid=True,
metadata={
"type": "medical_entity",
"index_fields": ["name", "medical_code"],
"domain": "medical",
"ontology_node_id": ontology_node.id,
}
)
# Map ontology properties with domain-specific processing
datapoint.name = ontology_node.name
datapoint.description = ontology_node.description
datapoint.medical_code = ontology_node.properties.get("icd_code", "")
datapoint.severity_level = ontology_node.properties.get("severity", "unknown")
# Add computed properties
datapoint.medical_category = "disease"
datapoint.risk_level = "high" if datapoint.severity_level == "chronic" else "low"
logger.info(f"Custom resolver created DataPoint for disease: {datapoint.name}")
return datapoint
# 2. Register custom resolver
ontology_manager.register_custom_resolver("medical_disease_resolver", medical_disease_resolver)
# 3. Update mapping to use custom resolver
updated_mapping = DataPointMapping(
ontology_node_type="Disease",
datapoint_class="cognee.infrastructure.engine.models.DataPoint.DataPoint",
field_mappings={}, # Handled by custom resolver
custom_resolver="medical_disease_resolver"
)
ontology_manager.configure_datapoint_mapping("medical", [updated_mapping])
logger.info("Registered custom medical disease resolver")
async def example_custom_binding_strategy(ontology_manager):
"""Example: Register and use custom graph binding strategy."""
logger.info("=== Custom Graph Binding Strategy ===")
# 1. Define custom binding strategy
async def medical_graph_binding_strategy(ontology, binding_config, context=None):
"""Custom binding strategy for medical graphs."""
from datetime import datetime
graph_nodes = []
graph_edges = []
# Process nodes with medical-specific transformations
for node in ontology.nodes:
if node.type == "Disease":
# Create medical condition node
node_props = {
"id": node.id,
"name": node.name,
"type": "medical_condition",
"description": node.description,
"medical_code": node.properties.get("icd_code", ""),
"severity": node.properties.get("severity", "unknown"),
"category": "pathology",
"updated_at": datetime.now().isoformat(),
"ontology_source": True,
}
graph_nodes.append((node.id, node_props))
elif node.type == "Symptom":
# Create clinical finding node
node_props = {
"id": node.id,
"name": node.name,
"type": "clinical_finding",
"description": node.description,
"frequency": node.properties.get("frequency", "unknown"),
"category": "symptom",
"updated_at": datetime.now().isoformat(),
"ontology_source": True,
}
graph_nodes.append((node.id, node_props))
# Process edges with medical-specific relationships
for edge in ontology.edges:
edge_props = {
"source_node_id": edge.source_id,
"target_node_id": edge.target_id,
"relationship_name": "medical_" + edge.relationship_type,
"strength": edge.properties.get("strength", "unknown"),
"updated_at": datetime.now().isoformat(),
"ontology_source": True,
}
graph_edges.append((
edge.source_id,
edge.target_id,
"medical_" + edge.relationship_type,
edge_props
))
logger.info(f"Custom binding strategy created {len(graph_nodes)} nodes and {len(graph_edges)} edges")
return graph_nodes, graph_edges
# 2. Register custom binding strategy
ontology_manager.register_binding_strategy("medical_graph_binding", medical_graph_binding_strategy)
# 3. Update binding config to use custom strategy
updated_binding = GraphBindingConfig(
custom_binding_strategy="medical_graph_binding"
)
ontology_manager.configure_graph_binding("medical", updated_binding)
logger.info("Registered custom medical graph binding strategy")
async def example_pipeline_integration(ontology_manager):
"""Example: Integrate ontology with pipeline tasks."""
logger.info("=== Pipeline Integration ===")
# 1. Create pipeline configurator
pipeline_configurator = PipelineOntologyConfigurator(ontology_manager)
# 2. Configure medical pipeline
medical_mappings = ontology_manager.domain_datapoint_mappings.get("medical", [])
medical_binding = ontology_manager.domain_graph_bindings.get("medical")
task_configs = {
"extract_graph_from_data": {
"enhance_with_entities": True,
"inject_datapoint_mappings": True,
"inject_graph_binding": True,
"target_entity_types": ["Disease", "Symptom", "Treatment"],
},
"summarize_text": {
"enhance_with_entities": True,
"enable_ontology_validation": True,
"validation_threshold": 0.85,
}
}
pipeline_configurator.configure_pipeline(
pipeline_name="medical_cognify_pipeline",
domain="medical",
datapoint_mappings=medical_mappings,
graph_binding=medical_binding,
task_specific_configs=task_configs
)
# 3. Create ontology injector for the pipeline
injector = pipeline_configurator.create_ontology_injector("medical_cognify_pipeline")
# 4. Create sample task and inject ontology
async def sample_extract_task(data_chunks, **kwargs):
"""Sample extraction task."""
logger.info("Executing extract task with ontology context")
ontology_context = kwargs.get("ontology_context")
datapoint_mappings = kwargs.get("datapoint_mappings", [])
logger.info(f"Task received ontology context for domain: {ontology_context.domain}")
logger.info(f"Task has {len(datapoint_mappings)} DataPoint mappings available")
# Simulate task processing with ontology enhancement
enhanced_results = []
for chunk in data_chunks:
# In real implementation, this would use ontology for entity extraction
enhanced_results.append({
"chunk": chunk,
"ontology_enhanced": True,
"domain": ontology_context.domain,
"entities_found": ["Diabetes", "Fatigue"] # Simulated
})
return enhanced_results
sample_task = Task(sample_extract_task)
# 5. Get pipeline context and inject ontology
context = pipeline_configurator.get_pipeline_context(
"medical_cognify_pipeline",
user_id="user123",
dataset_id="medical_dataset_001"
)
enhanced_task = await injector.inject_into_task(sample_task, context)
# 6. Execute enhanced task
sample_data = ["Patient reports fatigue and frequent urination", "Diagnosis: Type 2 Diabetes"]
results = await enhanced_task.run(sample_data)
logger.info(f"Enhanced task completed with {len(results)} results")
return results
async def example_content_enhancement(ontology_manager):
"""Example: Enhance content with ontological information."""
logger.info("=== Content Enhancement ===")
# 1. Create context for medical domain
context = OntologyContext(
domain="medical",
pipeline_name="medical_analysis",
user_id="doctor123"
)
# 2. Sample medical text
medical_text = """
The patient presents with chronic fatigue and excessive thirst.
Blood glucose levels are elevated, indicating possible diabetes mellitus.
Further testing for HbA1c is recommended to confirm diagnosis.
"""
# 3. Enhance content with ontology
enhanced_content = await ontology_manager.enhance_with_ontology(medical_text, context)
logger.info("Enhanced content:")
logger.info(f" Original content: {enhanced_content['original_content'][:50]}...")
logger.info(f" Extracted entities: {len(enhanced_content['extracted_entities'])}")
logger.info(f" Semantic relationships: {len(enhanced_content['semantic_relationships'])}")
for entity in enhanced_content['extracted_entities']:
logger.info(f" Found entity: {entity['name']} (type: {entity['type']})")
for relationship in enhanced_content['semantic_relationships']:
logger.info(f" Relationship: {relationship['source']} -> {relationship['relationship']} -> {relationship['target']}")
return enhanced_content
async def example_datapoint_resolution(ontology_manager):
"""Example: Resolve ontology nodes to DataPoint instances."""
logger.info("=== DataPoint Resolution ===")
# 1. Get medical ontology
context = OntologyContext(domain="medical")
ontologies = await ontology_manager.get_applicable_ontologies(context)
if not ontologies:
logger.warning("No medical ontologies found")
return []
medical_ontology = ontologies[0]
# 2. Filter disease nodes
disease_nodes = [node for node in medical_ontology.nodes if node.type == "Disease"]
if not disease_nodes:
logger.warning("No disease nodes found in ontology")
return []
# 3. Resolve to DataPoint instances
datapoints = await ontology_manager.resolve_to_datapoints(disease_nodes, context)
logger.info(f"Resolved {len(datapoints)} disease nodes to DataPoints:")
for dp in datapoints:
logger.info(f" DataPoint: {dp.type} - {getattr(dp, 'name', 'Unnamed')}")
logger.info(f" ID: {dp.id}")
logger.info(f" Ontology valid: {dp.ontology_valid}")
if hasattr(dp, 'medical_code'):
logger.info(f" Medical code: {dp.medical_code}")
return datapoints
async def example_graph_binding(ontology_manager):
"""Example: Bind ontology to graph structure."""
logger.info("=== Graph Binding ===")
# 1. Get medical ontology
context = OntologyContext(domain="medical")
ontologies = await ontology_manager.get_applicable_ontologies(context)
if not ontologies:
logger.warning("No medical ontologies found")
return [], []
medical_ontology = ontologies[0]
# 2. Bind to graph structure
graph_nodes, graph_edges = await ontology_manager.bind_to_graph(medical_ontology, context)
logger.info(f"Bound ontology to graph structure:")
logger.info(f" Graph nodes: {len(graph_nodes)}")
logger.info(f" Graph edges: {len(graph_edges)}")
# Display sample nodes
for i, node in enumerate(graph_nodes[:3]): # Show first 3
if isinstance(node, tuple):
node_id, node_props = node
logger.info(f" Node {i+1}: {node_id} (type: {node_props.get('type', 'unknown')})")
else:
logger.info(f" Node {i+1}: {getattr(node, 'id', 'unknown')} (type: {getattr(node, 'type', 'unknown')})")
# Display sample edges
for i, edge in enumerate(graph_edges[:3]): # Show first 3
if isinstance(edge, tuple) and len(edge) >= 4:
source, target, rel_type, props = edge[:4]
logger.info(f" Edge {i+1}: {source} -> {rel_type} -> {target}")
else:
logger.info(f" Edge {i+1}: {edge}")
return graph_nodes, graph_edges
async def main():
"""Run all examples."""
logger.info("Starting ontology system examples...")
try:
# 1. Basic setup
ontology_manager = await example_basic_setup()
# 2. Load ontologies
medical_ontology = await example_load_ontologies(ontology_manager)
# 3. Configure domain mappings
await example_configure_domain_mappings(ontology_manager)
# 4. Custom resolver
await example_custom_resolver(ontology_manager)
# 5. Custom binding strategy
await example_custom_binding_strategy(ontology_manager)
# 6. Pipeline integration
pipeline_results = await example_pipeline_integration(ontology_manager)
# 7. Content enhancement
enhanced_content = await example_content_enhancement(ontology_manager)
# 8. DataPoint resolution
datapoints = await example_datapoint_resolution(ontology_manager)
# 9. Graph binding
graph_nodes, graph_edges = await example_graph_binding(ontology_manager)
logger.info("All examples completed successfully!")
# Print summary
logger.info("\n=== Summary ===")
logger.info(f"Loaded ontologies: 1")
logger.info(f"Pipeline results: {len(pipeline_results) if pipeline_results else 0}")
logger.info(f"Enhanced entities: {len(enhanced_content.get('extracted_entities', [])) if enhanced_content else 0}")
logger.info(f"DataPoints created: {len(datapoints) if datapoints else 0}")
logger.info(f"Graph nodes: {len(graph_nodes) if graph_nodes else 0}")
logger.info(f"Graph edges: {len(graph_edges) if graph_edges else 0}")
except Exception as e:
logger.error(f"Example execution failed: {e}", exc_info=True)
if __name__ == "__main__":
asyncio.run(main())

View file

@ -0,0 +1,351 @@
"""Abstract interfaces for ontology system components."""
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, Union, Type, Callable
from pydantic import BaseModel
from enum import Enum
class OntologyFormat(str, Enum):
"""Supported ontology formats."""
RDF_XML = "rdf_xml"
OWL = "owl"
JSON = "json"
CSV = "csv"
YAML = "yaml"
DATABASE = "database"
LLM_GENERATED = "llm_generated"
class OntologyScope(str, Enum):
"""Ontology scopes for different use cases."""
GLOBAL = "global" # Applies to all pipelines
DOMAIN = "domain" # Applies to specific domain (medical, legal, etc.)
PIPELINE = "pipeline" # Applies to specific pipeline type
USER = "user" # User-specific ontologies
DATASET = "dataset" # Dataset-specific ontologies
class OntologyNode(BaseModel):
"""Standard ontology node representation."""
id: str
name: str
type: str
description: Optional[str] = None
category: Optional[str] = None
properties: Dict[str, Any] = {}
labels: List[str] = []
class OntologyEdge(BaseModel):
"""Standard ontology edge representation."""
id: str
source_id: str
target_id: str
relationship_type: str
properties: Dict[str, Any] = {}
weight: Optional[float] = None
class OntologyGraph(BaseModel):
"""Standard ontology graph representation."""
id: str
name: str
description: Optional[str] = None
format: OntologyFormat
scope: OntologyScope
nodes: List[OntologyNode]
edges: List[OntologyEdge]
metadata: Dict[str, Any] = {}
class OntologyContext(BaseModel):
"""Context for ontology operations."""
user_id: Optional[str] = None
dataset_id: Optional[str] = None
pipeline_name: Optional[str] = None
domain: Optional[str] = None
custom_properties: Dict[str, Any] = {}
class DataPointMapping(BaseModel):
"""Mapping configuration between ontology and DataPoint."""
ontology_node_type: str
datapoint_class: str
field_mappings: Dict[str, str] = {} # ontology_field -> datapoint_field
custom_resolver: Optional[str] = None # Function name for custom resolution
validation_rules: List[str] = []
class GraphBindingConfig(BaseModel):
"""Configuration for how ontology binds to graph structures."""
node_type_mapping: Dict[str, str] = {} # ontology_type -> graph_node_type
edge_type_mapping: Dict[str, str] = {} # ontology_relation -> graph_edge_type
property_transformations: Dict[str, Callable[[Any], Any]] = {}
custom_binding_strategy: Optional[str] = None
class IOntologyProvider(ABC):
"""Abstract interface for ontology providers."""
@abstractmethod
async def load_ontology(
self,
source: Union[str, Dict[str, Any]],
context: Optional[OntologyContext] = None
) -> OntologyGraph:
"""Load ontology from source."""
pass
@abstractmethod
async def save_ontology(
self,
ontology: OntologyGraph,
destination: str,
context: Optional[OntologyContext] = None
) -> bool:
"""Save ontology to destination."""
pass
@abstractmethod
def supports_format(self, format: OntologyFormat) -> bool:
"""Check if provider supports given format."""
pass
@abstractmethod
async def validate_ontology(self, ontology: OntologyGraph) -> bool:
"""Validate ontology structure."""
pass
class IOntologyAdapter(ABC):
"""Abstract interface for ontology adapters."""
@abstractmethod
async def find_matching_nodes(
self,
query_text: str,
ontology: OntologyGraph,
similarity_threshold: float = 0.8
) -> List[OntologyNode]:
"""Find nodes matching query text."""
pass
@abstractmethod
async def get_node_relationships(
self,
node_id: str,
ontology: OntologyGraph,
max_depth: int = 2
) -> List[OntologyEdge]:
"""Get relationships for a specific node."""
pass
@abstractmethod
async def expand_subgraph(
self,
node_ids: List[str],
ontology: OntologyGraph,
directed: bool = True
) -> Tuple[List[OntologyNode], List[OntologyEdge]]:
"""Expand subgraph around given nodes."""
pass
@abstractmethod
async def merge_ontologies(
self,
ontologies: List[OntologyGraph]
) -> OntologyGraph:
"""Merge multiple ontologies."""
pass
class IOntologyRegistry(ABC):
"""Abstract interface for ontology registry."""
@abstractmethod
async def register_ontology(
self,
ontology: OntologyGraph,
scope: OntologyScope,
context: Optional[OntologyContext] = None
) -> str:
"""Register an ontology."""
pass
@abstractmethod
async def get_ontology(
self,
ontology_id: str,
context: Optional[OntologyContext] = None
) -> Optional[OntologyGraph]:
"""Get ontology by ID."""
pass
@abstractmethod
async def find_ontologies(
self,
scope: Optional[OntologyScope] = None,
domain: Optional[str] = None,
context: Optional[OntologyContext] = None
) -> List[OntologyGraph]:
"""Find ontologies matching criteria."""
pass
@abstractmethod
async def unregister_ontology(
self,
ontology_id: str,
context: Optional[OntologyContext] = None
) -> bool:
"""Unregister an ontology."""
pass
class IDataPointResolver(ABC):
"""Abstract interface for resolving ontology to DataPoint instances."""
@abstractmethod
async def resolve_to_datapoint(
self,
ontology_node: OntologyNode,
mapping_config: DataPointMapping,
context: Optional[OntologyContext] = None
) -> Any: # Should be DataPoint but avoiding circular import
"""Resolve ontology node to DataPoint instance."""
pass
@abstractmethod
async def resolve_from_datapoint(
self,
datapoint: Any, # DataPoint
mapping_config: DataPointMapping,
context: Optional[OntologyContext] = None
) -> OntologyNode:
"""Resolve DataPoint instance to ontology node."""
pass
@abstractmethod
async def validate_mapping(
self,
mapping_config: DataPointMapping
) -> bool:
"""Validate mapping configuration."""
pass
@abstractmethod
def register_custom_resolver(
self,
resolver_name: str,
resolver_func: Callable[[OntologyNode, DataPointMapping], Any]
) -> None:
"""Register a custom resolver function."""
pass
class IGraphBinder(ABC):
"""Abstract interface for binding ontology to graph structures."""
@abstractmethod
async def bind_ontology_to_graph(
self,
ontology: OntologyGraph,
binding_config: GraphBindingConfig,
context: Optional[OntologyContext] = None
) -> Tuple[List[Any], List[Any]]: # (graph_nodes, graph_edges)
"""Bind ontology to graph structure."""
pass
@abstractmethod
async def transform_node_properties(
self,
node: OntologyNode,
transformations: Dict[str, Callable[[Any], Any]]
) -> Dict[str, Any]:
"""Transform node properties according to binding config."""
pass
@abstractmethod
async def transform_edge_properties(
self,
edge: OntologyEdge,
transformations: Dict[str, Callable[[Any], Any]]
) -> Dict[str, Any]:
"""Transform edge properties according to binding config."""
pass
@abstractmethod
def register_binding_strategy(
self,
strategy_name: str,
strategy_func: Callable[[OntologyGraph, GraphBindingConfig], Tuple[List[Any], List[Any]]]
) -> None:
"""Register a custom binding strategy."""
pass
class IOntologyManager(ABC):
"""Abstract interface for ontology manager."""
@abstractmethod
async def get_applicable_ontologies(
self,
context: OntologyContext
) -> List[OntologyGraph]:
"""Get ontologies applicable to given context."""
pass
@abstractmethod
async def enhance_with_ontology(
self,
content: str,
context: OntologyContext
) -> Dict[str, Any]:
"""Enhance content with ontological information."""
pass
@abstractmethod
async def inject_ontology_into_task(
self,
task_name: str,
task_params: Dict[str, Any],
context: OntologyContext
) -> Dict[str, Any]:
"""Inject ontological context into task parameters."""
pass
@abstractmethod
async def resolve_to_datapoints(
self,
ontology_nodes: List[OntologyNode],
context: OntologyContext
) -> List[Any]: # List[DataPoint]
"""Resolve ontology nodes to DataPoint instances."""
pass
@abstractmethod
async def bind_to_graph(
self,
ontology: OntologyGraph,
context: OntologyContext
) -> Tuple[List[Any], List[Any]]: # (graph_nodes, graph_edges)
"""Bind ontology to graph structure using configured binding."""
pass
@abstractmethod
def configure_datapoint_mapping(
self,
domain: str,
mappings: List[DataPointMapping]
) -> None:
"""Configure DataPoint mappings for a domain."""
pass
@abstractmethod
def configure_graph_binding(
self,
domain: str,
binding_config: GraphBindingConfig
) -> None:
"""Configure graph binding for a domain."""
pass

View file

@ -0,0 +1,363 @@
"""Core ontology manager implementation."""
import importlib
from typing import Any, Dict, List, Optional, Tuple, Callable, Type
from uuid import uuid4
from cognee.shared.logging_utils import get_logger
from cognee.modules.ontology.interfaces import (
IOntologyManager,
IOntologyRegistry,
IOntologyProvider,
IOntologyAdapter,
IDataPointResolver,
IGraphBinder,
OntologyGraph,
OntologyNode,
OntologyContext,
DataPointMapping,
GraphBindingConfig,
OntologyScope,
)
logger = get_logger("OntologyManager")
class OntologyManager(IOntologyManager):
"""Core implementation of ontology management."""
def __init__(
self,
registry: IOntologyRegistry,
providers: Dict[str, IOntologyProvider],
adapters: Dict[str, IOntologyAdapter],
datapoint_resolver: IDataPointResolver,
graph_binder: IGraphBinder,
):
self.registry = registry
self.providers = providers
self.adapters = adapters
self.datapoint_resolver = datapoint_resolver
self.graph_binder = graph_binder
# Domain-specific configurations
self.domain_datapoint_mappings: Dict[str, List[DataPointMapping]] = {}
self.domain_graph_bindings: Dict[str, GraphBindingConfig] = {}
# Custom resolvers and binding strategies
self.custom_resolvers: Dict[str, Callable] = {}
self.custom_binding_strategies: Dict[str, Callable] = {}
async def get_applicable_ontologies(
self,
context: OntologyContext
) -> List[OntologyGraph]:
"""Get ontologies applicable to given context."""
ontologies = []
# Get global ontologies
global_ontologies = await self.registry.find_ontologies(
scope=OntologyScope.GLOBAL,
context=context
)
ontologies.extend(global_ontologies)
# Get domain-specific ontologies
if context.domain:
domain_ontologies = await self.registry.find_ontologies(
scope=OntologyScope.DOMAIN,
domain=context.domain,
context=context
)
ontologies.extend(domain_ontologies)
# Get pipeline-specific ontologies
if context.pipeline_name:
pipeline_ontologies = await self.registry.find_ontologies(
scope=OntologyScope.PIPELINE,
context=context
)
ontologies.extend(pipeline_ontologies)
# Get user-specific ontologies
if context.user_id:
user_ontologies = await self.registry.find_ontologies(
scope=OntologyScope.USER,
context=context
)
ontologies.extend(user_ontologies)
# Get dataset-specific ontologies
if context.dataset_id:
dataset_ontologies = await self.registry.find_ontologies(
scope=OntologyScope.DATASET,
context=context
)
ontologies.extend(dataset_ontologies)
# Remove duplicates and prioritize by scope
unique_ontologies = self._prioritize_ontologies(ontologies)
logger.info(f"Found {len(unique_ontologies)} applicable ontologies for context")
return unique_ontologies
async def enhance_with_ontology(
self,
content: str,
context: OntologyContext
) -> Dict[str, Any]:
"""Enhance content with ontological information."""
applicable_ontologies = await self.get_applicable_ontologies(context)
enhanced_data = {
"original_content": content,
"ontological_annotations": [],
"extracted_entities": [],
"semantic_relationships": [],
}
for ontology in applicable_ontologies:
adapter_name = self._get_adapter_for_ontology(ontology)
if adapter_name not in self.adapters:
logger.warning(f"No adapter found for ontology {ontology.id}")
continue
adapter = self.adapters[adapter_name]
# Find matching nodes in content
matching_nodes = await adapter.find_matching_nodes(
content, ontology, similarity_threshold=0.7
)
for node in matching_nodes:
enhanced_data["extracted_entities"].append({
"node_id": node.id,
"name": node.name,
"type": node.type,
"category": node.category,
"ontology_id": ontology.id,
"confidence": 0.8, # This should come from the adapter
})
# Get relationships for this node
relationships = await adapter.get_node_relationships(
node.id, ontology, max_depth=1
)
for rel in relationships:
enhanced_data["semantic_relationships"].append({
"source": rel.source_id,
"target": rel.target_id,
"relationship": rel.relationship_type,
"ontology_id": ontology.id,
})
return enhanced_data
async def inject_ontology_into_task(
self,
task_name: str,
task_params: Dict[str, Any],
context: OntologyContext
) -> Dict[str, Any]:
"""Inject ontological context into task parameters."""
applicable_ontologies = await self.get_applicable_ontologies(context)
# Merge all applicable ontologies
if len(applicable_ontologies) > 1:
primary_adapter = list(self.adapters.values())[0] # Use first available adapter
merged_ontology = await primary_adapter.merge_ontologies(applicable_ontologies)
ontologies_to_inject = [merged_ontology]
else:
ontologies_to_inject = applicable_ontologies
# Inject ontology-specific parameters
enhanced_params = task_params.copy()
enhanced_params["ontology_context"] = {
"ontologies": [ont.id for ont in ontologies_to_inject],
"domain": context.domain,
"pipeline_name": context.pipeline_name,
}
# Add ontology-aware configurations
if context.domain and context.domain in self.domain_datapoint_mappings:
enhanced_params["datapoint_mappings"] = self.domain_datapoint_mappings[context.domain]
if context.domain and context.domain in self.domain_graph_bindings:
enhanced_params["graph_binding_config"] = self.domain_graph_bindings[context.domain]
return enhanced_params
async def resolve_to_datapoints(
self,
ontology_nodes: List[OntologyNode],
context: OntologyContext
) -> List[Any]: # List[DataPoint]
"""Resolve ontology nodes to DataPoint instances."""
datapoints = []
# Get domain-specific mappings
mappings = self.domain_datapoint_mappings.get(context.domain, [])
if not mappings:
logger.warning(f"No DataPoint mappings configured for domain: {context.domain}")
return datapoints
for node in ontology_nodes:
# Find appropriate mapping for this node type
mapping = self._find_mapping_for_node(node, mappings)
if not mapping:
logger.debug(f"No mapping found for node type: {node.type}")
continue
try:
datapoint = await self.datapoint_resolver.resolve_to_datapoint(
node, mapping, context
)
if datapoint:
datapoints.append(datapoint)
except Exception as e:
logger.error(f"Failed to resolve node {node.id} to DataPoint: {e}")
logger.info(f"Resolved {len(datapoints)} DataPoints from {len(ontology_nodes)} nodes")
return datapoints
async def bind_to_graph(
self,
ontology: OntologyGraph,
context: OntologyContext
) -> Tuple[List[Any], List[Any]]: # (graph_nodes, graph_edges)
"""Bind ontology to graph structure using configured binding."""
binding_config = self.domain_graph_bindings.get(context.domain)
if not binding_config:
logger.warning(f"No graph binding configured for domain: {context.domain}")
# Use default binding
binding_config = GraphBindingConfig()
return await self.graph_binder.bind_ontology_to_graph(
ontology, binding_config, context
)
def configure_datapoint_mapping(
self,
domain: str,
mappings: List[DataPointMapping]
) -> None:
"""Configure DataPoint mappings for a domain."""
self.domain_datapoint_mappings[domain] = mappings
logger.info(f"Configured {len(mappings)} DataPoint mappings for domain: {domain}")
def configure_graph_binding(
self,
domain: str,
binding_config: GraphBindingConfig
) -> None:
"""Configure graph binding for a domain."""
self.domain_graph_bindings[domain] = binding_config
logger.info(f"Configured graph binding for domain: {domain}")
def register_custom_resolver(
self,
resolver_name: str,
resolver_func: Callable
) -> None:
"""Register a custom DataPoint resolver."""
self.custom_resolvers[resolver_name] = resolver_func
self.datapoint_resolver.register_custom_resolver(resolver_name, resolver_func)
logger.info(f"Registered custom resolver: {resolver_name}")
def register_binding_strategy(
self,
strategy_name: str,
strategy_func: Callable
) -> None:
"""Register a custom graph binding strategy."""
self.custom_binding_strategies[strategy_name] = strategy_func
self.graph_binder.register_binding_strategy(strategy_name, strategy_func)
logger.info(f"Registered custom binding strategy: {strategy_name}")
def _prioritize_ontologies(self, ontologies: List[OntologyGraph]) -> List[OntologyGraph]:
"""Prioritize ontologies by scope (dataset > user > pipeline > domain > global)."""
priority_order = [
OntologyScope.DATASET,
OntologyScope.USER,
OntologyScope.PIPELINE,
OntologyScope.DOMAIN,
OntologyScope.GLOBAL,
]
seen_ids = set()
prioritized = []
for scope in priority_order:
for ontology in ontologies:
if ontology.scope == scope and ontology.id not in seen_ids:
prioritized.append(ontology)
seen_ids.add(ontology.id)
return prioritized
def _get_adapter_for_ontology(self, ontology: OntologyGraph) -> str:
"""Get the appropriate adapter name for an ontology."""
# This could be made more sophisticated with adapter selection logic
format_to_adapter = {
"rdf_xml": "rdf_adapter",
"owl": "rdf_adapter",
"json": "json_adapter",
"llm_generated": "llm_adapter",
}
return format_to_adapter.get(ontology.format.value, "default_adapter")
def _find_mapping_for_node(
self,
node: OntologyNode,
mappings: List[DataPointMapping]
) -> Optional[DataPointMapping]:
"""Find the appropriate DataPoint mapping for a node."""
for mapping in mappings:
if mapping.ontology_node_type == node.type:
return mapping
return None
async def create_ontology_manager(
registry: IOntologyRegistry,
providers: Optional[Dict[str, IOntologyProvider]] = None,
adapters: Optional[Dict[str, IOntologyAdapter]] = None,
datapoint_resolver: Optional[IDataPointResolver] = None,
graph_binder: Optional[IGraphBinder] = None,
) -> OntologyManager:
"""Factory function to create an OntologyManager with default implementations."""
# Import default implementations
from cognee.modules.ontology.registry import OntologyRegistry
from cognee.modules.ontology.providers import RDFOntologyProvider, JSONOntologyProvider
from cognee.modules.ontology.adapters import DefaultOntologyAdapter
from cognee.modules.ontology.resolvers import DefaultDataPointResolver
from cognee.modules.ontology.binders import DefaultGraphBinder
if providers is None:
providers = {
"rdf_provider": RDFOntologyProvider(),
"json_provider": JSONOntologyProvider(),
}
if adapters is None:
adapters = {
"default_adapter": DefaultOntologyAdapter(),
"rdf_adapter": DefaultOntologyAdapter(),
"json_adapter": DefaultOntologyAdapter(),
}
if datapoint_resolver is None:
datapoint_resolver = DefaultDataPointResolver()
if graph_binder is None:
graph_binder = DefaultGraphBinder()
return OntologyManager(
registry=registry,
providers=providers,
adapters=adapters,
datapoint_resolver=datapoint_resolver,
graph_binder=graph_binder,
)

View file

@ -0,0 +1,27 @@
# Ontology CRUD operations following Cognee methods pattern
# Create operations
from .create_ontology import create_ontology
from .create_ontology_from_file import create_ontology_from_file
# Read operations
from .get_ontology import get_ontology
from .get_ontologies import get_ontologies
from .get_ontology_by_domain import get_ontology_by_domain
from .load_ontology import load_ontology
# Update operations
from .update_ontology import update_ontology
from .register_ontology import register_ontology
# Delete operations
from .delete_ontology import delete_ontology
from .unregister_ontology import unregister_ontology
# Search operations
from .search_ontologies import search_ontologies
from .find_matching_nodes import find_matching_nodes
# Utility operations
from .validate_ontology import validate_ontology
from .merge_ontologies import merge_ontologies

View file

@ -0,0 +1,193 @@
"""Create ontology method following Cognee patterns."""
from typing import Dict, Any, Optional
from uuid import uuid4
from datetime import datetime
from cognee.shared.logging_utils import get_logger
from cognee.modules.ontology.interfaces import (
OntologyGraph,
OntologyNode,
OntologyEdge,
OntologyScope,
OntologyFormat,
OntologyContext,
)
from cognee.modules.ontology.config import get_ontology_config
from cognee.modules.users.models import User
logger = get_logger("ontology.create")
async def create_ontology(
ontology_data: Dict[str, Any],
user: User,
scope: OntologyScope = OntologyScope.USER,
context: Optional[OntologyContext] = None
) -> OntologyGraph:
"""
Create a new ontology from provided data.
Args:
ontology_data: Dictionary containing ontology structure
user: User creating the ontology
scope: Scope for the ontology (user, domain, global, etc.)
context: Optional context for the ontology
Returns:
Created OntologyGraph instance
Raises:
ValueError: If ontology_data is invalid
RuntimeError: If ontology creation fails
"""
try:
config = get_ontology_config()
# Validate required fields
if "nodes" not in ontology_data:
raise ValueError("Ontology data must contain 'nodes' field")
# Extract basic information
ontology_id = ontology_data.get("id", str(uuid4()))
ontology_name = ontology_data.get("name", f"ontology_{ontology_id}")
description = ontology_data.get("description", "")
format_type = OntologyFormat(ontology_data.get("format", config.default_format))
# Parse nodes
nodes = []
for node_data in ontology_data["nodes"]:
if not isinstance(node_data, dict):
logger.warning(f"Skipping invalid node data: {node_data}")
continue
try:
node = OntologyNode(
id=node_data.get("id", str(uuid4())),
name=node_data.get("name", "unnamed_node"),
type=node_data.get("type", "entity"),
description=node_data.get("description", ""),
category=node_data.get("category", "general"),
properties=node_data.get("properties", {})
)
nodes.append(node)
except Exception as e:
logger.warning(f"Failed to parse node {node_data.get('id', 'unknown')}: {e}")
continue
# Parse edges
edges = []
for edge_data in ontology_data.get("edges", []):
if not isinstance(edge_data, dict):
logger.warning(f"Skipping invalid edge data: {edge_data}")
continue
try:
edge = OntologyEdge(
id=edge_data.get("id", str(uuid4())),
source_id=edge_data["source"],
target_id=edge_data["target"],
relationship_type=edge_data.get("relationship", "related_to"),
properties=edge_data.get("properties", {}),
weight=edge_data.get("weight")
)
edges.append(edge)
except KeyError as e:
logger.warning(f"Edge missing required field {e}: {edge_data}")
continue
except Exception as e:
logger.warning(f"Failed to parse edge: {e}")
continue
# Create metadata
metadata = ontology_data.get("metadata", {})
metadata.update({
"created_by": str(user.id),
"created_at": datetime.now().isoformat(),
"scope": scope.value,
"format": format_type.value,
})
if context:
metadata.update({
"domain": context.domain,
"pipeline_name": context.pipeline_name,
"dataset_id": context.dataset_id,
})
# Create ontology
ontology = OntologyGraph(
id=ontology_id,
name=ontology_name,
description=description,
format=format_type,
scope=scope,
nodes=nodes,
edges=edges,
metadata=metadata
)
logger.info(
f"Created ontology '{ontology_name}' with {len(nodes)} nodes "
f"and {len(edges)} edges for user {user.id}"
)
return ontology
except ValueError:
# Re-raise validation errors
raise
except Exception as e:
error_msg = f"Failed to create ontology: {e}"
logger.error(error_msg)
raise RuntimeError(error_msg) from e
async def create_empty_ontology(
name: str,
user: User,
scope: OntologyScope = OntologyScope.USER,
domain: Optional[str] = None,
description: str = ""
) -> OntologyGraph:
"""
Create an empty ontology with basic structure.
Args:
name: Name for the ontology
user: User creating the ontology
scope: Scope for the ontology
domain: Optional domain for the ontology
description: Optional description
Returns:
Empty OntologyGraph instance
"""
config = get_ontology_config()
ontology_id = str(uuid4())
metadata = {
"created_by": str(user.id),
"created_at": datetime.now().isoformat(),
"scope": scope.value,
"format": config.default_format,
}
if domain:
metadata["domain"] = domain
ontology = OntologyGraph(
id=ontology_id,
name=name,
description=description,
format=OntologyFormat(config.default_format),
scope=scope,
nodes=[],
edges=[],
metadata=metadata
)
logger.info(f"Created empty ontology '{name}' for user {user.id}")
return ontology

View file

@ -0,0 +1,56 @@
"""Delete ontology method following Cognee patterns."""
from typing import Optional
from cognee.shared.logging_utils import get_logger
from cognee.modules.ontology.interfaces import OntologyContext
from cognee.modules.ontology.registry import OntologyRegistry
from cognee.modules.users.models import User
logger = get_logger("ontology.delete")
async def delete_ontology(
ontology_id: str,
user: User,
context: Optional[OntologyContext] = None
) -> bool:
"""
Delete ontology following Cognee delete patterns.
Args:
ontology_id: ID of the ontology to delete
user: User requesting deletion
context: Optional context for access control
Returns:
True if deletion successful, False otherwise
"""
try:
# This would use dependency injection in real implementation
registry = OntologyRegistry()
# Get ontology first to check existence and permissions
ontology = await registry.get_ontology(ontology_id, context)
if ontology is None:
logger.warning(f"Ontology {ontology_id} not found for deletion")
return False
# TODO: Add access control check
# For now, assume user has access
# Perform deletion
success = await registry.unregister_ontology(ontology_id, context)
if success:
logger.info(f"Deleted ontology {ontology_id} for user {user.id}")
else:
logger.error(f"Failed to delete ontology {ontology_id}")
return success
except Exception as e:
logger.error(f"Error deleting ontology {ontology_id}: {e}")
return False

View file

@ -0,0 +1,48 @@
"""Get ontology method following Cognee patterns."""
from typing import Optional
from cognee.shared.logging_utils import get_logger
from cognee.modules.ontology.interfaces import OntologyGraph, OntologyContext
from cognee.modules.ontology.registry import OntologyRegistry
from cognee.modules.users.models import User
logger = get_logger("ontology.get")
async def get_ontology(
ontology_id: str,
user: User,
context: Optional[OntologyContext] = None
) -> Optional[OntologyGraph]:
"""
Get ontology by ID following Cognee get patterns.
Args:
ontology_id: ID of the ontology to retrieve
user: User requesting the ontology
context: Optional context for access control
Returns:
OntologyGraph if found and accessible, None otherwise
"""
try:
# This would use dependency injection in real implementation
registry = OntologyRegistry()
ontology = await registry.get_ontology(ontology_id, context)
if ontology is None:
logger.info(f"Ontology {ontology_id} not found")
return None
# TODO: Add access control check based on user and ontology ownership
# For now, assume user has access
logger.info(f"Retrieved ontology {ontology_id} for user {user.id}")
return ontology
except Exception as e:
logger.error(f"Failed to get ontology {ontology_id}: {e}")
return None

View file

@ -0,0 +1,83 @@
"""Load ontology method following Cognee patterns."""
from typing import Union, Dict, Any, Optional
from cognee.shared.logging_utils import get_logger
from cognee.modules.ontology.interfaces import OntologyGraph, OntologyContext
from cognee.modules.ontology.providers import JSONOntologyProvider, RDFOntologyProvider, CSVOntologyProvider
from cognee.modules.ontology.config import get_ontology_config
from cognee.modules.users.models import User
logger = get_logger("ontology.load")
async def load_ontology(
source: Union[str, Dict[str, Any]],
user: User,
context: Optional[OntologyContext] = None
) -> Optional[OntologyGraph]:
"""
Load ontology from various sources following Cognee patterns.
Args:
source: File path, URL, or data dictionary
user: User loading the ontology
context: Optional context for the ontology
Returns:
Loaded OntologyGraph or None if loading failed
"""
try:
config = get_ontology_config()
# Determine provider based on source
provider = None
if isinstance(source, str):
# File path or URL
if source.endswith(('.owl', '.rdf', '.xml')) and config.rdf_provider_enabled:
provider = RDFOntologyProvider()
if not provider.available:
logger.warning("RDF provider not available, falling back to JSON")
provider = None
elif source.endswith('.json') and config.json_provider_enabled:
provider = JSONOntologyProvider()
elif source.endswith('.csv') and config.csv_provider_enabled:
provider = CSVOntologyProvider()
else:
# Default to JSON provider
provider = JSONOntologyProvider()
else:
# Dictionary data - use JSON provider
provider = JSONOntologyProvider()
if provider is None:
logger.error(f"No suitable provider found for source: {source}")
return None
# Load ontology
ontology = await provider.load_ontology(source, context)
# Validate ontology
if not await provider.validate_ontology(ontology):
logger.error(f"Ontology validation failed for source: {source}")
return None
# Add metadata about loading
ontology.metadata.update({
"loaded_by": str(user.id),
"source": str(source),
"provider": provider.__class__.__name__,
})
logger.info(
f"Loaded ontology '{ontology.name}' with {len(ontology.nodes)} nodes "
f"from source: {source}"
)
return ontology
except Exception as e:
logger.error(f"Failed to load ontology from {source}: {e}")
return None

View file

@ -0,0 +1,73 @@
"""Register ontology method following Cognee patterns."""
from typing import Optional
from cognee.shared.logging_utils import get_logger
from cognee.modules.ontology.interfaces import OntologyGraph, OntologyScope, OntologyContext
from cognee.modules.ontology.registry import OntologyRegistry
from cognee.modules.users.models import User
logger = get_logger("ontology.register")
async def register_ontology(
ontology: OntologyGraph,
user: User,
scope: OntologyScope = OntologyScope.USER,
context: Optional[OntologyContext] = None
) -> str:
"""
Register ontology in the registry following Cognee patterns.
Args:
ontology: OntologyGraph to register
user: User registering the ontology
scope: Scope for the ontology
context: Optional context for registration
Returns:
Ontology ID if registration successful
Raises:
ValueError: If ontology is invalid
RuntimeError: If registration fails
"""
try:
# Validate ontology
if not ontology.nodes:
raise ValueError("Cannot register empty ontology")
# Update metadata
ontology.metadata.update({
"registered_by": str(user.id),
"scope": scope.value,
})
if context:
ontology.metadata.update({
"domain": context.domain,
"pipeline_name": context.pipeline_name,
"dataset_id": context.dataset_id,
})
# This would use dependency injection in real implementation
registry = OntologyRegistry()
# Register in registry
ontology_id = await registry.register_ontology(ontology, scope, context)
logger.info(
f"Registered ontology '{ontology.name}' (ID: {ontology_id}) "
f"with scope {scope.value} for user {user.id}"
)
return ontology_id
except ValueError:
# Re-raise validation errors
raise
except Exception as e:
error_msg = f"Failed to register ontology '{ontology.name}': {e}"
logger.error(error_msg)
raise RuntimeError(error_msg) from e

View file

@ -0,0 +1,496 @@
"""Pipeline integration for ontology system."""
from typing import Any, Dict, List, Optional, Type, Union
import inspect
from cognee.shared.logging_utils import get_logger
from cognee.modules.pipelines.tasks.task import Task
from cognee.modules.ontology.interfaces import (
IOntologyManager,
OntologyContext,
DataPointMapping,
GraphBindingConfig,
)
logger = get_logger("OntologyPipelineIntegration")
class OntologyInjector:
"""Handles injection of ontology context into pipeline tasks."""
def __init__(self, ontology_manager: IOntologyManager):
self.ontology_manager = ontology_manager
self.task_ontology_configs: Dict[str, Dict[str, Any]] = {}
def configure_task_ontology(
self,
task_name: str,
ontology_config: Dict[str, Any]
) -> None:
"""Configure ontology settings for a specific task."""
self.task_ontology_configs[task_name] = ontology_config
logger.info(f"Configured ontology for task: {task_name}")
async def inject_into_task(
self,
task: Task,
context: OntologyContext
) -> Task:
"""Inject ontology context into a task."""
task_name = self._get_task_name(task)
# Check if task has ontology configuration
if task_name not in self.task_ontology_configs:
# No specific configuration, use default behavior
return await self._apply_default_injection(task, context)
# Apply configured ontology injection
config = self.task_ontology_configs[task_name]
return await self._apply_configured_injection(task, context, config)
async def enhance_task_params(
self,
task_params: Dict[str, Any],
context: OntologyContext,
task_name: Optional[str] = None
) -> Dict[str, Any]:
"""Enhance task parameters with ontological context."""
enhanced_params = await self.ontology_manager.inject_ontology_into_task(
task_name or "unknown_task",
task_params,
context
)
return enhanced_params
def _get_task_name(self, task: Task) -> str:
"""Extract task name from Task object."""
if hasattr(task.executable, '__name__'):
return task.executable.__name__
elif hasattr(task.executable, '__class__'):
return task.executable.__class__.__name__
else:
return str(task.executable)
async def _apply_default_injection(
self,
task: Task,
context: OntologyContext
) -> Task:
"""Apply default ontology injection to task."""
# Get applicable ontologies
ontologies = await self.ontology_manager.get_applicable_ontologies(context)
if not ontologies:
logger.debug("No applicable ontologies found for task")
return task
# Enhance task parameters
enhanced_params = task.default_params.copy()
enhanced_params["kwargs"]["ontology_context"] = context
enhanced_params["kwargs"]["available_ontologies"] = [ont.id for ont in ontologies]
# Create new task with enhanced parameters
enhanced_task = Task(
task.executable,
*enhanced_params["args"],
task_config=task.task_config,
**enhanced_params["kwargs"]
)
return enhanced_task
async def _apply_configured_injection(
self,
task: Task,
context: OntologyContext,
config: Dict[str, Any]
) -> Task:
"""Apply configured ontology injection to task."""
enhanced_params = task.default_params.copy()
# Apply ontology-specific enhancements based on config
if config.get("enhance_with_entities", False):
enhanced_params = await self._inject_entity_enhancement(
enhanced_params, context, config
)
if config.get("inject_datapoint_mappings", False):
enhanced_params = await self._inject_datapoint_mappings(
enhanced_params, context, config
)
if config.get("inject_graph_binding", False):
enhanced_params = await self._inject_graph_binding(
enhanced_params, context, config
)
if config.get("enable_ontology_validation", False):
enhanced_params = await self._inject_validation_config(
enhanced_params, context, config
)
# Create new task with enhanced parameters
enhanced_task = Task(
task.executable,
*enhanced_params["args"],
task_config=task.task_config,
**enhanced_params["kwargs"]
)
return enhanced_task
async def _inject_entity_enhancement(
self,
params: Dict[str, Any],
context: OntologyContext,
config: Dict[str, Any]
) -> Dict[str, Any]:
"""Inject entity enhancement capabilities."""
params["kwargs"]["ontology_manager"] = self.ontology_manager
params["kwargs"]["ontology_context"] = context
params["kwargs"]["entity_extraction_enabled"] = True
# Add specific entity types to extract if configured
if "target_entity_types" in config:
params["kwargs"]["target_entity_types"] = config["target_entity_types"]
return params
async def _inject_datapoint_mappings(
self,
params: Dict[str, Any],
context: OntologyContext,
config: Dict[str, Any]
) -> Dict[str, Any]:
"""Inject DataPoint mapping configurations."""
# Get domain-specific mappings
if context.domain:
domain_mappings = getattr(
self.ontology_manager, 'domain_datapoint_mappings', {}
).get(context.domain, [])
if domain_mappings:
params["kwargs"]["datapoint_mappings"] = domain_mappings
params["kwargs"]["datapoint_resolver"] = self.ontology_manager.datapoint_resolver
# Add custom mappings from config
if "custom_mappings" in config:
params["kwargs"]["custom_datapoint_mappings"] = config["custom_mappings"]
return params
async def _inject_graph_binding(
self,
params: Dict[str, Any],
context: OntologyContext,
config: Dict[str, Any]
) -> Dict[str, Any]:
"""Inject graph binding configurations."""
# Get domain-specific binding config
if context.domain:
domain_binding = getattr(
self.ontology_manager, 'domain_graph_bindings', {}
).get(context.domain)
if domain_binding:
params["kwargs"]["graph_binding_config"] = domain_binding
params["kwargs"]["graph_binder"] = self.ontology_manager.graph_binder
# Add custom binding from config
if "custom_binding" in config:
params["kwargs"]["custom_graph_binding"] = config["custom_binding"]
return params
async def _inject_validation_config(
self,
params: Dict[str, Any],
context: OntologyContext,
config: Dict[str, Any]
) -> Dict[str, Any]:
"""Inject ontology validation configurations."""
params["kwargs"]["ontology_validation_enabled"] = True
params["kwargs"]["validation_threshold"] = config.get("validation_threshold", 0.8)
params["kwargs"]["strict_validation"] = config.get("strict_validation", False)
return params
class OntologyAwareTaskWrapper:
"""Wrapper to make existing tasks ontology-aware."""
def __init__(
self,
original_task: Task,
ontology_manager: IOntologyManager,
context: OntologyContext
):
self.original_task = original_task
self.ontology_manager = ontology_manager
self.context = context
async def execute_with_ontology(self, *args, **kwargs):
"""Execute task with ontology enhancements."""
# Enhance content if provided
if "content" in kwargs:
enhanced_content = await self.ontology_manager.enhance_with_ontology(
kwargs["content"], self.context
)
kwargs["enhanced_content"] = enhanced_content
# Add ontology context
kwargs["ontology_context"] = self.context
kwargs["ontology_manager"] = self.ontology_manager
# Execute original task
return await self.original_task.run(*args, **kwargs)
class PipelineOntologyConfigurator:
"""Configures ontology integration for entire pipelines."""
def __init__(self, ontology_manager: IOntologyManager):
self.ontology_manager = ontology_manager
self.pipeline_configs: Dict[str, Dict[str, Any]] = {}
def configure_pipeline(
self,
pipeline_name: str,
domain: str,
datapoint_mappings: List[DataPointMapping],
graph_binding: GraphBindingConfig,
task_specific_configs: Optional[Dict[str, Dict[str, Any]]] = None
) -> None:
"""Configure ontology for an entire pipeline."""
# Configure domain mappings
self.ontology_manager.configure_datapoint_mapping(domain, datapoint_mappings)
self.ontology_manager.configure_graph_binding(domain, graph_binding)
# Store pipeline configuration
self.pipeline_configs[pipeline_name] = {
"domain": domain,
"datapoint_mappings": datapoint_mappings,
"graph_binding": graph_binding,
"task_configs": task_specific_configs or {},
}
logger.info(f"Configured ontology for pipeline: {pipeline_name}")
def get_pipeline_context(
self,
pipeline_name: str,
user_id: Optional[str] = None,
dataset_id: Optional[str] = None,
custom_properties: Optional[Dict[str, Any]] = None
) -> OntologyContext:
"""Get ontology context for a pipeline."""
config = self.pipeline_configs.get(pipeline_name, {})
return OntologyContext(
user_id=user_id,
dataset_id=dataset_id,
pipeline_name=pipeline_name,
domain=config.get("domain"),
custom_properties=custom_properties or {}
)
def create_ontology_injector(self, pipeline_name: str) -> OntologyInjector:
"""Create an ontology injector configured for a specific pipeline."""
injector = OntologyInjector(self.ontology_manager)
# Apply pipeline-specific task configurations
if pipeline_name in self.pipeline_configs:
task_configs = self.pipeline_configs[pipeline_name].get("task_configs", {})
for task_name, config in task_configs.items():
injector.configure_task_ontology(task_name, config)
return injector
# Pre-configured pipeline setups for common domains
def create_medical_pipeline_config() -> Dict[str, Any]:
"""Create pre-configured ontology setup for medical pipelines."""
datapoint_mappings = [
DataPointMapping(
ontology_node_type="Disease",
datapoint_class="cognee.infrastructure.engine.models.DataPoint.DataPoint",
field_mappings={
"name": "name",
"description": "description",
"icd_code": "medical_code",
"severity": "severity_level",
},
validation_rules=["required:name", "type:severity_level:str"]
),
DataPointMapping(
ontology_node_type="Symptom",
datapoint_class="cognee.infrastructure.engine.models.DataPoint.DataPoint",
field_mappings={
"name": "name",
"description": "description",
"frequency": "occurrence_rate",
}
),
]
graph_binding = GraphBindingConfig(
node_type_mapping={
"Disease": "medical_entity",
"Symptom": "clinical_finding",
"Treatment": "therapeutic_procedure",
},
edge_type_mapping={
"treats": "therapeutic_relationship",
"causes": "causality",
"associated_with": "clinical_association",
}
)
task_configs = {
"extract_graph_from_data": {
"enhance_with_entities": True,
"inject_datapoint_mappings": True,
"inject_graph_binding": True,
"target_entity_types": ["Disease", "Symptom", "Treatment"],
},
"summarize_text": {
"enhance_with_entities": True,
"enable_ontology_validation": True,
"validation_threshold": 0.85,
}
}
return {
"domain": "medical",
"datapoint_mappings": datapoint_mappings,
"graph_binding": graph_binding,
"task_configs": task_configs,
}
def create_legal_pipeline_config() -> Dict[str, Any]:
"""Create pre-configured ontology setup for legal pipelines."""
datapoint_mappings = [
DataPointMapping(
ontology_node_type="Law",
datapoint_class="cognee.infrastructure.engine.models.DataPoint.DataPoint",
field_mappings={
"name": "name",
"description": "description",
"jurisdiction": "legal_authority",
"citation": "legal_citation",
}
),
DataPointMapping(
ontology_node_type="Case",
datapoint_class="cognee.infrastructure.engine.models.DataPoint.DataPoint",
field_mappings={
"name": "name",
"description": "description",
"court": "court_level",
"date": "decision_date",
}
),
]
graph_binding = GraphBindingConfig(
node_type_mapping={
"Law": "legal_statute",
"Case": "legal_precedent",
"Court": "judicial_body",
},
edge_type_mapping={
"cites": "legal_citation",
"overrules": "legal_override",
"applies": "legal_application",
}
)
task_configs = {
"extract_graph_from_data": {
"enhance_with_entities": True,
"inject_datapoint_mappings": True,
"inject_graph_binding": True,
"target_entity_types": ["Law", "Case", "Court", "Legal_Concept"],
}
}
return {
"domain": "legal",
"datapoint_mappings": datapoint_mappings,
"graph_binding": graph_binding,
"task_configs": task_configs,
}
def create_code_pipeline_config() -> Dict[str, Any]:
"""Create pre-configured ontology setup for code analysis pipelines."""
datapoint_mappings = [
DataPointMapping(
ontology_node_type="Function",
datapoint_class="cognee.infrastructure.engine.models.DataPoint.DataPoint",
field_mappings={
"name": "name",
"description": "description",
"parameters": "function_parameters",
"return_type": "return_type",
}
),
DataPointMapping(
ontology_node_type="Class",
datapoint_class="cognee.infrastructure.engine.models.DataPoint.DataPoint",
field_mappings={
"name": "name",
"description": "description",
"methods": "class_methods",
"inheritance": "parent_classes",
}
),
]
graph_binding = GraphBindingConfig(
node_type_mapping={
"Function": "code_function",
"Class": "code_class",
"Module": "code_module",
"Variable": "code_variable",
},
edge_type_mapping={
"calls": "function_call",
"inherits": "inheritance",
"imports": "module_import",
"defines": "definition",
}
)
task_configs = {
"extract_graph_from_code": {
"enhance_with_entities": True,
"inject_datapoint_mappings": True,
"inject_graph_binding": True,
"target_entity_types": ["Function", "Class", "Module", "Variable"],
}
}
return {
"domain": "code",
"datapoint_mappings": datapoint_mappings,
"graph_binding": graph_binding,
"task_configs": task_configs,
}

View file

@ -0,0 +1,475 @@
"""Ontology provider implementations."""
import json
import csv
from typing import Dict, Any, Union, Optional
from pathlib import Path
from cognee.shared.logging_utils import get_logger
from cognee.modules.ontology.interfaces import (
IOntologyProvider,
OntologyGraph,
OntologyNode,
OntologyEdge,
OntologyFormat,
OntologyScope,
OntologyContext,
)
logger = get_logger("OntologyProviders")
class RDFOntologyProvider(IOntologyProvider):
"""Provider for RDF/OWL ontologies."""
def __init__(self):
try:
from rdflib import Graph, URIRef, RDF, RDFS, OWL
self.Graph = Graph
self.URIRef = URIRef
self.RDF = RDF
self.RDFS = RDFS
self.OWL = OWL
self.available = True
except ImportError:
logger.warning("rdflib not available, RDF support disabled")
self.available = False
async def load_ontology(
self,
source: Union[str, Dict[str, Any]],
context: Optional[OntologyContext] = None
) -> OntologyGraph:
"""Load ontology from RDF/OWL file."""
if not self.available:
raise ImportError("rdflib is required for RDF ontology support")
if isinstance(source, dict):
file_path = source.get("file_path")
else:
file_path = source
if not file_path or not Path(file_path).exists():
raise FileNotFoundError(f"RDF file not found: {file_path}")
# Parse RDF graph
rdf_graph = self.Graph()
rdf_graph.parse(file_path)
# Convert to our ontology format
nodes = []
edges = []
# Extract classes and individuals
for cls in rdf_graph.subjects(self.RDF.type, self.OWL.Class):
node = OntologyNode(
id=self._uri_to_id(cls),
name=self._extract_name(cls),
type="class",
category="owl_class",
properties=self._extract_node_properties(cls, rdf_graph)
)
nodes.append(node)
# Extract individuals
for individual in rdf_graph.subjects(self.RDF.type, None):
if not any(rdf_graph.triples((individual, self.RDF.type, self.OWL.Class))):
node = OntologyNode(
id=self._uri_to_id(individual),
name=self._extract_name(individual),
type="individual",
category="owl_individual",
properties=self._extract_node_properties(individual, rdf_graph)
)
nodes.append(node)
# Extract relationships
for s, p, o in rdf_graph:
if p != self.RDF.type: # Skip type relationships
edge = OntologyEdge(
id=f"{self._uri_to_id(s)}_{self._uri_to_id(p)}_{self._uri_to_id(o)}",
source_id=self._uri_to_id(s),
target_id=self._uri_to_id(o),
relationship_type=self._extract_name(p),
properties={"predicate_uri": str(p)}
)
edges.append(edge)
ontology = OntologyGraph(
id=f"rdf_{Path(file_path).stem}",
name=Path(file_path).stem,
description=f"RDF ontology loaded from {file_path}",
format=OntologyFormat.RDF_XML,
scope=OntologyScope.DOMAIN,
nodes=nodes,
edges=edges,
metadata={"source_file": file_path, "triple_count": len(rdf_graph)}
)
logger.info(f"Loaded RDF ontology with {len(nodes)} nodes and {len(edges)} edges")
return ontology
async def save_ontology(
self,
ontology: OntologyGraph,
destination: str,
context: Optional[OntologyContext] = None
) -> bool:
"""Save ontology to RDF/OWL file."""
if not self.available:
raise ImportError("rdflib is required for RDF ontology support")
# Convert back to RDF
rdf_graph = self.Graph()
# Add nodes
for node in ontology.nodes:
uri = self.URIRef(f"http://example.org/ontology#{node.id}")
if node.type == "class":
rdf_graph.add((uri, self.RDF.type, self.OWL.Class))
else:
# Add as individual of some class
class_uri = self.URIRef(f"http://example.org/ontology#{node.type}")
rdf_graph.add((uri, self.RDF.type, class_uri))
# Add edges
for edge in ontology.edges:
s_uri = self.URIRef(f"http://example.org/ontology#{edge.source_id}")
p_uri = self.URIRef(f"http://example.org/ontology#{edge.relationship_type}")
o_uri = self.URIRef(f"http://example.org/ontology#{edge.target_id}")
rdf_graph.add((s_uri, p_uri, o_uri))
# Serialize to file
try:
rdf_graph.serialize(destination=destination, format='xml')
logger.info(f"Saved RDF ontology to {destination}")
return True
except Exception as e:
logger.error(f"Failed to save RDF ontology: {e}")
return False
def supports_format(self, format: OntologyFormat) -> bool:
"""Check if provider supports given format."""
return self.available and format in [OntologyFormat.RDF_XML, OntologyFormat.OWL]
async def validate_ontology(self, ontology: OntologyGraph) -> bool:
"""Validate RDF ontology structure."""
# Basic validation - could be enhanced with OWL reasoning
node_ids = {node.id for node in ontology.nodes}
for edge in ontology.edges:
if edge.source_id not in node_ids or edge.target_id not in node_ids:
return False
return True
def _uri_to_id(self, uri) -> str:
"""Convert URI to simple ID."""
uri_str = str(uri)
if "#" in uri_str:
return uri_str.split("#")[-1]
return uri_str.rstrip("/").split("/")[-1]
def _extract_name(self, uri) -> str:
"""Extract readable name from URI."""
return self._uri_to_id(uri).replace("_", " ").title()
def _extract_node_properties(self, uri, graph) -> Dict[str, Any]:
"""Extract additional properties for a node."""
props = {}
# Get labels
for label in graph.objects(uri, self.RDFS.label):
props["label"] = str(label)
# Get comments
for comment in graph.objects(uri, self.RDFS.comment):
props["comment"] = str(comment)
return props
class JSONOntologyProvider(IOntologyProvider):
"""Provider for JSON-based ontologies."""
async def load_ontology(
self,
source: Union[str, Dict[str, Any]],
context: Optional[OntologyContext] = None
) -> OntologyGraph:
"""Load ontology from JSON file or dict."""
if isinstance(source, str):
# Load from file
with open(source, 'r') as f:
data = json.load(f)
ontology_id = f"json_{Path(source).stem}"
source_file = source
else:
# Use provided dict
data = source
ontology_id = data.get("id", "json_ontology")
source_file = None
# Parse nodes
nodes = []
for node_data in data.get("nodes", []):
node = OntologyNode(
id=node_data["id"],
name=node_data.get("name", node_data["id"]),
type=node_data.get("type", "entity"),
description=node_data.get("description", ""),
category=node_data.get("category", "general"),
properties=node_data.get("properties", {})
)
nodes.append(node)
# Parse edges
edges = []
for edge_data in data.get("edges", []):
edge = OntologyEdge(
id=edge_data.get("id", f"{edge_data['source']}_{edge_data['target']}"),
source_id=edge_data["source"],
target_id=edge_data["target"],
relationship_type=edge_data.get("relationship", "related_to"),
properties=edge_data.get("properties", {}),
weight=edge_data.get("weight")
)
edges.append(edge)
ontology = OntologyGraph(
id=ontology_id,
name=data.get("name", ontology_id),
description=data.get("description", "JSON-based ontology"),
format=OntologyFormat.JSON,
scope=OntologyScope.DOMAIN,
nodes=nodes,
edges=edges,
metadata=data.get("metadata", {"source_file": source_file})
)
logger.info(f"Loaded JSON ontology with {len(nodes)} nodes and {len(edges)} edges")
return ontology
async def save_ontology(
self,
ontology: OntologyGraph,
destination: str,
context: Optional[OntologyContext] = None
) -> bool:
"""Save ontology to JSON file."""
data = {
"id": ontology.id,
"name": ontology.name,
"description": ontology.description,
"format": ontology.format.value,
"scope": ontology.scope.value,
"nodes": [
{
"id": node.id,
"name": node.name,
"type": node.type,
"description": node.description,
"category": node.category,
"properties": node.properties
}
for node in ontology.nodes
],
"edges": [
{
"id": edge.id,
"source": edge.source_id,
"target": edge.target_id,
"relationship": edge.relationship_type,
"properties": edge.properties,
"weight": edge.weight
}
for edge in ontology.edges
],
"metadata": ontology.metadata
}
try:
with open(destination, 'w') as f:
json.dump(data, f, indent=2)
logger.info(f"Saved JSON ontology to {destination}")
return True
except Exception as e:
logger.error(f"Failed to save JSON ontology: {e}")
return False
def supports_format(self, format: OntologyFormat) -> bool:
"""Check if provider supports given format."""
return format == OntologyFormat.JSON
async def validate_ontology(self, ontology: OntologyGraph) -> bool:
"""Validate JSON ontology structure."""
node_ids = {node.id for node in ontology.nodes}
for edge in ontology.edges:
if edge.source_id not in node_ids or edge.target_id not in node_ids:
return False
return True
class CSVOntologyProvider(IOntologyProvider):
"""Provider for CSV-based ontologies."""
async def load_ontology(
self,
source: Union[str, Dict[str, Any]],
context: Optional[OntologyContext] = None
) -> OntologyGraph:
"""Load ontology from CSV files."""
if isinstance(source, dict):
nodes_file = source.get("nodes_file")
edges_file = source.get("edges_file")
else:
# Assume single file or directory
source_path = Path(source)
if source_path.is_dir():
nodes_file = source_path / "nodes.csv"
edges_file = source_path / "edges.csv"
else:
nodes_file = source
edges_file = None
# Load nodes
nodes = []
if nodes_file and Path(nodes_file).exists():
with open(nodes_file, 'r') as f:
reader = csv.DictReader(f)
for row in reader:
node = OntologyNode(
id=row["id"],
name=row.get("name", row["id"]),
type=row.get("type", "entity"),
description=row.get("description", ""),
category=row.get("category", "general"),
properties={k: v for k, v in row.items()
if k not in ["id", "name", "type", "description", "category"]}
)
nodes.append(node)
# Load edges
edges = []
if edges_file and Path(edges_file).exists():
with open(edges_file, 'r') as f:
reader = csv.DictReader(f)
for row in reader:
edge = OntologyEdge(
id=row.get("id", f"{row['source']}_{row['target']}"),
source_id=row["source"],
target_id=row["target"],
relationship_type=row.get("relationship", "related_to"),
properties={k: v for k, v in row.items()
if k not in ["id", "source", "target", "relationship"]},
weight=float(row["weight"]) if row.get("weight") else None
)
edges.append(edge)
ontology = OntologyGraph(
id=f"csv_{Path(nodes_file).stem}" if nodes_file else "csv_ontology",
name=f"CSV Ontology",
description="CSV-based ontology",
format=OntologyFormat.CSV,
scope=OntologyScope.DOMAIN,
nodes=nodes,
edges=edges,
metadata={"nodes_file": str(nodes_file), "edges_file": str(edges_file)}
)
logger.info(f"Loaded CSV ontology with {len(nodes)} nodes and {len(edges)} edges")
return ontology
async def save_ontology(
self,
ontology: OntologyGraph,
destination: str,
context: Optional[OntologyContext] = None
) -> bool:
"""Save ontology to CSV files."""
dest_path = Path(destination)
if dest_path.suffix == ".csv":
# Single file - save nodes only
nodes_file = destination
edges_file = None
else:
# Directory - save separate files
dest_path.mkdir(exist_ok=True)
nodes_file = dest_path / "nodes.csv"
edges_file = dest_path / "edges.csv"
try:
# Save nodes
if ontology.nodes:
all_properties = set()
for node in ontology.nodes:
all_properties.update(node.properties.keys())
fieldnames = ["id", "name", "type", "description", "category"] + list(all_properties)
with open(nodes_file, 'w', newline='') as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
for node in ontology.nodes:
row = {
"id": node.id,
"name": node.name,
"type": node.type,
"description": node.description,
"category": node.category,
**node.properties
}
writer.writerow(row)
# Save edges
if edges_file and ontology.edges:
all_properties = set()
for edge in ontology.edges:
all_properties.update(edge.properties.keys())
fieldnames = ["id", "source", "target", "relationship", "weight"] + list(all_properties)
with open(edges_file, 'w', newline='') as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
for edge in ontology.edges:
row = {
"id": edge.id,
"source": edge.source_id,
"target": edge.target_id,
"relationship": edge.relationship_type,
"weight": edge.weight,
**edge.properties
}
writer.writerow(row)
logger.info(f"Saved CSV ontology to {destination}")
return True
except Exception as e:
logger.error(f"Failed to save CSV ontology: {e}")
return False
def supports_format(self, format: OntologyFormat) -> bool:
"""Check if provider supports given format."""
return format == OntologyFormat.CSV
async def validate_ontology(self, ontology: OntologyGraph) -> bool:
"""Validate CSV ontology structure."""
node_ids = {node.id for node in ontology.nodes}
for edge in ontology.edges:
if edge.source_id not in node_ids or edge.target_id not in node_ids:
return False
return True

View file

@ -0,0 +1,241 @@
"""Ontology registry implementation."""
from typing import Dict, List, Optional
from uuid import uuid4
from cognee.shared.logging_utils import get_logger
from cognee.modules.ontology.interfaces import (
IOntologyRegistry,
OntologyGraph,
OntologyScope,
OntologyContext,
)
logger = get_logger("OntologyRegistry")
class OntologyRegistry(IOntologyRegistry):
"""In-memory implementation of ontology registry."""
def __init__(self):
self.ontologies: Dict[str, OntologyGraph] = {}
self.scope_index: Dict[OntologyScope, List[str]] = {
scope: [] for scope in OntologyScope
}
self.domain_index: Dict[str, List[str]] = {}
self.user_index: Dict[str, List[str]] = {}
self.dataset_index: Dict[str, List[str]] = {}
self.pipeline_index: Dict[str, List[str]] = {}
async def register_ontology(
self,
ontology: OntologyGraph,
scope: OntologyScope,
context: Optional[OntologyContext] = None
) -> str:
"""Register an ontology."""
ontology_id = ontology.id or str(uuid4())
ontology.id = ontology_id
ontology.scope = scope
self.ontologies[ontology_id] = ontology
# Update scope index
if ontology_id not in self.scope_index[scope]:
self.scope_index[scope].append(ontology_id)
# Update domain index if applicable
if context and context.domain:
if context.domain not in self.domain_index:
self.domain_index[context.domain] = []
if ontology_id not in self.domain_index[context.domain]:
self.domain_index[context.domain].append(ontology_id)
# Update user index if applicable
if context and context.user_id:
if context.user_id not in self.user_index:
self.user_index[context.user_id] = []
if ontology_id not in self.user_index[context.user_id]:
self.user_index[context.user_id].append(ontology_id)
# Update dataset index if applicable
if context and context.dataset_id:
if context.dataset_id not in self.dataset_index:
self.dataset_index[context.dataset_id] = []
if ontology_id not in self.dataset_index[context.dataset_id]:
self.dataset_index[context.dataset_id].append(ontology_id)
# Update pipeline index if applicable
if context and context.pipeline_name:
if context.pipeline_name not in self.pipeline_index:
self.pipeline_index[context.pipeline_name] = []
if ontology_id not in self.pipeline_index[context.pipeline_name]:
self.pipeline_index[context.pipeline_name].append(ontology_id)
logger.info(f"Registered ontology {ontology_id} with scope {scope}")
return ontology_id
async def get_ontology(
self,
ontology_id: str,
context: Optional[OntologyContext] = None
) -> Optional[OntologyGraph]:
"""Get ontology by ID."""
return self.ontologies.get(ontology_id)
async def find_ontologies(
self,
scope: Optional[OntologyScope] = None,
domain: Optional[str] = None,
context: Optional[OntologyContext] = None
) -> List[OntologyGraph]:
"""Find ontologies matching criteria."""
candidate_ids = set()
# Filter by scope
if scope:
candidate_ids.update(self.scope_index.get(scope, []))
else:
# If no scope specified, get all
for scope_ids in self.scope_index.values():
candidate_ids.update(scope_ids)
# Filter by domain
if domain:
domain_ids = set(self.domain_index.get(domain, []))
candidate_ids = candidate_ids.intersection(domain_ids)
# Filter by context
if context:
if context.user_id:
user_ids = set(self.user_index.get(context.user_id, []))
if scope == OntologyScope.USER:
candidate_ids = candidate_ids.intersection(user_ids)
else:
candidate_ids.update(user_ids)
if context.dataset_id:
dataset_ids = set(self.dataset_index.get(context.dataset_id, []))
if scope == OntologyScope.DATASET:
candidate_ids = candidate_ids.intersection(dataset_ids)
else:
candidate_ids.update(dataset_ids)
if context.pipeline_name:
pipeline_ids = set(self.pipeline_index.get(context.pipeline_name, []))
if scope == OntologyScope.PIPELINE:
candidate_ids = candidate_ids.intersection(pipeline_ids)
else:
candidate_ids.update(pipeline_ids)
# Return matching ontologies
result = []
for ontology_id in candidate_ids:
if ontology_id in self.ontologies:
result.append(self.ontologies[ontology_id])
logger.debug(f"Found {len(result)} ontologies matching criteria")
return result
async def unregister_ontology(
self,
ontology_id: str,
context: Optional[OntologyContext] = None
) -> bool:
"""Unregister an ontology."""
if ontology_id not in self.ontologies:
return False
ontology = self.ontologies[ontology_id]
# Remove from all indices
for scope_ids in self.scope_index.values():
if ontology_id in scope_ids:
scope_ids.remove(ontology_id)
for domain_ids in self.domain_index.values():
if ontology_id in domain_ids:
domain_ids.remove(ontology_id)
for user_ids in self.user_index.values():
if ontology_id in user_ids:
user_ids.remove(ontology_id)
for dataset_ids in self.dataset_index.values():
if ontology_id in dataset_ids:
dataset_ids.remove(ontology_id)
for pipeline_ids in self.pipeline_index.values():
if ontology_id in pipeline_ids:
pipeline_ids.remove(ontology_id)
# Remove from main registry
del self.ontologies[ontology_id]
logger.info(f"Unregistered ontology {ontology_id}")
return True
def get_stats(self) -> Dict[str, int]:
"""Get registry statistics."""
return {
"total_ontologies": len(self.ontologies),
"global_ontologies": len(self.scope_index[OntologyScope.GLOBAL]),
"domain_ontologies": len(self.scope_index[OntologyScope.DOMAIN]),
"pipeline_ontologies": len(self.scope_index[OntologyScope.PIPELINE]),
"user_ontologies": len(self.scope_index[OntologyScope.USER]),
"dataset_ontologies": len(self.scope_index[OntologyScope.DATASET]),
"unique_domains": len(self.domain_index),
"unique_users": len(self.user_index),
"unique_datasets": len(self.dataset_index),
"unique_pipelines": len(self.pipeline_index),
}
class DatabaseOntologyRegistry(IOntologyRegistry):
"""Database-backed ontology registry (placeholder implementation)."""
def __init__(self, db_connection=None):
self.db_connection = db_connection
# This would use actual database operations in a real implementation
self._memory_registry = OntologyRegistry()
async def register_ontology(
self,
ontology: OntologyGraph,
scope: OntologyScope,
context: Optional[OntologyContext] = None
) -> str:
"""Register an ontology in database."""
# TODO: Implement database storage
return await self._memory_registry.register_ontology(ontology, scope, context)
async def get_ontology(
self,
ontology_id: str,
context: Optional[OntologyContext] = None
) -> Optional[OntologyGraph]:
"""Get ontology from database."""
# TODO: Implement database retrieval
return await self._memory_registry.get_ontology(ontology_id, context)
async def find_ontologies(
self,
scope: Optional[OntologyScope] = None,
domain: Optional[str] = None,
context: Optional[OntologyContext] = None
) -> List[OntologyGraph]:
"""Find ontologies in database."""
# TODO: Implement database query
return await self._memory_registry.find_ontologies(scope, domain, context)
async def unregister_ontology(
self,
ontology_id: str,
context: Optional[OntologyContext] = None
) -> bool:
"""Unregister ontology from database."""
# TODO: Implement database deletion
return await self._memory_registry.unregister_ontology(ontology_id, context)

View file

@ -0,0 +1,275 @@
"""DataPoint resolution implementations."""
import importlib
from typing import Any, Dict, List, Optional, Callable, Type
from uuid import uuid4
from cognee.shared.logging_utils import get_logger
from cognee.modules.ontology.interfaces import (
IDataPointResolver,
OntologyNode,
DataPointMapping,
OntologyContext,
)
logger = get_logger("DataPointResolver")
class DefaultDataPointResolver(IDataPointResolver):
"""Default implementation for DataPoint resolution."""
def __init__(self):
self.custom_resolvers: Dict[str, Callable] = {}
async def resolve_to_datapoint(
self,
ontology_node: OntologyNode,
mapping_config: DataPointMapping,
context: Optional[OntologyContext] = None
) -> Any: # DataPoint
"""Resolve ontology node to DataPoint instance."""
# Use custom resolver if specified
if mapping_config.custom_resolver and mapping_config.custom_resolver in self.custom_resolvers:
return await self._apply_custom_resolver(
ontology_node, mapping_config, context
)
# Use default resolution logic
return await self._default_resolution(ontology_node, mapping_config, context)
async def resolve_from_datapoint(
self,
datapoint: Any, # DataPoint
mapping_config: DataPointMapping,
context: Optional[OntologyContext] = None
) -> OntologyNode:
"""Resolve DataPoint instance to ontology node."""
# Extract properties from DataPoint
datapoint_dict = datapoint.to_dict() if hasattr(datapoint, 'to_dict') else datapoint.__dict__
# Map DataPoint fields to ontology properties
ontology_properties = {}
reverse_mappings = {v: k for k, v in mapping_config.field_mappings.items()}
for datapoint_field, ontology_field in reverse_mappings.items():
if datapoint_field in datapoint_dict:
ontology_properties[ontology_field] = datapoint_dict[datapoint_field]
# Create ontology node
node = OntologyNode(
id=str(datapoint.id) if hasattr(datapoint, 'id') else str(uuid4()),
name=ontology_properties.get('name', str(datapoint.id)),
type=mapping_config.ontology_node_type,
description=ontology_properties.get('description', ''),
category=ontology_properties.get('category', 'entity'),
properties=ontology_properties
)
return node
async def validate_mapping(
self,
mapping_config: DataPointMapping
) -> bool:
"""Validate mapping configuration."""
try:
# Check if DataPoint class exists
module_path, class_name = mapping_config.datapoint_class.rsplit('.', 1)
module = importlib.import_module(module_path)
datapoint_class = getattr(module, class_name)
# Validate field mappings
if hasattr(datapoint_class, '__annotations__'):
valid_fields = set(datapoint_class.__annotations__.keys())
mapped_fields = set(mapping_config.field_mappings.values())
invalid_fields = mapped_fields - valid_fields
if invalid_fields:
logger.warning(f"Invalid field mappings: {invalid_fields}")
return False
# Validate custom resolver if specified
if mapping_config.custom_resolver:
if mapping_config.custom_resolver not in self.custom_resolvers:
logger.warning(f"Custom resolver not found: {mapping_config.custom_resolver}")
return False
return True
except Exception as e:
logger.error(f"Mapping validation failed: {e}")
return False
def register_custom_resolver(
self,
resolver_name: str,
resolver_func: Callable[[OntologyNode, DataPointMapping], Any]
) -> None:
"""Register a custom resolver function."""
self.custom_resolvers[resolver_name] = resolver_func
logger.info(f"Registered custom resolver: {resolver_name}")
async def _apply_custom_resolver(
self,
ontology_node: OntologyNode,
mapping_config: DataPointMapping,
context: Optional[OntologyContext] = None
) -> Any:
"""Apply custom resolver function."""
resolver_func = self.custom_resolvers[mapping_config.custom_resolver]
if callable(resolver_func):
try:
# Call resolver function
if context:
result = await resolver_func(ontology_node, mapping_config, context)
else:
result = await resolver_func(ontology_node, mapping_config)
return result
except Exception as e:
logger.error(f"Custom resolver failed: {e}")
# Fallback to default resolution
return await self._default_resolution(ontology_node, mapping_config, context)
else:
logger.error(f"Invalid custom resolver: {mapping_config.custom_resolver}")
return await self._default_resolution(ontology_node, mapping_config, context)
async def _default_resolution(
self,
ontology_node: OntologyNode,
mapping_config: DataPointMapping,
context: Optional[OntologyContext] = None
) -> Any:
"""Default resolution logic."""
try:
# Import the DataPoint class
module_path, class_name = mapping_config.datapoint_class.rsplit('.', 1)
module = importlib.import_module(module_path)
datapoint_class = getattr(module, class_name)
# Map ontology properties to DataPoint fields
datapoint_data = {}
# Apply field mappings
for ontology_field, datapoint_field in mapping_config.field_mappings.items():
if ontology_field in ontology_node.properties:
datapoint_data[datapoint_field] = ontology_node.properties[ontology_field]
elif hasattr(ontology_node, ontology_field):
datapoint_data[datapoint_field] = getattr(ontology_node, ontology_field)
# Set default mappings if not provided
if 'id' not in datapoint_data:
datapoint_data['id'] = ontology_node.id
if 'type' not in datapoint_data:
datapoint_data['type'] = ontology_node.type
# Add ontology metadata
if hasattr(datapoint_class, 'metadata'):
datapoint_data['metadata'] = {
'type': ontology_node.type,
'index_fields': list(mapping_config.field_mappings.values()),
'ontology_source': True,
'ontology_node_id': ontology_node.id,
}
# Set ontology_valid flag
datapoint_data['ontology_valid'] = True
# Create DataPoint instance
datapoint = datapoint_class(**datapoint_data)
# Apply validation rules if specified
if mapping_config.validation_rules:
await self._apply_validation_rules(datapoint, mapping_config.validation_rules)
return datapoint
except Exception as e:
logger.error(f"Default resolution failed for node {ontology_node.id}: {e}")
return None
async def _apply_validation_rules(
self,
datapoint: Any,
validation_rules: List[str]
) -> None:
"""Apply validation rules to DataPoint."""
for rule in validation_rules:
try:
# This is a simple implementation - in practice, you'd want
# a more sophisticated rule engine
if rule.startswith("required:"):
field_name = rule.split(":", 1)[1]
if not hasattr(datapoint, field_name) or getattr(datapoint, field_name) is None:
raise ValueError(f"Required field {field_name} is missing")
elif rule.startswith("type:"):
field_name, expected_type = rule.split(":", 2)[1:]
if hasattr(datapoint, field_name):
field_value = getattr(datapoint, field_name)
if field_value is not None and not isinstance(field_value, eval(expected_type)):
raise ValueError(f"Field {field_name} has wrong type")
# Add more validation rules as needed
except Exception as e:
logger.warning(f"Validation rule '{rule}' failed: {e}")
class DomainSpecificResolver(DefaultDataPointResolver):
"""Domain-specific resolver with specialized logic."""
def __init__(self, domain: str):
super().__init__()
self.domain = domain
async def _default_resolution(
self,
ontology_node: OntologyNode,
mapping_config: DataPointMapping,
context: Optional[OntologyContext] = None
) -> Any:
"""Domain-specific resolution logic."""
# Apply domain-specific preprocessing
if self.domain == "medical":
ontology_node = await self._preprocess_medical_node(ontology_node)
elif self.domain == "legal":
ontology_node = await self._preprocess_legal_node(ontology_node)
elif self.domain == "code":
ontology_node = await self._preprocess_code_node(ontology_node)
# Use parent's default resolution
return await super()._default_resolution(ontology_node, mapping_config, context)
async def _preprocess_medical_node(self, node: OntologyNode) -> OntologyNode:
"""Preprocess medical domain nodes."""
# Add medical-specific property transformations
if node.type == "Disease":
node.properties["medical_category"] = "disease"
elif node.type == "Symptom":
node.properties["medical_category"] = "symptom"
return node
async def _preprocess_legal_node(self, node: OntologyNode) -> OntologyNode:
"""Preprocess legal domain nodes."""
# Add legal-specific property transformations
if node.type == "Law":
node.properties["legal_category"] = "legislation"
elif node.type == "Case":
node.properties["legal_category"] = "precedent"
return node
async def _preprocess_code_node(self, node: OntologyNode) -> OntologyNode:
"""Preprocess code domain nodes."""
# Add code-specific property transformations
if node.type == "Function":
node.properties["code_category"] = "function"
elif node.type == "Class":
node.properties["code_category"] = "class"
return node

View file

@ -0,0 +1,410 @@
"""
Enhanced graph extraction task with ontology awareness.
This demonstrates how to update existing tasks to use the new ontology system.
"""
import asyncio
from typing import Type, List, Optional, Any, Dict
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.tasks.storage.add_data_points import add_data_points
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
from cognee.modules.graph.utils import (
expand_with_nodes_and_edges,
retrieve_existing_edges,
)
from cognee.shared.data_models import KnowledgeGraph
from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.shared.logging_utils import get_logger
# New ontology imports
from cognee.modules.ontology.interfaces import (
IOntologyManager,
OntologyContext,
DataPointMapping,
GraphBindingConfig,
)
logger = get_logger("extract_graph_from_data_ontology_aware")
async def extract_graph_from_data_ontology_aware(
data_chunks: List[DocumentChunk],
graph_model: Type[Any] = KnowledgeGraph,
ontology_manager: Optional[IOntologyManager] = None,
ontology_context: Optional[OntologyContext] = None,
datapoint_mappings: Optional[List[DataPointMapping]] = None,
graph_binding_config: Optional[GraphBindingConfig] = None,
entity_extraction_enabled: bool = False,
target_entity_types: Optional[List[str]] = None,
enhanced_content: Optional[Dict[str, Any]] = None,
**kwargs
) -> List[DocumentChunk]:
"""
Enhanced graph extraction with ontology awareness.
Args:
data_chunks: Document chunks to process
graph_model: Graph model type (KnowledgeGraph or custom)
ontology_manager: Ontology manager instance (injected by pipeline)
ontology_context: Ontology context (injected by pipeline)
datapoint_mappings: DataPoint mappings (injected by pipeline)
graph_binding_config: Graph binding configuration (injected by pipeline)
entity_extraction_enabled: Whether to use ontology for entity extraction
target_entity_types: Specific entity types to extract
enhanced_content: Pre-enhanced content with ontological information
**kwargs: Additional parameters
Returns:
Updated document chunks with extracted graph data
"""
logger.info(f"Processing {len(data_chunks)} chunks with ontology awareness")
# Check if ontology integration is available
if ontology_manager and ontology_context:
logger.info(f"Ontology integration enabled for domain: {ontology_context.domain}")
return await _extract_with_ontology(
data_chunks=data_chunks,
graph_model=graph_model,
ontology_manager=ontology_manager,
ontology_context=ontology_context,
datapoint_mappings=datapoint_mappings,
graph_binding_config=graph_binding_config,
entity_extraction_enabled=entity_extraction_enabled,
target_entity_types=target_entity_types,
enhanced_content=enhanced_content,
**kwargs
)
else:
logger.info("No ontology integration, using standard extraction")
return await _extract_standard(data_chunks, graph_model, **kwargs)
async def _extract_with_ontology(
data_chunks: List[DocumentChunk],
graph_model: Type[Any],
ontology_manager: IOntologyManager,
ontology_context: OntologyContext,
datapoint_mappings: Optional[List[DataPointMapping]] = None,
graph_binding_config: Optional[GraphBindingConfig] = None,
entity_extraction_enabled: bool = False,
target_entity_types: Optional[List[str]] = None,
enhanced_content: Optional[Dict[str, Any]] = None,
**kwargs
) -> List[DocumentChunk]:
"""Extract graph data with ontology enhancement."""
# Step 1: Get applicable ontologies
ontologies = await ontology_manager.get_applicable_ontologies(ontology_context)
logger.info(f"Found {len(ontologies)} applicable ontologies")
if not ontologies:
logger.warning("No applicable ontologies found, falling back to standard extraction")
return await _extract_standard(data_chunks, graph_model, **kwargs)
# Step 2: Enhance content with ontological information
chunk_graphs = []
enhanced_datapoints = []
for chunk in data_chunks:
# Enhance chunk content if not already done
if enhanced_content:
chunk_enhanced = enhanced_content
else:
chunk_enhanced = await ontology_manager.enhance_with_ontology(
chunk.text, ontology_context
)
# Extract graph using enhanced information
chunk_graph = await _extract_chunk_graph_with_ontology(
chunk=chunk,
enhanced_content=chunk_enhanced,
graph_model=graph_model,
ontology_manager=ontology_manager,
ontology_context=ontology_context,
target_entity_types=target_entity_types,
)
chunk_graphs.append(chunk_graph)
# Convert ontological entities to DataPoints if mappings available
if datapoint_mappings and chunk_enhanced.get('extracted_entities'):
ontology_nodes = await _convert_entities_to_ontology_nodes(
chunk_enhanced['extracted_entities'], ontologies[0]
)
if ontology_nodes:
datapoints = await ontology_manager.resolve_to_datapoints(
ontology_nodes, ontology_context
)
enhanced_datapoints.extend(datapoints)
# Step 3: Integrate with graph database using custom binding
if graph_model is KnowledgeGraph:
# Use standard integration with ontology-enhanced nodes/edges
await _integrate_ontology_enhanced_graphs(
data_chunks=data_chunks,
chunk_graphs=chunk_graphs,
enhanced_datapoints=enhanced_datapoints,
ontology_manager=ontology_manager,
ontology_context=ontology_context,
graph_binding_config=graph_binding_config,
)
else:
# Custom graph model - just attach graphs to chunks
for chunk_index, chunk_graph in enumerate(chunk_graphs):
data_chunks[chunk_index].contains = chunk_graph
logger.info(f"Completed ontology-aware graph extraction for {len(data_chunks)} chunks")
return data_chunks
async def _extract_chunk_graph_with_ontology(
chunk: DocumentChunk,
enhanced_content: Dict[str, Any],
graph_model: Type[Any],
ontology_manager: IOntologyManager,
ontology_context: OntologyContext,
target_entity_types: Optional[List[str]] = None,
) -> Any:
"""Extract graph for a single chunk using ontological enhancement."""
# Build entity-aware prompt for LLM
extracted_entities = enhanced_content.get('extracted_entities', [])
semantic_relationships = enhanced_content.get('semantic_relationships', [])
# Filter entities by target types if specified
if target_entity_types:
extracted_entities = [
entity for entity in extracted_entities
if entity.get('type') in target_entity_types
]
# Create enhanced prompt with ontological context
ontology_context_prompt = _build_ontology_context_prompt(
extracted_entities, semantic_relationships
)
# Use LLM to extract graph with ontological guidance
full_prompt = f"""
{ontology_context_prompt}
Text to analyze:
{chunk.text}
Extract entities and relationships, giving preference to the ontological entities
and relationships mentioned above when they appear in the text.
"""
# Extract using LLM with ontological context
chunk_graph = await LLMGateway.acreate_structured_output(
full_prompt,
"You are a knowledge graph extractor. Use the ontological context to guide your extraction.",
graph_model
)
# Enhance extracted graph with ontological metadata
if hasattr(chunk_graph, 'nodes'):
for node in chunk_graph.nodes:
# Add ontological metadata to nodes
matching_entity = _find_matching_ontological_entity(node, extracted_entities)
if matching_entity:
node.ontology_source = matching_entity.get('ontology_id')
node.ontology_confidence = matching_entity.get('confidence', 0.0)
if hasattr(node, 'type'):
node.type = matching_entity.get('type', node.type)
return chunk_graph
async def _integrate_ontology_enhanced_graphs(
data_chunks: List[DocumentChunk],
chunk_graphs: List[Any],
enhanced_datapoints: List[Any],
ontology_manager: IOntologyManager,
ontology_context: OntologyContext,
graph_binding_config: Optional[GraphBindingConfig] = None,
):
"""Integrate ontology-enhanced graphs into the graph database."""
graph_engine = await get_graph_engine()
# Get existing edges to avoid duplicates
existing_edges_map = await retrieve_existing_edges(data_chunks, chunk_graphs)
# Apply ontology-aware graph binding if available
if graph_binding_config:
# Transform graphs using custom binding
all_nodes = []
all_edges = []
for chunk_graph in chunk_graphs:
if hasattr(chunk_graph, 'nodes') and hasattr(chunk_graph, 'edges'):
# Convert chunk graph to ontology format for binding
from cognee.modules.ontology.interfaces import OntologyGraph, OntologyNode, OntologyEdge
ontology_nodes = []
ontology_edges = []
for node in chunk_graph.nodes:
ont_node = OntologyNode(
id=node.id,
name=node.name,
type=node.type,
description=getattr(node, 'description', ''),
properties=getattr(node, '__dict__', {})
)
ontology_nodes.append(ont_node)
for edge in chunk_graph.edges:
ont_edge = OntologyEdge(
id=f"{edge.source_node_id}_{edge.target_node_id}_{edge.relationship_name}",
source_id=edge.source_node_id,
target_id=edge.target_node_id,
relationship_type=edge.relationship_name,
properties=getattr(edge, '__dict__', {})
)
ontology_edges.append(ont_edge)
# Create temporary ontology for binding
temp_ontology = OntologyGraph(
id="temp_chunk_ontology",
name="Chunk Ontology",
description="Temporary ontology for chunk graph",
format="llm_generated",
scope="dataset",
nodes=ontology_nodes,
edges=ontology_edges
)
# Apply custom binding
bound_nodes, bound_edges = await ontology_manager.bind_to_graph(
temp_ontology, ontology_context
)
all_nodes.extend(bound_nodes)
all_edges.extend(bound_edges)
# Add bound nodes and edges to graph
if all_nodes:
await add_data_points(all_nodes)
if all_edges:
await graph_engine.add_edges(all_edges)
else:
# Use standard integration
graph_nodes, graph_edges = expand_with_nodes_and_edges(
data_chunks, chunk_graphs, None, existing_edges_map
)
if graph_nodes:
await add_data_points(graph_nodes)
if graph_edges:
await graph_engine.add_edges(graph_edges)
# Add enhanced DataPoints from ontology resolution
if enhanced_datapoints:
await add_data_points(enhanced_datapoints)
logger.info(f"Added {len(enhanced_datapoints)} ontology-resolved DataPoints")
async def _extract_standard(
data_chunks: List[DocumentChunk],
graph_model: Type[Any],
**kwargs
) -> List[DocumentChunk]:
"""Standard graph extraction without ontology (fallback)."""
# Import the original function to maintain compatibility
from cognee.tasks.graph.extract_graph_from_data import integrate_chunk_graphs
# Generate standard graphs using LLM
chunk_graphs = []
for chunk in data_chunks:
system_prompt = LLMGateway.read_query_prompt("generate_graph_prompt_oneshot.txt")
chunk_graph = await LLMGateway.acreate_structured_output(
chunk.text, system_prompt, graph_model
)
chunk_graphs.append(chunk_graph)
# Use standard integration
return await integrate_chunk_graphs(
data_chunks=data_chunks,
chunk_graphs=chunk_graphs,
graph_model=graph_model,
ontology_adapter=None, # No ontology adapter
)
def _build_ontology_context_prompt(
extracted_entities: List[Dict[str, Any]],
semantic_relationships: List[Dict[str, Any]]
) -> str:
"""Build ontological context prompt for LLM."""
if not extracted_entities and not semantic_relationships:
return ""
prompt = "ONTOLOGICAL CONTEXT:\n"
if extracted_entities:
prompt += "Known entities in this domain:\n"
for entity in extracted_entities[:10]: # Limit to top 10
prompt += f"- {entity['name']} (type: {entity['type']})\n"
prompt += "\n"
if semantic_relationships:
prompt += "Known relationships:\n"
for rel in semantic_relationships[:10]: # Limit to top 10
prompt += f"- {rel['source']} {rel['relationship']} {rel['target']}\n"
prompt += "\n"
prompt += "When extracting entities and relationships, prefer these ontological concepts when they appear in the text.\n\n"
return prompt
def _find_matching_ontological_entity(
extracted_node: Any,
ontological_entities: List[Dict[str, Any]]
) -> Optional[Dict[str, Any]]:
"""Find matching ontological entity for an extracted node."""
node_name = getattr(extracted_node, 'name', '').lower()
for entity in ontological_entities:
entity_name = entity.get('name', '').lower()
if node_name == entity_name or node_name in entity_name or entity_name in node_name:
return entity
return None
async def _convert_entities_to_ontology_nodes(
extracted_entities: List[Dict[str, Any]],
ontology: Any
) -> List[Any]:
"""Convert extracted entities to ontology nodes for DataPoint resolution."""
from cognee.modules.ontology.interfaces import OntologyNode
ontology_nodes = []
for entity in extracted_entities:
node = OntologyNode(
id=entity.get('node_id', entity['name']),
name=entity['name'],
type=entity['type'],
description=entity.get('description', ''),
category=entity.get('category', 'entity'),
properties={
'confidence': entity.get('confidence', 0.0),
'ontology_id': entity.get('ontology_id'),
'source': 'llm_extraction'
}
)
ontology_nodes.append(node)
return ontology_nodes