Compare commits
1 commit
main
...
ontology_i
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
191a9ed0ee |
21 changed files with 6081 additions and 0 deletions
581
cognee/modules/ontology/IMPLEMENTATION_TICKET.md
Normal file
581
cognee/modules/ontology/IMPLEMENTATION_TICKET.md
Normal file
|
|
@ -0,0 +1,581 @@
|
|||
# 🎫 Epic: Refactor Ontology System for General Pipeline Usage
|
||||
|
||||
**Epic ID:** CGNEE-2024-ONT-001
|
||||
**Priority:** High
|
||||
**Story Points:** 21
|
||||
**Type:** Epic
|
||||
**Labels:** `refactoring`, `architecture`, `ontology`, `pipeline-integration`
|
||||
|
||||
## 📋 Overview
|
||||
|
||||
Refactor the current monolithic `OntologyResolver` to follow Cognee's architectural patterns and be general, extensible, and usable across all pipelines. The current system is tightly coupled to specific tasks and only supports RDF/OWL formats. We need a modular system that follows Cognee's established patterns for configuration, methods organization, and module structure.
|
||||
|
||||
## 🎯 Business Value
|
||||
|
||||
- **Consistency**: Follow established Cognee patterns and conventions
|
||||
- **Flexibility**: Support multiple ontology formats and domains
|
||||
- **Reusability**: One ontology system usable across all pipelines
|
||||
- **Maintainability**: Modular architecture following Cognee's separation of concerns
|
||||
- **Developer Experience**: Familiar patterns and simple configuration
|
||||
|
||||
## 🏗️ Architecture Analysis
|
||||
|
||||
Based on examination of existing Cognee patterns:
|
||||
|
||||
### **Configuration Pattern** (Following `cognee/base_config.py`, `cognee/modules/cognify/config.py`)
|
||||
```python
|
||||
# Follow BaseSettings pattern with @lru_cache
|
||||
class OntologyConfig(BaseSettings):
|
||||
default_format: OntologyFormat = OntologyFormat.JSON
|
||||
enable_semantic_search: bool = False
|
||||
registry_type: str = "memory"
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||
|
||||
@lru_cache
|
||||
def get_ontology_config():
|
||||
return OntologyConfig()
|
||||
```
|
||||
|
||||
### **Methods Organization** (Following `cognee/modules/data/methods/`, `cognee/modules/users/methods/`)
|
||||
```python
|
||||
# Organize methods in separate files:
|
||||
cognee/modules/ontology/methods/
|
||||
├── __init__.py # Export all methods
|
||||
├── create_ontology.py # async def create_ontology(...)
|
||||
├── get_ontology.py # async def get_ontology(...)
|
||||
├── load_ontology.py # async def load_ontology(...)
|
||||
├── register_ontology.py # async def register_ontology(...)
|
||||
└── delete_ontology.py # async def delete_ontology(...)
|
||||
```
|
||||
|
||||
### **Models Structure** (Following `cognee/modules/data/models/`, `cognee/modules/users/models/`)
|
||||
```python
|
||||
# Create models with proper inheritance:
|
||||
cognee/modules/ontology/models/
|
||||
├── __init__.py # Export all models
|
||||
├── OntologyGraph.py # class OntologyGraph(BaseModel)
|
||||
├── OntologyNode.py # class OntologyNode(BaseModel)
|
||||
├── OntologyContext.py # class OntologyContext(BaseModel)
|
||||
└── DataPointMapping.py # class DataPointMapping(BaseModel)
|
||||
```
|
||||
|
||||
### **Module Organization** (Following existing module patterns)
|
||||
```python
|
||||
cognee/modules/ontology/
|
||||
├── __init__.py # Public API with convenience functions
|
||||
├── config.py # OntologyConfig with @lru_cache
|
||||
├── models/ # Pydantic models
|
||||
├── methods/ # Async method functions
|
||||
├── providers/ # Format-specific providers
|
||||
├── adapters/ # Query and search operations
|
||||
├── operations/ # Core business logic
|
||||
└── utils/ # Utility functions
|
||||
```
|
||||
|
||||
## 📦 Epic Breakdown (Cognee-Style Implementation)
|
||||
|
||||
### Story 1: Configuration System & Models
|
||||
**Story Points:** 3
|
||||
**Assignee:** Backend Developer
|
||||
|
||||
#### Acceptance Criteria
|
||||
- [ ] Create `OntologyConfig` following `BaseSettings` pattern with `@lru_cache`
|
||||
- [ ] Create Pydantic models in `models/` directory with proper `__init__.py` exports
|
||||
- [ ] Follow naming conventions: `OntologyNode`, `OntologyEdge`, etc.
|
||||
- [ ] Use proper type hints and docstring patterns from existing code
|
||||
- [ ] Environment variable support following existing config patterns
|
||||
|
||||
#### Files to Create
|
||||
```
|
||||
cognee/modules/ontology/config.py
|
||||
cognee/modules/ontology/models/__init__.py
|
||||
cognee/modules/ontology/models/OntologyGraph.py
|
||||
cognee/modules/ontology/models/OntologyNode.py
|
||||
cognee/modules/ontology/models/OntologyContext.py
|
||||
cognee/modules/ontology/models/DataPointMapping.py
|
||||
```
|
||||
|
||||
#### Implementation Notes
|
||||
```python
|
||||
# config.py - Follow existing pattern
|
||||
@lru_cache
|
||||
def get_ontology_config():
|
||||
return OntologyConfig()
|
||||
|
||||
# models/OntologyNode.py - Follow DataPoint pattern
|
||||
class OntologyNode(BaseModel):
|
||||
id: str = Field(..., description="Unique identifier")
|
||||
name: str
|
||||
type: str
|
||||
# ... follow existing model patterns
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Story 2: Methods Organization
|
||||
**Story Points:** 2
|
||||
**Assignee:** Backend Developer
|
||||
**Dependencies:** Story 1
|
||||
|
||||
#### Acceptance Criteria
|
||||
- [ ] Create `methods/` directory following existing pattern
|
||||
- [ ] Implement async functions following `create_dataset`, `get_user` patterns
|
||||
- [ ] Use proper error handling patterns from existing methods
|
||||
- [ ] Follow parameter naming and return type conventions
|
||||
- [ ] Export all methods in `methods/__init__.py`
|
||||
|
||||
#### Files to Create
|
||||
```
|
||||
cognee/modules/ontology/methods/__init__.py
|
||||
cognee/modules/ontology/methods/create_ontology.py
|
||||
cognee/modules/ontology/methods/get_ontology.py
|
||||
cognee/modules/ontology/methods/load_ontology.py
|
||||
cognee/modules/ontology/methods/register_ontology.py
|
||||
cognee/modules/ontology/methods/delete_ontology.py
|
||||
```
|
||||
|
||||
#### Implementation Notes
|
||||
```python
|
||||
# methods/create_ontology.py - Follow create_dataset pattern
|
||||
async def create_ontology(
|
||||
ontology_data: Dict[str, Any],
|
||||
user: User,
|
||||
scope: OntologyScope = OntologyScope.USER
|
||||
) -> OntologyGraph:
|
||||
# Follow existing error handling and validation patterns
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Story 3: Provider System (Following Existing Patterns)
|
||||
**Story Points:** 4
|
||||
**Assignee:** Backend Developer
|
||||
**Dependencies:** Story 2
|
||||
|
||||
#### Acceptance Criteria
|
||||
- [ ] Create providers following the adapter pattern seen in retrieval systems
|
||||
- [ ] Use abstract base classes like `CogneeAbstractGraph`
|
||||
- [ ] Follow error handling patterns from existing providers
|
||||
- [ ] Support graceful degradation (like RDF provider with optional rdflib)
|
||||
- [ ] Use proper logging patterns with `get_logger()`
|
||||
|
||||
#### Files to Create
|
||||
```
|
||||
cognee/modules/ontology/providers/__init__.py
|
||||
cognee/modules/ontology/providers/base.py
|
||||
cognee/modules/ontology/providers/rdf_provider.py
|
||||
cognee/modules/ontology/providers/json_provider.py
|
||||
cognee/modules/ontology/providers/csv_provider.py
|
||||
```
|
||||
|
||||
#### Implementation Notes
|
||||
```python
|
||||
# providers/base.py - Follow CogneeAbstractGraph pattern
|
||||
class BaseOntologyProvider(ABC):
|
||||
@abstractmethod
|
||||
async def load_ontology(self, source: str) -> OntologyGraph:
|
||||
pass
|
||||
|
||||
# providers/rdf_provider.py - Follow graceful fallback pattern
|
||||
class RDFOntologyProvider(BaseOntologyProvider):
|
||||
def __init__(self):
|
||||
try:
|
||||
import rdflib
|
||||
self.available = True
|
||||
except ImportError:
|
||||
logger.warning("rdflib not available")
|
||||
self.available = False
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Story 4: Operations Layer (Core Business Logic)
|
||||
**Story Points:** 4
|
||||
**Assignee:** Senior Backend Developer
|
||||
**Dependencies:** Story 3
|
||||
|
||||
#### Acceptance Criteria
|
||||
- [ ] Create `operations/` directory following pipeline operations pattern
|
||||
- [ ] Implement core business logic following `cognee_pipeline` pattern
|
||||
- [ ] Use dependency injection patterns seen in existing operations
|
||||
- [ ] Follow async/await patterns consistently
|
||||
- [ ] Implement proper error handling and logging
|
||||
|
||||
#### Files to Create
|
||||
```
|
||||
cognee/modules/ontology/operations/__init__.py
|
||||
cognee/modules/ontology/operations/ontology_manager.py
|
||||
cognee/modules/ontology/operations/datapoint_resolver.py
|
||||
cognee/modules/ontology/operations/graph_binder.py
|
||||
cognee/modules/ontology/operations/registry.py
|
||||
```
|
||||
|
||||
#### Implementation Notes
|
||||
```python
|
||||
# operations/ontology_manager.py - Follow cognee_pipeline pattern
|
||||
async def manage_ontology_processing(
|
||||
context: OntologyContext,
|
||||
providers: Dict[str, BaseOntologyProvider],
|
||||
config: OntologyConfig = None
|
||||
) -> List[OntologyGraph]:
|
||||
# Follow existing operation patterns
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Story 5: Pipeline Integration (Following Task Pattern)
|
||||
**Story Points:** 3
|
||||
**Assignee:** Backend Developer
|
||||
**Dependencies:** Story 4
|
||||
|
||||
#### Acceptance Criteria
|
||||
- [ ] Create integration following `Task` class pattern
|
||||
- [ ] Support injection into existing pipeline operations
|
||||
- [ ] Follow parameter passing patterns from `cognee_pipeline`
|
||||
- [ ] Maintain backward compatibility with existing tasks
|
||||
- [ ] Use context variables pattern for configuration
|
||||
|
||||
#### Files to Create
|
||||
```
|
||||
cognee/modules/ontology/operations/pipeline_integration.py
|
||||
cognee/modules/ontology/operations/task_enhancer.py
|
||||
```
|
||||
|
||||
#### Implementation Notes
|
||||
```python
|
||||
# operations/pipeline_integration.py - Follow cognee_pipeline pattern
|
||||
async def inject_ontology_context(
|
||||
tasks: list[Task],
|
||||
ontology_context: OntologyContext,
|
||||
config: OntologyConfig = None
|
||||
) -> list[Task]:
|
||||
# Follow existing task enhancement patterns
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Story 6: Utils and Utilities
|
||||
**Story Points:** 2
|
||||
**Assignee:** Backend Developer
|
||||
**Dependencies:** Story 5
|
||||
|
||||
#### Acceptance Criteria
|
||||
- [ ] Create `utils/` directory for utility functions
|
||||
- [ ] Follow utility patterns from existing modules
|
||||
- [ ] Implement helper functions for common operations
|
||||
- [ ] Use proper type hints and error handling
|
||||
|
||||
#### Files to Create
|
||||
```
|
||||
cognee/modules/ontology/utils/__init__.py
|
||||
cognee/modules/ontology/utils/ontology_helpers.py
|
||||
cognee/modules/ontology/utils/mapping_helpers.py
|
||||
cognee/modules/ontology/utils/validation.py
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Story 7: Public API and Module Initialization
|
||||
**Story Points:** 2
|
||||
**Assignee:** Backend Developer
|
||||
**Dependencies:** Story 6
|
||||
|
||||
#### Acceptance Criteria
|
||||
- [ ] Create comprehensive `__init__.py` following existing patterns
|
||||
- [ ] Export key classes and functions following module conventions
|
||||
- [ ] Create convenience functions following `get_base_config()` pattern
|
||||
- [ ] Maintain backward compatibility with `OntologyResolver`
|
||||
- [ ] Use proper `__all__` exports
|
||||
|
||||
#### Files to Update/Create
|
||||
```
|
||||
cognee/modules/ontology/__init__.py
|
||||
```
|
||||
|
||||
#### Implementation Notes
|
||||
```python
|
||||
# __init__.py - Follow existing module export patterns
|
||||
from .config import get_ontology_config
|
||||
from .models import OntologyGraph, OntologyNode, OntologyContext
|
||||
from .methods import create_ontology, load_ontology, get_ontology
|
||||
|
||||
# Convenience functions following get_base_config pattern
|
||||
@lru_cache
|
||||
def get_ontology_manager():
|
||||
return OntologyManager()
|
||||
|
||||
__all__ = [
|
||||
"OntologyGraph",
|
||||
"OntologyNode",
|
||||
"get_ontology_config",
|
||||
"create_ontology",
|
||||
# ... follow existing export patterns
|
||||
]
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Story 8: Enhanced Task Implementation
|
||||
**Story Points:** 2
|
||||
**Assignee:** Backend Developer
|
||||
**Dependencies:** Story 7
|
||||
|
||||
#### Acceptance Criteria
|
||||
- [ ] Update existing graph extraction task to be ontology-aware
|
||||
- [ ] Follow existing task parameter patterns
|
||||
- [ ] Maintain backward compatibility
|
||||
- [ ] Use proper error handling and fallback mechanisms
|
||||
- [ ] Follow existing task documentation patterns
|
||||
|
||||
#### Files to Create/Update
|
||||
```
|
||||
cognee/tasks/graph/extract_graph_from_data_ontology_aware.py
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Story 9: Documentation and Examples
|
||||
**Story Points:** 1
|
||||
**Assignee:** Technical Writer / Backend Developer
|
||||
**Dependencies:** Story 8
|
||||
|
||||
#### Acceptance Criteria
|
||||
- [ ] Create migration guide following existing documentation patterns
|
||||
- [ ] Provide examples following existing code patterns
|
||||
- [ ] Document configuration options following existing config docs
|
||||
- [ ] Create troubleshooting guide
|
||||
|
||||
#### Files to Create
|
||||
```
|
||||
cognee/modules/ontology/MIGRATION_GUIDE.md
|
||||
cognee/modules/ontology/examples/
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Story 10: Testing Following Cognee Patterns
|
||||
**Story Points:** 2
|
||||
**Assignee:** QA Engineer / Backend Developer
|
||||
**Dependencies:** Story 9
|
||||
|
||||
#### Acceptance Criteria
|
||||
- [ ] Create tests following existing test structure in `cognee/tests/`
|
||||
- [ ] Use existing test patterns and fixtures
|
||||
- [ ] Test integration with existing pipeline system
|
||||
- [ ] Verify backward compatibility
|
||||
- [ ] Performance testing following existing benchmarks
|
||||
|
||||
#### Files to Create
|
||||
```
|
||||
cognee/tests/unit/ontology/
|
||||
cognee/tests/integration/ontology/
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🔧 Technical Requirements (Cognee-Aligned)
|
||||
|
||||
### Follow Existing Patterns
|
||||
- **Configuration**: Use `BaseSettings` with `@lru_cache` pattern
|
||||
- **Models**: Pydantic models with proper inheritance and validation
|
||||
- **Methods**: Async functions with consistent error handling
|
||||
- **Operations**: Business logic separation following existing operations
|
||||
- **Exports**: Proper `__init__.py` files with `__all__` exports
|
||||
- **Logging**: Use `get_logger()` pattern consistently
|
||||
- **Error Handling**: Follow existing exception patterns
|
||||
|
||||
### Integration Requirements
|
||||
- Work seamlessly with existing `Task` system
|
||||
- Support existing `cognee_pipeline` operations
|
||||
- Integrate with current configuration management
|
||||
- Support existing database patterns
|
||||
- Maintain compatibility with current DataPoint model
|
||||
|
||||
### Code Style Requirements
|
||||
- Follow existing naming conventions (PascalCase for classes, snake_case for functions)
|
||||
- Use type hints consistently like existing code
|
||||
- Follow docstring patterns from existing modules
|
||||
- Use existing import organization patterns
|
||||
- Follow async/await patterns consistently
|
||||
|
||||
## 🚀 Implementation Guidelines
|
||||
|
||||
### File Organization (Following Cognee Patterns)
|
||||
```
|
||||
cognee/modules/ontology/
|
||||
├── __init__.py # Public API, convenience functions
|
||||
├── config.py # OntologyConfig with @lru_cache
|
||||
├── models/ # Pydantic models
|
||||
│ ├── __init__.py # Export all models
|
||||
│ ├── OntologyGraph.py
|
||||
│ ├── OntologyNode.py
|
||||
│ └── ...
|
||||
├── methods/ # Business methods (CRUD operations)
|
||||
│ ├── __init__.py # Export all methods
|
||||
│ ├── create_ontology.py
|
||||
│ ├── get_ontology.py
|
||||
│ └── ...
|
||||
├── providers/ # Format-specific providers
|
||||
├── operations/ # Core business logic
|
||||
├── utils/ # Utility functions
|
||||
└── examples/ # Usage examples
|
||||
```
|
||||
|
||||
### Code Patterns to Follow
|
||||
|
||||
#### Configuration Pattern
|
||||
```python
|
||||
# cognee/modules/ontology/config.py
|
||||
from functools import lru_cache
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
class OntologyConfig(BaseSettings):
|
||||
default_format: str = "json"
|
||||
enable_semantic_search: bool = False
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {"default_format": self.default_format, ...}
|
||||
|
||||
@lru_cache
|
||||
def get_ontology_config():
|
||||
return OntologyConfig()
|
||||
```
|
||||
|
||||
#### Method Pattern
|
||||
```python
|
||||
# cognee/modules/ontology/methods/create_ontology.py
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.ontology.models import OntologyGraph
|
||||
|
||||
logger = get_logger("ontology.create")
|
||||
|
||||
async def create_ontology(
|
||||
ontology_data: Dict[str, Any],
|
||||
user: User,
|
||||
scope: OntologyScope = OntologyScope.USER
|
||||
) -> OntologyGraph:
|
||||
"""Create ontology following existing method patterns."""
|
||||
try:
|
||||
# Implementation following existing patterns
|
||||
logger.info(f"Creating ontology for user: {user.id}")
|
||||
# ...
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create ontology: {e}")
|
||||
raise
|
||||
```
|
||||
|
||||
#### Model Pattern
|
||||
```python
|
||||
# cognee/modules/ontology/models/OntologyNode.py
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
class OntologyNode(BaseModel):
|
||||
"""Ontology node following Cognee model patterns."""
|
||||
|
||||
id: str = Field(..., description="Unique identifier")
|
||||
name: str
|
||||
type: str
|
||||
description: Optional[str] = None
|
||||
properties: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
class Config:
|
||||
"""Follow existing model config patterns."""
|
||||
arbitrary_types_allowed = True
|
||||
```
|
||||
|
||||
### Integration with Existing Systems
|
||||
|
||||
#### Task Integration
|
||||
```python
|
||||
# Follow existing Task parameter patterns
|
||||
async def extract_graph_from_data_ontology_aware(
|
||||
data_chunks: List[DocumentChunk],
|
||||
graph_model: Type[BaseModel] = KnowledgeGraph,
|
||||
ontology_config: OntologyConfig = None,
|
||||
**kwargs
|
||||
) -> List[DocumentChunk]:
|
||||
"""Enhanced task following existing task patterns."""
|
||||
|
||||
config = ontology_config or get_ontology_config()
|
||||
# Follow existing task implementation patterns
|
||||
```
|
||||
|
||||
#### Pipeline Integration
|
||||
```python
|
||||
# Follow cognee_pipeline pattern for integration
|
||||
async def cognee_pipeline_with_ontology(
|
||||
tasks: list[Task],
|
||||
ontology_context: OntologyContext = None,
|
||||
**kwargs
|
||||
):
|
||||
"""Enhanced pipeline following existing pipeline patterns."""
|
||||
|
||||
# Inject ontology context following existing parameter injection
|
||||
enhanced_tasks = []
|
||||
for task in tasks:
|
||||
if ontology_context:
|
||||
# Enhance task following existing enhancement patterns
|
||||
enhanced_task = enhance_task_with_ontology(task, ontology_context)
|
||||
enhanced_tasks.append(enhanced_task)
|
||||
else:
|
||||
enhanced_tasks.append(task)
|
||||
|
||||
# Use existing pipeline execution
|
||||
return await cognee_pipeline(enhanced_tasks, **kwargs)
|
||||
```
|
||||
|
||||
## 📊 Success Metrics (Aligned with Cognee Standards)
|
||||
|
||||
### Code Quality Metrics
|
||||
- [ ] Follow all existing linting and code style rules
|
||||
- [ ] Pass all existing code quality checks
|
||||
- [ ] Maintain or improve test coverage percentage
|
||||
- [ ] Follow existing documentation standards
|
||||
|
||||
### Integration Metrics
|
||||
- [ ] Zero breaking changes to existing API
|
||||
- [ ] All existing tests continue to pass
|
||||
- [ ] Performance meets or exceeds existing benchmarks
|
||||
- [ ] Memory usage within existing parameters
|
||||
|
||||
### Pattern Compliance
|
||||
- [ ] Configuration follows `BaseSettings` + `@lru_cache` pattern
|
||||
- [ ] Models follow existing Pydantic patterns
|
||||
- [ ] Methods follow existing async function patterns
|
||||
- [ ] Exports follow existing `__init__.py` patterns
|
||||
- [ ] Error handling follows existing exception patterns
|
||||
|
||||
## 🔗 Related Files to Study
|
||||
|
||||
### Configuration Patterns
|
||||
- `cognee/base_config.py`
|
||||
- `cognee/modules/cognify/config.py`
|
||||
- `cognee/infrastructure/llm/config.py`
|
||||
|
||||
### Model Patterns
|
||||
- `cognee/modules/data/models/`
|
||||
- `cognee/modules/users/models/`
|
||||
- `cognee/infrastructure/engine/models/DataPoint.py`
|
||||
|
||||
### Method Patterns
|
||||
- `cognee/modules/data/methods/`
|
||||
- `cognee/modules/users/methods/`
|
||||
|
||||
### Operation Patterns
|
||||
- `cognee/modules/pipelines/operations/`
|
||||
- `cognee/modules/search/methods/`
|
||||
|
||||
### Module Organization
|
||||
- `cognee/modules/pipelines/__init__.py`
|
||||
- `cognee/modules/data/__init__.py`
|
||||
- `cognee/modules/users/__init__.py`
|
||||
|
||||
---
|
||||
|
||||
**Estimated Total Effort:** 21 Story Points (~4-5 Sprints)
|
||||
**Target Completion:** End of Q2 2024
|
||||
**Review Required:** Architecture Review, Code Standards Review, Integration Review
|
||||
|
||||
**Key Success Factor:** Strict adherence to existing Cognee patterns and conventions to ensure seamless integration and maintainability.
|
||||
367
cognee/modules/ontology/MIGRATION_GUIDE.md
Normal file
367
cognee/modules/ontology/MIGRATION_GUIDE.md
Normal file
|
|
@ -0,0 +1,367 @@
|
|||
# Ontology System Migration Guide
|
||||
|
||||
This guide explains how to migrate from the old ontology system to the new refactored architecture.
|
||||
|
||||
## Overview of Changes
|
||||
|
||||
The ontology system has been completely refactored to provide:
|
||||
|
||||
1. **Better Separation of Concerns**: Clear interfaces for different components
|
||||
2. **DataPoint Integration**: Automatic mapping between ontologies and DataPoint instances
|
||||
3. **Custom Graph Binding**: Configurable binding strategies for different graph types
|
||||
4. **Pipeline Integration**: Seamless injection into pipeline tasks
|
||||
5. **Domain Configuration**: Pre-configured setups for common domains
|
||||
6. **Extensibility**: Plugin system for custom behavior
|
||||
|
||||
## Architecture Changes
|
||||
|
||||
### Old System
|
||||
```
|
||||
OntologyResolver (monolithic)
|
||||
├── RDF/OWL parsing
|
||||
├── Basic node/edge extraction
|
||||
└── Simple graph operations
|
||||
```
|
||||
|
||||
### New System
|
||||
```
|
||||
OntologyManager (orchestrator)
|
||||
├── OntologyRegistry (storage/lookup)
|
||||
├── OntologyProviders (format-specific loading)
|
||||
├── OntologyAdapters (query/search operations)
|
||||
├── DataPointResolver (ontology ↔ DataPoint mapping)
|
||||
└── GraphBinder (ontology → graph structure binding)
|
||||
```
|
||||
|
||||
## Migration Steps
|
||||
|
||||
### 1. Replace OntologyResolver Usage
|
||||
|
||||
**Old Code:**
|
||||
```python
|
||||
from cognee.modules.ontology.rdf_xml.OntologyResolver import OntologyResolver
|
||||
|
||||
# Old way
|
||||
ontology_resolver = OntologyResolver(ontology_file="medical.owl")
|
||||
nodes, edges, root = ontology_resolver.get_subgraph("Disease", "classes")
|
||||
```
|
||||
|
||||
**New Code:**
|
||||
```python
|
||||
from cognee.modules.ontology import (
|
||||
create_ontology_system,
|
||||
OntologyContext,
|
||||
configure_domain,
|
||||
DataPointMapping,
|
||||
GraphBindingConfig
|
||||
)
|
||||
|
||||
# New way
|
||||
ontology_manager = await create_ontology_system()
|
||||
|
||||
# Configure domain if needed
|
||||
mappings = [
|
||||
DataPointMapping(
|
||||
ontology_node_type="Disease",
|
||||
datapoint_class="cognee.infrastructure.engine.models.DataPoint.DataPoint",
|
||||
field_mappings={"name": "name", "description": "description"}
|
||||
)
|
||||
]
|
||||
binding = GraphBindingConfig(
|
||||
node_type_mapping={"Disease": "medical_condition"}
|
||||
)
|
||||
configure_domain("medical", mappings, binding)
|
||||
|
||||
# Use with context
|
||||
context = OntologyContext(domain="medical")
|
||||
ontologies = await ontology_manager.get_applicable_ontologies(context)
|
||||
```
|
||||
|
||||
### 2. Update Pipeline Task Integration
|
||||
|
||||
**Old Code:**
|
||||
```python
|
||||
from cognee.api.v1.cognify.cognify import get_default_tasks
|
||||
|
||||
# Old way - hardcoded ontology adapter
|
||||
tasks = await get_default_tasks(
|
||||
ontology_file_path="path/to/ontology.owl"
|
||||
)
|
||||
```
|
||||
|
||||
**New Code:**
|
||||
```python
|
||||
from cognee.modules.ontology import create_pipeline_injector
|
||||
from cognee.modules.pipelines.tasks.task import Task
|
||||
|
||||
# New way - configurable ontology injection
|
||||
ontology_manager = await create_ontology_system()
|
||||
injector = create_pipeline_injector(ontology_manager, "my_pipeline", "medical")
|
||||
|
||||
# Inject into tasks
|
||||
original_task = Task(extract_graph_from_data)
|
||||
context = OntologyContext(domain="medical", pipeline_name="my_pipeline")
|
||||
enhanced_task = await injector.inject_into_task(original_task, context)
|
||||
```
|
||||
|
||||
### 3. Update DataPoint Creation
|
||||
|
||||
**Old Code:**
|
||||
```python
|
||||
# Old way - manual DataPoint creation
|
||||
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
||||
|
||||
datapoint = DataPoint(
|
||||
id="disease_001",
|
||||
type="Disease",
|
||||
# Manual field mapping...
|
||||
)
|
||||
```
|
||||
|
||||
**New Code:**
|
||||
```python
|
||||
# New way - automatic resolution from ontology
|
||||
ontology_nodes = [...] # Nodes from ontology
|
||||
context = OntologyContext(domain="medical")
|
||||
|
||||
# Automatically resolve to DataPoints
|
||||
datapoints = await ontology_manager.resolve_to_datapoints(ontology_nodes, context)
|
||||
```
|
||||
|
||||
### 4. Update Graph Binding
|
||||
|
||||
**Old Code:**
|
||||
```python
|
||||
# Old way - hardcoded graph node creation
|
||||
graph_nodes = []
|
||||
for ontology_node in nodes:
|
||||
graph_node = (
|
||||
ontology_node.id,
|
||||
{
|
||||
"name": ontology_node.name,
|
||||
"type": ontology_node.type,
|
||||
# Manual property mapping...
|
||||
}
|
||||
)
|
||||
graph_nodes.append(graph_node)
|
||||
```
|
||||
|
||||
**New Code:**
|
||||
```python
|
||||
# New way - configurable binding
|
||||
ontology = # ... loaded ontology
|
||||
context = OntologyContext(domain="medical")
|
||||
|
||||
# Automatically bind to graph structure
|
||||
graph_nodes, graph_edges = await ontology_manager.bind_to_graph(ontology, context)
|
||||
```
|
||||
|
||||
## Domain-Specific Configurations
|
||||
|
||||
### Medical Domain
|
||||
|
||||
**Old Code:**
|
||||
```python
|
||||
# Old way - manual setup for each pipeline
|
||||
ontology_resolver = OntologyResolver("medical_ontology.owl")
|
||||
# ... manual entity extraction
|
||||
# ... manual graph binding
|
||||
```
|
||||
|
||||
**New Code:**
|
||||
```python
|
||||
# New way - use pre-configured medical domain
|
||||
from cognee.modules.ontology import create_medical_pipeline_config
|
||||
|
||||
ontology_manager = await create_ontology_system()
|
||||
injector = create_pipeline_injector(ontology_manager, "medical_pipeline", "medical")
|
||||
|
||||
# Automatically configured for:
|
||||
# - Disease, Symptom, Treatment entities
|
||||
# - Medical-specific DataPoint mappings
|
||||
# - Clinical graph relationships
|
||||
```
|
||||
|
||||
### Legal Domain
|
||||
|
||||
```python
|
||||
# Pre-configured for legal documents
|
||||
injector = create_pipeline_injector(ontology_manager, "legal_pipeline", "legal")
|
||||
|
||||
# Automatically handles:
|
||||
# - Law, Case, Court entities
|
||||
# - Legal citation relationships
|
||||
# - Jurisdiction-aware processing
|
||||
```
|
||||
|
||||
### Code Analysis Domain
|
||||
|
||||
```python
|
||||
# Pre-configured for code analysis
|
||||
injector = create_pipeline_injector(ontology_manager, "code_pipeline", "code")
|
||||
|
||||
# Automatically handles:
|
||||
# - Function, Class, Module entities
|
||||
# - Code dependency relationships
|
||||
# - Language-specific processing
|
||||
```
|
||||
|
||||
## Custom Resolvers and Binding
|
||||
|
||||
### Custom DataPoint Resolver
|
||||
|
||||
```python
|
||||
# Define custom resolver for special entity types
|
||||
async def custom_medical_resolver(ontology_node, mapping_config, context=None):
|
||||
# Custom logic for creating DataPoints
|
||||
datapoint = DataPoint(
|
||||
id=ontology_node.id,
|
||||
type="medical_entity",
|
||||
# Custom field mapping and validation
|
||||
)
|
||||
return datapoint
|
||||
|
||||
# Register the resolver
|
||||
ontology_manager.register_custom_resolver("medical_resolver", custom_medical_resolver)
|
||||
|
||||
# Use in mapping configuration
|
||||
mapping = DataPointMapping(
|
||||
ontology_node_type="SpecialDisease",
|
||||
datapoint_class="cognee.infrastructure.engine.models.DataPoint.DataPoint",
|
||||
custom_resolver="medical_resolver"
|
||||
)
|
||||
```
|
||||
|
||||
### Custom Graph Binding
|
||||
|
||||
```python
|
||||
# Define custom binding strategy
|
||||
async def custom_graph_binding(ontology, binding_config, context=None):
|
||||
# Custom logic for graph structure creation
|
||||
graph_nodes = []
|
||||
graph_edges = []
|
||||
|
||||
for node in ontology.nodes:
|
||||
# Custom node transformation
|
||||
transformed_node = transform_node(node)
|
||||
graph_nodes.append(transformed_node)
|
||||
|
||||
return graph_nodes, graph_edges
|
||||
|
||||
# Register the strategy
|
||||
ontology_manager.register_binding_strategy("custom_binding", custom_graph_binding)
|
||||
|
||||
# Use in binding configuration
|
||||
binding = GraphBindingConfig(
|
||||
custom_binding_strategy="custom_binding"
|
||||
)
|
||||
```
|
||||
|
||||
## Configuration Files
|
||||
|
||||
### Create Configuration File
|
||||
|
||||
```python
|
||||
from cognee.modules.ontology import create_example_config_file
|
||||
|
||||
# Generate example configuration
|
||||
create_example_config_file("ontology_config.json")
|
||||
```
|
||||
|
||||
### Load Configuration
|
||||
|
||||
```python
|
||||
from cognee.modules.ontology import load_ontology_config
|
||||
|
||||
# Load from file
|
||||
load_ontology_config("ontology_config.json")
|
||||
|
||||
# Or create system with config
|
||||
ontology_manager = await create_ontology_system(config_file="ontology_config.json")
|
||||
```
|
||||
|
||||
## Backward Compatibility
|
||||
|
||||
The old `OntologyResolver` is still available for backward compatibility:
|
||||
|
||||
```python
|
||||
from cognee.modules.ontology import OntologyResolver
|
||||
|
||||
# Old interface still works
|
||||
resolver = OntologyResolver(ontology_file="medical.owl")
|
||||
```
|
||||
|
||||
However, it's recommended to migrate to the new system for better flexibility and features.
|
||||
|
||||
## Performance Improvements
|
||||
|
||||
The new system provides several performance benefits:
|
||||
|
||||
1. **Lazy Loading**: Ontologies are loaded only when needed
|
||||
2. **Caching**: Registry caches frequently accessed ontologies
|
||||
3. **Parallel Processing**: Multiple ontologies can be processed simultaneously
|
||||
4. **Semantic Search**: Optional semantic similarity for better entity matching
|
||||
|
||||
## Testing Your Migration
|
||||
|
||||
Use the example usage script to test your migration:
|
||||
|
||||
```python
|
||||
from cognee.modules.ontology.example_usage import main
|
||||
|
||||
# Run comprehensive examples
|
||||
await main()
|
||||
```
|
||||
|
||||
## Common Migration Issues
|
||||
|
||||
### 1. Import Errors
|
||||
|
||||
**Problem**: `ImportError: cannot import name 'OntologyResolver'`
|
||||
|
||||
**Solution**: Update imports to use new interfaces:
|
||||
```python
|
||||
# Old
|
||||
from cognee.modules.ontology.rdf_xml.OntologyResolver import OntologyResolver
|
||||
|
||||
# New
|
||||
from cognee.modules.ontology import create_ontology_system, OntologyContext
|
||||
```
|
||||
|
||||
### 2. Task Parameter Changes
|
||||
|
||||
**Problem**: Tasks expecting `ontology_adapter` parameter
|
||||
|
||||
**Solution**: Use ontology injection:
|
||||
```python
|
||||
# Old
|
||||
Task(extract_graph_from_data, ontology_adapter=resolver)
|
||||
|
||||
# New
|
||||
injector = create_pipeline_injector(ontology_manager, "pipeline", "domain")
|
||||
enhanced_task = await injector.inject_into_task(original_task, context)
|
||||
```
|
||||
|
||||
### 3. DataPoint Creation Changes
|
||||
|
||||
**Problem**: Manual DataPoint creation with ontology data
|
||||
|
||||
**Solution**: Use automatic resolution:
|
||||
```python
|
||||
# Old
|
||||
datapoint = DataPoint(id=node.id, type=node.type, ...)
|
||||
|
||||
# New
|
||||
datapoints = await ontology_manager.resolve_to_datapoints(nodes, context)
|
||||
```
|
||||
|
||||
## Support
|
||||
|
||||
For additional support with migration:
|
||||
|
||||
1. Check the example usage file: `cognee/modules/ontology/example_usage.py`
|
||||
2. Review the interface documentation in `cognee/modules/ontology/interfaces.py`
|
||||
3. Use the pre-configured domain setups for common use cases
|
||||
4. Test with the provided configuration examples
|
||||
|
||||
The new system is designed to be more powerful while maintaining ease of use. Most migrations can be completed by updating imports and using the convenience functions provided.
|
||||
|
|
@ -0,0 +1,196 @@
|
|||
"""
|
||||
Ontology module for Cognee.
|
||||
|
||||
Provides ontology management capabilities including loading, processing,
|
||||
and integration with pipelines following Cognee's architectural patterns.
|
||||
"""
|
||||
|
||||
# Configuration (following Cognee pattern)
|
||||
from .config import get_ontology_config
|
||||
|
||||
# Core data models
|
||||
from .interfaces import (
|
||||
OntologyNode,
|
||||
OntologyEdge,
|
||||
OntologyGraph,
|
||||
OntologyContext,
|
||||
DataPointMapping,
|
||||
GraphBindingConfig,
|
||||
OntologyFormat,
|
||||
OntologyScope,
|
||||
)
|
||||
|
||||
# Core implementations
|
||||
from .manager import OntologyManager, create_ontology_manager
|
||||
from .registry import OntologyRegistry
|
||||
from .providers import (
|
||||
RDFOntologyProvider,
|
||||
JSONOntologyProvider,
|
||||
CSVOntologyProvider,
|
||||
)
|
||||
|
||||
# Methods (following Cognee pattern)
|
||||
from .methods import (
|
||||
create_ontology,
|
||||
get_ontology,
|
||||
load_ontology,
|
||||
register_ontology,
|
||||
delete_ontology,
|
||||
)
|
||||
|
||||
# Legacy compatibility
|
||||
from .rdf_xml.OntologyResolver import OntologyResolver
|
||||
|
||||
|
||||
# Convenience functions for quick setup
|
||||
async def create_ontology_system(
|
||||
config_file: str = None,
|
||||
use_database_registry: bool = False,
|
||||
enable_semantic_search: bool = False
|
||||
) -> OntologyManager:
|
||||
"""
|
||||
Create a fully configured ontology system.
|
||||
|
||||
Args:
|
||||
config_file: Optional configuration file to load
|
||||
use_database_registry: Whether to use database-backed registry
|
||||
enable_semantic_search: Whether to enable semantic search capabilities
|
||||
|
||||
Returns:
|
||||
Configured OntologyManager instance
|
||||
"""
|
||||
# Create registry
|
||||
if use_database_registry:
|
||||
registry = DatabaseOntologyRegistry()
|
||||
else:
|
||||
registry = OntologyRegistry()
|
||||
|
||||
# Create providers
|
||||
providers = {
|
||||
"json_provider": JSONOntologyProvider(),
|
||||
"csv_provider": CSVOntologyProvider(),
|
||||
}
|
||||
|
||||
# Add RDF provider if available
|
||||
rdf_provider = RDFOntologyProvider()
|
||||
if rdf_provider.available:
|
||||
providers["rdf_provider"] = rdf_provider
|
||||
|
||||
# Create adapters
|
||||
adapters = {
|
||||
"default_adapter": DefaultOntologyAdapter(),
|
||||
"graph_adapter": GraphOntologyAdapter(),
|
||||
}
|
||||
|
||||
# Add semantic adapter if requested and available
|
||||
if enable_semantic_search:
|
||||
semantic_adapter = SemanticOntologyAdapter()
|
||||
if semantic_adapter.embeddings_available:
|
||||
adapters["semantic_adapter"] = semantic_adapter
|
||||
|
||||
# Create resolver and binder
|
||||
datapoint_resolver = DefaultDataPointResolver()
|
||||
graph_binder = DefaultGraphBinder()
|
||||
|
||||
# Create manager
|
||||
manager = await create_ontology_manager(
|
||||
registry=registry,
|
||||
providers=providers,
|
||||
adapters=adapters,
|
||||
datapoint_resolver=datapoint_resolver,
|
||||
graph_binder=graph_binder,
|
||||
)
|
||||
|
||||
# Load configuration if provided
|
||||
if config_file:
|
||||
config = get_ontology_config()
|
||||
config.load_from_file(config_file)
|
||||
|
||||
# Apply configurations to manager
|
||||
for domain, domain_config in config.domain_configs.items():
|
||||
manager.configure_datapoint_mapping(
|
||||
domain, domain_config["datapoint_mappings"]
|
||||
)
|
||||
manager.configure_graph_binding(
|
||||
domain, domain_config["graph_binding"]
|
||||
)
|
||||
|
||||
return manager
|
||||
|
||||
|
||||
def create_pipeline_injector(
|
||||
ontology_manager: OntologyManager,
|
||||
pipeline_name: str,
|
||||
domain: str = None
|
||||
) -> OntologyInjector:
|
||||
"""
|
||||
Create an ontology injector for a specific pipeline.
|
||||
|
||||
Args:
|
||||
ontology_manager: The ontology manager instance
|
||||
pipeline_name: Name of the pipeline
|
||||
domain: Domain for the pipeline (optional)
|
||||
|
||||
Returns:
|
||||
Configured OntologyInjector
|
||||
"""
|
||||
configurator = PipelineOntologyConfigurator(ontology_manager)
|
||||
|
||||
# Use pre-configured domain setup if available
|
||||
if domain in ["medical", "legal", "code"]:
|
||||
if domain == "medical":
|
||||
config = create_medical_pipeline_config()
|
||||
elif domain == "legal":
|
||||
config = create_legal_pipeline_config()
|
||||
elif domain == "code":
|
||||
config = create_code_pipeline_config()
|
||||
|
||||
configurator.configure_pipeline(
|
||||
pipeline_name=pipeline_name,
|
||||
domain=config["domain"],
|
||||
datapoint_mappings=config["datapoint_mappings"],
|
||||
graph_binding=config["graph_binding"],
|
||||
task_specific_configs=config["task_configs"]
|
||||
)
|
||||
|
||||
return configurator.create_ontology_injector(pipeline_name)
|
||||
|
||||
|
||||
# Export following Cognee pattern
|
||||
__all__ = [
|
||||
# Configuration
|
||||
"get_ontology_config",
|
||||
|
||||
# Core classes
|
||||
"OntologyManager",
|
||||
"OntologyRegistry",
|
||||
|
||||
# Data models
|
||||
"OntologyNode",
|
||||
"OntologyEdge",
|
||||
"OntologyGraph",
|
||||
"OntologyContext",
|
||||
"DataPointMapping",
|
||||
"GraphBindingConfig",
|
||||
"OntologyFormat",
|
||||
"OntologyScope",
|
||||
|
||||
# Providers
|
||||
"RDFOntologyProvider",
|
||||
"JSONOntologyProvider",
|
||||
"CSVOntologyProvider",
|
||||
|
||||
# Methods
|
||||
"create_ontology",
|
||||
"get_ontology",
|
||||
"load_ontology",
|
||||
"register_ontology",
|
||||
"delete_ontology",
|
||||
|
||||
# Convenience functions
|
||||
"create_ontology_system",
|
||||
"create_pipeline_injector",
|
||||
|
||||
# Legacy compatibility
|
||||
"OntologyResolver",
|
||||
]
|
||||
397
cognee/modules/ontology/adapters.py
Normal file
397
cognee/modules/ontology/adapters.py
Normal file
|
|
@ -0,0 +1,397 @@
|
|||
"""Ontology adapter implementations."""
|
||||
|
||||
import difflib
|
||||
from typing import List, Tuple, Optional
|
||||
from collections import deque
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.ontology.interfaces import (
|
||||
IOntologyAdapter,
|
||||
OntologyGraph,
|
||||
OntologyNode,
|
||||
OntologyEdge,
|
||||
)
|
||||
|
||||
logger = get_logger("OntologyAdapters")
|
||||
|
||||
|
||||
class DefaultOntologyAdapter(IOntologyAdapter):
|
||||
"""Default implementation of ontology adapter."""
|
||||
|
||||
async def find_matching_nodes(
|
||||
self,
|
||||
query_text: str,
|
||||
ontology: OntologyGraph,
|
||||
similarity_threshold: float = 0.8
|
||||
) -> List[OntologyNode]:
|
||||
"""Find nodes matching query text using simple string similarity."""
|
||||
|
||||
matching_nodes = []
|
||||
query_lower = query_text.lower()
|
||||
|
||||
for node in ontology.nodes:
|
||||
# Check name similarity
|
||||
name_similarity = self._calculate_similarity(query_lower, node.name.lower())
|
||||
|
||||
# Check description similarity
|
||||
desc_similarity = 0.0
|
||||
if node.description:
|
||||
desc_similarity = self._calculate_similarity(query_lower, node.description.lower())
|
||||
|
||||
# Check properties similarity
|
||||
props_similarity = 0.0
|
||||
for prop_value in node.properties.values():
|
||||
if isinstance(prop_value, str):
|
||||
prop_sim = self._calculate_similarity(query_lower, prop_value.lower())
|
||||
props_similarity = max(props_similarity, prop_sim)
|
||||
|
||||
# Take maximum similarity
|
||||
max_similarity = max(name_similarity, desc_similarity, props_similarity)
|
||||
|
||||
if max_similarity >= similarity_threshold:
|
||||
# Add similarity score to node properties for ranking
|
||||
node_copy = OntologyNode(**node.dict())
|
||||
node_copy.properties["_similarity_score"] = max_similarity
|
||||
matching_nodes.append(node_copy)
|
||||
|
||||
# Sort by similarity score
|
||||
matching_nodes.sort(key=lambda n: n.properties.get("_similarity_score", 0), reverse=True)
|
||||
|
||||
logger.debug(f"Found {len(matching_nodes)} nodes matching '{query_text}'")
|
||||
return matching_nodes
|
||||
|
||||
async def get_node_relationships(
|
||||
self,
|
||||
node_id: str,
|
||||
ontology: OntologyGraph,
|
||||
max_depth: int = 2
|
||||
) -> List[OntologyEdge]:
|
||||
"""Get relationships for a specific node."""
|
||||
|
||||
relationships = []
|
||||
visited = set()
|
||||
queue = deque([(node_id, 0)]) # (node_id, depth)
|
||||
|
||||
while queue:
|
||||
current_id, depth = queue.popleft()
|
||||
|
||||
if current_id in visited or depth > max_depth:
|
||||
continue
|
||||
|
||||
visited.add(current_id)
|
||||
|
||||
# Find edges where this node is source or target
|
||||
for edge in ontology.edges:
|
||||
if edge.source_id == current_id:
|
||||
relationships.append(edge)
|
||||
if depth < max_depth:
|
||||
queue.append((edge.target_id, depth + 1))
|
||||
|
||||
elif edge.target_id == current_id:
|
||||
relationships.append(edge)
|
||||
if depth < max_depth:
|
||||
queue.append((edge.source_id, depth + 1))
|
||||
|
||||
# Remove duplicates
|
||||
unique_relationships = []
|
||||
seen_edges = set()
|
||||
for rel in relationships:
|
||||
edge_key = (rel.source_id, rel.target_id, rel.relationship_type)
|
||||
if edge_key not in seen_edges:
|
||||
seen_edges.add(edge_key)
|
||||
unique_relationships.append(rel)
|
||||
|
||||
logger.debug(f"Found {len(unique_relationships)} relationships for node {node_id}")
|
||||
return unique_relationships
|
||||
|
||||
async def expand_subgraph(
|
||||
self,
|
||||
node_ids: List[str],
|
||||
ontology: OntologyGraph,
|
||||
directed: bool = True
|
||||
) -> Tuple[List[OntologyNode], List[OntologyEdge]]:
|
||||
"""Expand subgraph around given nodes."""
|
||||
|
||||
subgraph_nodes = []
|
||||
subgraph_edges = []
|
||||
|
||||
# Get all nodes in the node_ids list
|
||||
node_map = {node.id: node for node in ontology.nodes}
|
||||
included_node_ids = set(node_ids)
|
||||
|
||||
# Add initial nodes
|
||||
for node_id in node_ids:
|
||||
if node_id in node_map:
|
||||
subgraph_nodes.append(node_map[node_id])
|
||||
|
||||
# Find connected edges and nodes
|
||||
for edge in ontology.edges:
|
||||
include_edge = False
|
||||
|
||||
if directed:
|
||||
# Include edge if source is in our set
|
||||
if edge.source_id in included_node_ids:
|
||||
include_edge = True
|
||||
# Add target node if not already included
|
||||
if edge.target_id not in included_node_ids and edge.target_id in node_map:
|
||||
subgraph_nodes.append(node_map[edge.target_id])
|
||||
included_node_ids.add(edge.target_id)
|
||||
else:
|
||||
# Include edge if either source or target is in our set
|
||||
if edge.source_id in included_node_ids or edge.target_id in included_node_ids:
|
||||
include_edge = True
|
||||
# Add both nodes if not already included
|
||||
for node_id in [edge.source_id, edge.target_id]:
|
||||
if node_id not in included_node_ids and node_id in node_map:
|
||||
subgraph_nodes.append(node_map[node_id])
|
||||
included_node_ids.add(node_id)
|
||||
|
||||
if include_edge:
|
||||
subgraph_edges.append(edge)
|
||||
|
||||
logger.debug(f"Expanded subgraph with {len(subgraph_nodes)} nodes and {len(subgraph_edges)} edges")
|
||||
return subgraph_nodes, subgraph_edges
|
||||
|
||||
async def merge_ontologies(
|
||||
self,
|
||||
ontologies: List[OntologyGraph]
|
||||
) -> OntologyGraph:
|
||||
"""Merge multiple ontologies."""
|
||||
|
||||
if not ontologies:
|
||||
raise ValueError("No ontologies to merge")
|
||||
|
||||
if len(ontologies) == 1:
|
||||
return ontologies[0]
|
||||
|
||||
# Create merged ontology
|
||||
merged_nodes = []
|
||||
merged_edges = []
|
||||
merged_metadata = {}
|
||||
|
||||
# Keep track of node and edge IDs to avoid duplicates
|
||||
seen_node_ids = set()
|
||||
seen_edge_ids = set()
|
||||
|
||||
# Merge nodes
|
||||
for ontology in ontologies:
|
||||
for node in ontology.nodes:
|
||||
if node.id not in seen_node_ids:
|
||||
merged_nodes.append(node)
|
||||
seen_node_ids.add(node.id)
|
||||
else:
|
||||
# Handle duplicate nodes by merging properties
|
||||
existing_node = next(n for n in merged_nodes if n.id == node.id)
|
||||
existing_node.properties.update(node.properties)
|
||||
|
||||
# Merge edges
|
||||
for ontology in ontologies:
|
||||
for edge in ontology.edges:
|
||||
edge_key = (edge.source_id, edge.target_id, edge.relationship_type)
|
||||
if edge_key not in seen_edge_ids:
|
||||
merged_edges.append(edge)
|
||||
seen_edge_ids.add(edge_key)
|
||||
|
||||
# Merge metadata
|
||||
for ontology in ontologies:
|
||||
merged_metadata.update(ontology.metadata)
|
||||
|
||||
merged_metadata["merged_from"] = [ont.id for ont in ontologies]
|
||||
from datetime import datetime
|
||||
merged_metadata["merge_timestamp"] = datetime.now().isoformat()
|
||||
|
||||
merged_ontology = OntologyGraph(
|
||||
id=f"merged_{'_'.join([ont.id for ont in ontologies[:3]])}",
|
||||
name=f"Merged Ontology ({len(ontologies)} sources)",
|
||||
description=f"Merged from: {', '.join([ont.name for ont in ontologies])}",
|
||||
format=ontologies[0].format, # Use format of first ontology
|
||||
scope=ontologies[0].scope, # Use scope of first ontology
|
||||
nodes=merged_nodes,
|
||||
edges=merged_edges,
|
||||
metadata=merged_metadata
|
||||
)
|
||||
|
||||
logger.info(f"Merged {len(ontologies)} ontologies into one with {len(merged_nodes)} nodes and {len(merged_edges)} edges")
|
||||
return merged_ontology
|
||||
|
||||
def _calculate_similarity(self, text1: str, text2: str) -> float:
|
||||
"""Calculate similarity between two text strings."""
|
||||
if not text1 or not text2:
|
||||
return 0.0
|
||||
|
||||
# Use difflib for sequence similarity
|
||||
return difflib.SequenceMatcher(None, text1, text2).ratio()
|
||||
|
||||
|
||||
class SemanticOntologyAdapter(DefaultOntologyAdapter):
|
||||
"""Semantic-aware ontology adapter using embeddings (if available)."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.embeddings_available = False
|
||||
try:
|
||||
# Try to import embedding functionality
|
||||
from cognee.infrastructure.llm import get_embedding_engine
|
||||
self.get_embedding_engine = get_embedding_engine
|
||||
self.embeddings_available = True
|
||||
except ImportError:
|
||||
logger.warning("Embedding engine not available, falling back to string similarity")
|
||||
|
||||
async def find_matching_nodes(
|
||||
self,
|
||||
query_text: str,
|
||||
ontology: OntologyGraph,
|
||||
similarity_threshold: float = 0.8
|
||||
) -> List[OntologyNode]:
|
||||
"""Find nodes using semantic similarity if embeddings are available."""
|
||||
|
||||
if not self.embeddings_available:
|
||||
return await super().find_matching_nodes(query_text, ontology, similarity_threshold)
|
||||
|
||||
try:
|
||||
# Get embedding for query
|
||||
embedding_engine = await self.get_embedding_engine()
|
||||
query_embedding = await embedding_engine.embed_text(query_text)
|
||||
|
||||
matching_nodes = []
|
||||
|
||||
for node in ontology.nodes:
|
||||
# Create node text for embedding
|
||||
node_text = f"{node.name} {node.description or ''}"
|
||||
for prop_value in node.properties.values():
|
||||
if isinstance(prop_value, str):
|
||||
node_text += f" {prop_value}"
|
||||
|
||||
# Get node embedding
|
||||
node_embedding = await embedding_engine.embed_text(node_text)
|
||||
|
||||
# Calculate cosine similarity
|
||||
similarity = self._cosine_similarity(query_embedding, node_embedding)
|
||||
|
||||
if similarity >= similarity_threshold:
|
||||
node_copy = OntologyNode(**node.dict())
|
||||
node_copy.properties["_similarity_score"] = similarity
|
||||
matching_nodes.append(node_copy)
|
||||
|
||||
# Sort by similarity
|
||||
matching_nodes.sort(key=lambda n: n.properties.get("_similarity_score", 0), reverse=True)
|
||||
|
||||
logger.debug(f"Found {len(matching_nodes)} nodes using semantic similarity")
|
||||
return matching_nodes
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Semantic similarity failed, falling back to string matching: {e}")
|
||||
return await super().find_matching_nodes(query_text, ontology, similarity_threshold)
|
||||
|
||||
def _cosine_similarity(self, vec1, vec2) -> float:
|
||||
"""Calculate cosine similarity between two vectors."""
|
||||
import numpy as np
|
||||
|
||||
vec1 = np.array(vec1)
|
||||
vec2 = np.array(vec2)
|
||||
|
||||
dot_product = np.dot(vec1, vec2)
|
||||
norm1 = np.linalg.norm(vec1)
|
||||
norm2 = np.linalg.norm(vec2)
|
||||
|
||||
if norm1 == 0 or norm2 == 0:
|
||||
return 0.0
|
||||
|
||||
return dot_product / (norm1 * norm2)
|
||||
|
||||
|
||||
class GraphOntologyAdapter(DefaultOntologyAdapter):
|
||||
"""Adapter specialized for graph-based operations."""
|
||||
|
||||
async def get_node_relationships(
|
||||
self,
|
||||
node_id: str,
|
||||
ontology: OntologyGraph,
|
||||
max_depth: int = 2
|
||||
) -> List[OntologyEdge]:
|
||||
"""Enhanced relationship discovery with graph algorithms."""
|
||||
|
||||
# Build adjacency lists for faster traversal
|
||||
outgoing = {}
|
||||
incoming = {}
|
||||
|
||||
for edge in ontology.edges:
|
||||
if edge.source_id not in outgoing:
|
||||
outgoing[edge.source_id] = []
|
||||
outgoing[edge.source_id].append(edge)
|
||||
|
||||
if edge.target_id not in incoming:
|
||||
incoming[edge.target_id] = []
|
||||
incoming[edge.target_id].append(edge)
|
||||
|
||||
# BFS traversal
|
||||
relationships = []
|
||||
visited = set()
|
||||
queue = deque([(node_id, 0)])
|
||||
|
||||
while queue:
|
||||
current_id, depth = queue.popleft()
|
||||
|
||||
if current_id in visited or depth > max_depth:
|
||||
continue
|
||||
|
||||
visited.add(current_id)
|
||||
|
||||
# Add outgoing edges
|
||||
for edge in outgoing.get(current_id, []):
|
||||
relationships.append(edge)
|
||||
if depth < max_depth:
|
||||
queue.append((edge.target_id, depth + 1))
|
||||
|
||||
# Add incoming edges
|
||||
for edge in incoming.get(current_id, []):
|
||||
relationships.append(edge)
|
||||
if depth < max_depth:
|
||||
queue.append((edge.source_id, depth + 1))
|
||||
|
||||
# Remove duplicates and sort by relevance
|
||||
unique_relationships = list({edge.id: edge for edge in relationships}.values())
|
||||
|
||||
# Sort by edge weight if available, then by relationship type
|
||||
unique_relationships.sort(
|
||||
key=lambda e: (e.weight or 0, e.relationship_type),
|
||||
reverse=True
|
||||
)
|
||||
|
||||
return unique_relationships
|
||||
|
||||
async def find_shortest_path(
|
||||
self,
|
||||
source_id: str,
|
||||
target_id: str,
|
||||
ontology: OntologyGraph
|
||||
) -> List[OntologyEdge]:
|
||||
"""Find shortest path between two nodes."""
|
||||
|
||||
# Build graph
|
||||
graph = {}
|
||||
for edge in ontology.edges:
|
||||
if edge.source_id not in graph:
|
||||
graph[edge.source_id] = []
|
||||
graph[edge.source_id].append((edge.target_id, edge))
|
||||
|
||||
# BFS for shortest path
|
||||
queue = deque([(source_id, [])])
|
||||
visited = set()
|
||||
|
||||
while queue:
|
||||
current_id, path = queue.popleft()
|
||||
|
||||
if current_id == target_id:
|
||||
return path
|
||||
|
||||
if current_id in visited:
|
||||
continue
|
||||
|
||||
visited.add(current_id)
|
||||
|
||||
for neighbor_id, edge in graph.get(current_id, []):
|
||||
if neighbor_id not in visited:
|
||||
queue.append((neighbor_id, path + [edge]))
|
||||
|
||||
return [] # No path found
|
||||
438
cognee/modules/ontology/binders.py
Normal file
438
cognee/modules/ontology/binders.py
Normal file
|
|
@ -0,0 +1,438 @@
|
|||
"""Graph binding implementations for ontologies."""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple, Callable
|
||||
from uuid import uuid4
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.ontology.interfaces import (
|
||||
IGraphBinder,
|
||||
OntologyGraph,
|
||||
OntologyNode,
|
||||
OntologyEdge,
|
||||
GraphBindingConfig,
|
||||
OntologyContext,
|
||||
)
|
||||
|
||||
logger = get_logger("GraphBinder")
|
||||
|
||||
|
||||
class DefaultGraphBinder(IGraphBinder):
|
||||
"""Default implementation for binding ontology to graph structures."""
|
||||
|
||||
def __init__(self):
|
||||
self.custom_strategies: Dict[str, Callable] = {}
|
||||
|
||||
async def bind_ontology_to_graph(
|
||||
self,
|
||||
ontology: OntologyGraph,
|
||||
binding_config: GraphBindingConfig,
|
||||
context: Optional[OntologyContext] = None
|
||||
) -> Tuple[List[Any], List[Any]]: # (graph_nodes, graph_edges)
|
||||
"""Bind ontology to graph structure."""
|
||||
|
||||
# Use custom binding strategy if specified
|
||||
if binding_config.custom_binding_strategy and binding_config.custom_binding_strategy in self.custom_strategies:
|
||||
return await self._apply_custom_strategy(ontology, binding_config, context)
|
||||
|
||||
# Use default binding logic
|
||||
return await self._default_binding(ontology, binding_config, context)
|
||||
|
||||
async def transform_node_properties(
|
||||
self,
|
||||
node: OntologyNode,
|
||||
transformations: Dict[str, Callable[[Any], Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""Transform node properties according to binding config."""
|
||||
|
||||
transformed_props = {}
|
||||
|
||||
# Start with node's base properties
|
||||
base_props = {
|
||||
"id": node.id,
|
||||
"name": node.name,
|
||||
"type": node.type,
|
||||
"description": node.description,
|
||||
"category": node.category,
|
||||
}
|
||||
|
||||
# Add custom properties
|
||||
base_props.update(node.properties)
|
||||
|
||||
# Apply transformations
|
||||
for prop_name, prop_value in base_props.items():
|
||||
if prop_name in transformations:
|
||||
try:
|
||||
transformed_props[prop_name] = transformations[prop_name](prop_value)
|
||||
except Exception as e:
|
||||
logger.warning(f"Transformation failed for property {prop_name}: {e}")
|
||||
transformed_props[prop_name] = prop_value
|
||||
else:
|
||||
transformed_props[prop_name] = prop_value
|
||||
|
||||
# Add standard graph properties
|
||||
transformed_props.update({
|
||||
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"ontology_source": True,
|
||||
"ontology_id": node.id,
|
||||
})
|
||||
|
||||
return transformed_props
|
||||
|
||||
async def transform_edge_properties(
|
||||
self,
|
||||
edge: OntologyEdge,
|
||||
transformations: Dict[str, Callable[[Any], Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""Transform edge properties according to binding config."""
|
||||
|
||||
transformed_props = {}
|
||||
|
||||
# Start with edge's base properties
|
||||
base_props = {
|
||||
"source_node_id": edge.source_id,
|
||||
"target_node_id": edge.target_id,
|
||||
"relationship_name": edge.relationship_type,
|
||||
"weight": edge.weight,
|
||||
}
|
||||
|
||||
# Add custom properties
|
||||
base_props.update(edge.properties)
|
||||
|
||||
# Apply transformations
|
||||
for prop_name, prop_value in base_props.items():
|
||||
if prop_name in transformations:
|
||||
try:
|
||||
transformed_props[prop_name] = transformations[prop_name](prop_value)
|
||||
except Exception as e:
|
||||
logger.warning(f"Transformation failed for edge property {prop_name}: {e}")
|
||||
transformed_props[prop_name] = prop_value
|
||||
else:
|
||||
transformed_props[prop_name] = prop_value
|
||||
|
||||
# Add standard graph properties
|
||||
transformed_props.update({
|
||||
"updated_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"ontology_source": True,
|
||||
"ontology_edge_id": edge.id,
|
||||
})
|
||||
|
||||
return transformed_props
|
||||
|
||||
def register_binding_strategy(
|
||||
self,
|
||||
strategy_name: str,
|
||||
strategy_func: Callable[[OntologyGraph, GraphBindingConfig], Tuple[List[Any], List[Any]]]
|
||||
) -> None:
|
||||
"""Register a custom binding strategy."""
|
||||
self.custom_strategies[strategy_name] = strategy_func
|
||||
logger.info(f"Registered custom binding strategy: {strategy_name}")
|
||||
|
||||
async def _apply_custom_strategy(
|
||||
self,
|
||||
ontology: OntologyGraph,
|
||||
binding_config: GraphBindingConfig,
|
||||
context: Optional[OntologyContext] = None
|
||||
) -> Tuple[List[Any], List[Any]]:
|
||||
"""Apply custom binding strategy."""
|
||||
|
||||
strategy_func = self.custom_strategies[binding_config.custom_binding_strategy]
|
||||
|
||||
try:
|
||||
if context:
|
||||
return await strategy_func(ontology, binding_config, context)
|
||||
else:
|
||||
return await strategy_func(ontology, binding_config)
|
||||
except Exception as e:
|
||||
logger.error(f"Custom binding strategy failed: {e}")
|
||||
# Fallback to default binding
|
||||
return await self._default_binding(ontology, binding_config, context)
|
||||
|
||||
async def _default_binding(
|
||||
self,
|
||||
ontology: OntologyGraph,
|
||||
binding_config: GraphBindingConfig,
|
||||
context: Optional[OntologyContext] = None
|
||||
) -> Tuple[List[Any], List[Any]]:
|
||||
"""Default binding logic."""
|
||||
|
||||
graph_nodes = []
|
||||
graph_edges = []
|
||||
|
||||
# Process nodes
|
||||
for node in ontology.nodes:
|
||||
try:
|
||||
graph_node = await self._bind_node_to_graph(node, binding_config)
|
||||
if graph_node:
|
||||
graph_nodes.append(graph_node)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to bind node {node.id}: {e}")
|
||||
|
||||
# Process edges
|
||||
for edge in ontology.edges:
|
||||
try:
|
||||
graph_edge = await self._bind_edge_to_graph(edge, binding_config)
|
||||
if graph_edge:
|
||||
graph_edges.append(graph_edge)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to bind edge {edge.id}: {e}")
|
||||
|
||||
logger.info(f"Bound {len(graph_nodes)} nodes and {len(graph_edges)} edges to graph")
|
||||
return graph_nodes, graph_edges
|
||||
|
||||
async def _bind_node_to_graph(
|
||||
self,
|
||||
node: OntologyNode,
|
||||
binding_config: GraphBindingConfig
|
||||
) -> Tuple[str, Dict[str, Any]]:
|
||||
"""Bind a single node to graph format."""
|
||||
|
||||
# Map node type if configured
|
||||
graph_node_type = binding_config.node_type_mapping.get(node.type, node.type)
|
||||
|
||||
# Transform properties
|
||||
node_properties = await self.transform_node_properties(
|
||||
node, binding_config.property_transformations
|
||||
)
|
||||
|
||||
# Set the mapped type
|
||||
node_properties["type"] = graph_node_type
|
||||
|
||||
# Generate node ID for graph (use ontology ID as base)
|
||||
node_id = node.id
|
||||
|
||||
return (node_id, node_properties)
|
||||
|
||||
async def _bind_edge_to_graph(
|
||||
self,
|
||||
edge: OntologyEdge,
|
||||
binding_config: GraphBindingConfig
|
||||
) -> Tuple[str, str, str, Dict[str, Any]]:
|
||||
"""Bind a single edge to graph format."""
|
||||
|
||||
# Map edge type if configured
|
||||
graph_edge_type = binding_config.edge_type_mapping.get(
|
||||
edge.relationship_type, edge.relationship_type
|
||||
)
|
||||
|
||||
# Transform properties
|
||||
edge_properties = await self.transform_edge_properties(
|
||||
edge, binding_config.property_transformations
|
||||
)
|
||||
|
||||
# Set the mapped relationship name
|
||||
edge_properties["relationship_name"] = graph_edge_type
|
||||
|
||||
return (
|
||||
edge.source_id,
|
||||
edge.target_id,
|
||||
graph_edge_type,
|
||||
edge_properties
|
||||
)
|
||||
|
||||
|
||||
class KnowledgeGraphBinder(DefaultGraphBinder):
|
||||
"""Specialized binder for KnowledgeGraph format."""
|
||||
|
||||
async def _bind_node_to_graph(
|
||||
self,
|
||||
node: OntologyNode,
|
||||
binding_config: GraphBindingConfig
|
||||
) -> Dict[str, Any]:
|
||||
"""Bind node to KnowledgeGraph Node format."""
|
||||
|
||||
# Transform properties
|
||||
node_properties = await self.transform_node_properties(
|
||||
node, binding_config.property_transformations
|
||||
)
|
||||
|
||||
# Create KnowledgeGraph-compatible node
|
||||
from cognee.shared.data_models import Node
|
||||
|
||||
kg_node = Node(
|
||||
id=node.id,
|
||||
name=node.name,
|
||||
type=binding_config.node_type_mapping.get(node.type, node.type),
|
||||
description=node.description or "",
|
||||
)
|
||||
|
||||
return kg_node
|
||||
|
||||
async def _bind_edge_to_graph(
|
||||
self,
|
||||
edge: OntologyEdge,
|
||||
binding_config: GraphBindingConfig
|
||||
) -> Dict[str, Any]:
|
||||
"""Bind edge to KnowledgeGraph Edge format."""
|
||||
|
||||
from cognee.shared.data_models import Edge
|
||||
|
||||
# Map edge type
|
||||
relationship_name = binding_config.edge_type_mapping.get(
|
||||
edge.relationship_type, edge.relationship_type
|
||||
)
|
||||
|
||||
kg_edge = Edge(
|
||||
source_node_id=edge.source_id,
|
||||
target_node_id=edge.target_id,
|
||||
relationship_name=relationship_name,
|
||||
)
|
||||
|
||||
return kg_edge
|
||||
|
||||
|
||||
class DataPointGraphBinder(DefaultGraphBinder):
|
||||
"""Specialized binder for DataPoint-based graphs."""
|
||||
|
||||
async def _bind_node_to_graph(
|
||||
self,
|
||||
node: OntologyNode,
|
||||
binding_config: GraphBindingConfig
|
||||
) -> Any: # DataPoint
|
||||
"""Bind node to DataPoint instance."""
|
||||
|
||||
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
||||
|
||||
# Transform properties
|
||||
node_properties = await self.transform_node_properties(
|
||||
node, binding_config.property_transformations
|
||||
)
|
||||
|
||||
# Create DataPoint instance
|
||||
datapoint = DataPoint(
|
||||
id=node.id,
|
||||
type=binding_config.node_type_mapping.get(node.type, node.type),
|
||||
ontology_valid=True,
|
||||
metadata={
|
||||
"type": node.type,
|
||||
"index_fields": ["name", "type"],
|
||||
"ontology_source": True,
|
||||
}
|
||||
)
|
||||
|
||||
# Add custom attributes
|
||||
for prop_name, prop_value in node_properties.items():
|
||||
if not hasattr(datapoint, prop_name):
|
||||
setattr(datapoint, prop_name, prop_value)
|
||||
|
||||
return datapoint
|
||||
|
||||
|
||||
class DomainSpecificBinder(DefaultGraphBinder):
|
||||
"""Domain-specific graph binder with specialized transformation logic."""
|
||||
|
||||
def __init__(self, domain: str):
|
||||
super().__init__()
|
||||
self.domain = domain
|
||||
|
||||
async def transform_node_properties(
|
||||
self,
|
||||
node: OntologyNode,
|
||||
transformations: Dict[str, Callable[[Any], Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""Apply domain-specific node transformations."""
|
||||
|
||||
# Apply parent transformations first
|
||||
props = await super().transform_node_properties(node, transformations)
|
||||
|
||||
# Apply domain-specific transformations
|
||||
if self.domain == "medical":
|
||||
props = await self._apply_medical_node_transforms(node, props)
|
||||
elif self.domain == "legal":
|
||||
props = await self._apply_legal_node_transforms(node, props)
|
||||
elif self.domain == "code":
|
||||
props = await self._apply_code_node_transforms(node, props)
|
||||
|
||||
return props
|
||||
|
||||
async def transform_edge_properties(
|
||||
self,
|
||||
edge: OntologyEdge,
|
||||
transformations: Dict[str, Callable[[Any], Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""Apply domain-specific edge transformations."""
|
||||
|
||||
# Apply parent transformations first
|
||||
props = await super().transform_edge_properties(edge, transformations)
|
||||
|
||||
# Apply domain-specific transformations
|
||||
if self.domain == "medical":
|
||||
props = await self._apply_medical_edge_transforms(edge, props)
|
||||
elif self.domain == "legal":
|
||||
props = await self._apply_legal_edge_transforms(edge, props)
|
||||
elif self.domain == "code":
|
||||
props = await self._apply_code_edge_transforms(edge, props)
|
||||
|
||||
return props
|
||||
|
||||
async def _apply_medical_node_transforms(
|
||||
self, node: OntologyNode, props: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Apply medical domain node transformations."""
|
||||
if node.type == "Disease":
|
||||
props["medical_category"] = "pathology"
|
||||
props["severity_level"] = node.properties.get("severity", "unknown")
|
||||
elif node.type == "Symptom":
|
||||
props["medical_category"] = "clinical_sign"
|
||||
props["frequency"] = node.properties.get("frequency", "unknown")
|
||||
|
||||
return props
|
||||
|
||||
async def _apply_legal_node_transforms(
|
||||
self, node: OntologyNode, props: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Apply legal domain node transformations."""
|
||||
if node.type == "Law":
|
||||
props["legal_authority"] = node.properties.get("jurisdiction", "unknown")
|
||||
props["enforcement_level"] = node.properties.get("level", "federal")
|
||||
elif node.type == "Case":
|
||||
props["court_level"] = node.properties.get("court", "unknown")
|
||||
props["precedent_value"] = node.properties.get("binding", "persuasive")
|
||||
|
||||
return props
|
||||
|
||||
async def _apply_code_node_transforms(
|
||||
self, node: OntologyNode, props: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Apply code domain node transformations."""
|
||||
if node.type == "Function":
|
||||
props["complexity"] = len(node.properties.get("parameters", []))
|
||||
props["visibility"] = node.properties.get("access_modifier", "public")
|
||||
elif node.type == "Class":
|
||||
props["inheritance_depth"] = node.properties.get("depth", 0)
|
||||
props["method_count"] = len(node.properties.get("methods", []))
|
||||
|
||||
return props
|
||||
|
||||
async def _apply_medical_edge_transforms(
|
||||
self, edge: OntologyEdge, props: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Apply medical domain edge transformations."""
|
||||
if edge.relationship_type == "causes":
|
||||
props["causality_strength"] = edge.properties.get("strength", "unknown")
|
||||
elif edge.relationship_type == "treats":
|
||||
props["efficacy"] = edge.properties.get("effectiveness", "unknown")
|
||||
|
||||
return props
|
||||
|
||||
async def _apply_legal_edge_transforms(
|
||||
self, edge: OntologyEdge, props: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Apply legal domain edge transformations."""
|
||||
if edge.relationship_type == "cites":
|
||||
props["citation_type"] = edge.properties.get("type", "supporting")
|
||||
elif edge.relationship_type == "overrules":
|
||||
props["authority_level"] = edge.properties.get("level", "same")
|
||||
|
||||
return props
|
||||
|
||||
async def _apply_code_edge_transforms(
|
||||
self, edge: OntologyEdge, props: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Apply code domain edge transformations."""
|
||||
if edge.relationship_type == "calls":
|
||||
props["call_frequency"] = edge.properties.get("frequency", 1)
|
||||
elif edge.relationship_type == "inherits":
|
||||
props["inheritance_type"] = edge.properties.get("type", "extends")
|
||||
|
||||
return props
|
||||
95
cognee/modules/ontology/config.py
Normal file
95
cognee/modules/ontology/config.py
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
"""Ontology configuration following Cognee patterns."""
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
from functools import lru_cache
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
logger = get_logger("ontology.config")
|
||||
|
||||
|
||||
class OntologyConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for the ontology system.
|
||||
|
||||
Follows Cognee's BaseSettings pattern with environment variable support.
|
||||
"""
|
||||
|
||||
# Default ontology settings
|
||||
default_format: str = "json"
|
||||
enable_semantic_search: bool = False
|
||||
registry_type: str = "memory" # "memory" or "database"
|
||||
|
||||
# Provider settings
|
||||
rdf_provider_enabled: bool = True
|
||||
json_provider_enabled: bool = True
|
||||
csv_provider_enabled: bool = True
|
||||
|
||||
# Performance settings
|
||||
cache_ontologies: bool = True
|
||||
max_cache_size: int = 100
|
||||
similarity_threshold: float = 0.8
|
||||
|
||||
# Domain-specific settings
|
||||
medical_domain_enabled: bool = True
|
||||
legal_domain_enabled: bool = True
|
||||
code_domain_enabled: bool = True
|
||||
|
||||
# File paths
|
||||
ontology_data_directory: str = os.path.join(
|
||||
os.getenv("COGNEE_DATA_ROOT", ".data_storage"), "ontologies"
|
||||
)
|
||||
default_config_file: Optional[str] = None
|
||||
|
||||
# Environment variables
|
||||
ontology_api_key: Optional[str] = os.getenv("ONTOLOGY_API_KEY")
|
||||
ontology_endpoint: Optional[str] = os.getenv("ONTOLOGY_ENDPOINT")
|
||||
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert configuration to dictionary following Cognee pattern."""
|
||||
return {
|
||||
"default_format": self.default_format,
|
||||
"enable_semantic_search": self.enable_semantic_search,
|
||||
"registry_type": self.registry_type,
|
||||
"rdf_provider_enabled": self.rdf_provider_enabled,
|
||||
"json_provider_enabled": self.json_provider_enabled,
|
||||
"csv_provider_enabled": self.csv_provider_enabled,
|
||||
"cache_ontologies": self.cache_ontologies,
|
||||
"max_cache_size": self.max_cache_size,
|
||||
"similarity_threshold": self.similarity_threshold,
|
||||
"medical_domain_enabled": self.medical_domain_enabled,
|
||||
"legal_domain_enabled": self.legal_domain_enabled,
|
||||
"code_domain_enabled": self.code_domain_enabled,
|
||||
"ontology_data_directory": self.ontology_data_directory,
|
||||
}
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_ontology_config():
|
||||
"""Get ontology configuration instance following Cognee pattern."""
|
||||
return OntologyConfig()
|
||||
|
||||
|
||||
# Configuration helpers following existing patterns
|
||||
def set_ontology_data_directory(directory: str):
|
||||
"""Set ontology data directory."""
|
||||
config = get_ontology_config()
|
||||
config.ontology_data_directory = directory
|
||||
logger.info(f"Set ontology data directory to: {directory}")
|
||||
|
||||
|
||||
def enable_semantic_search(enabled: bool = True):
|
||||
"""Enable or disable semantic search."""
|
||||
config = get_ontology_config()
|
||||
config.enable_semantic_search = enabled
|
||||
logger.info(f"Semantic search {'enabled' if enabled else 'disabled'}")
|
||||
|
||||
|
||||
def set_similarity_threshold(threshold: float):
|
||||
"""Set similarity threshold for entity matching."""
|
||||
config = get_ontology_config()
|
||||
config.similarity_threshold = threshold
|
||||
logger.info(f"Set similarity threshold to: {threshold}")
|
||||
364
cognee/modules/ontology/configuration.py
Normal file
364
cognee/modules/ontology/configuration.py
Normal file
|
|
@ -0,0 +1,364 @@
|
|||
"""Configuration system for ontology integration."""
|
||||
|
||||
from typing import Dict, List, Optional, Any, Callable
|
||||
from pathlib import Path
|
||||
import json
|
||||
import yaml
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.ontology.interfaces import (
|
||||
DataPointMapping,
|
||||
GraphBindingConfig,
|
||||
OntologyFormat,
|
||||
OntologyScope,
|
||||
)
|
||||
from cognee.modules.ontology.pipeline_integration import (
|
||||
create_medical_pipeline_config,
|
||||
create_legal_pipeline_config,
|
||||
create_code_pipeline_config,
|
||||
)
|
||||
|
||||
logger = get_logger("OntologyConfiguration")
|
||||
|
||||
|
||||
class OntologyConfiguration:
|
||||
"""Central configuration management for ontology system."""
|
||||
|
||||
def __init__(self):
|
||||
self.domain_configs: Dict[str, Dict[str, Any]] = {}
|
||||
self.pipeline_configs: Dict[str, Dict[str, Any]] = {}
|
||||
self.custom_resolvers: Dict[str, Callable] = {}
|
||||
self.custom_binding_strategies: Dict[str, Callable] = {}
|
||||
|
||||
# Load default configurations
|
||||
self._load_default_configs()
|
||||
|
||||
def _load_default_configs(self):
|
||||
"""Load default domain configurations."""
|
||||
self.domain_configs.update({
|
||||
"medical": create_medical_pipeline_config(),
|
||||
"legal": create_legal_pipeline_config(),
|
||||
"code": create_code_pipeline_config(),
|
||||
})
|
||||
|
||||
def register_domain_config(
|
||||
self,
|
||||
domain: str,
|
||||
datapoint_mappings: List[DataPointMapping],
|
||||
graph_binding: GraphBindingConfig,
|
||||
task_configs: Optional[Dict[str, Dict[str, Any]]] = None
|
||||
) -> None:
|
||||
"""Register configuration for a domain."""
|
||||
|
||||
self.domain_configs[domain] = {
|
||||
"domain": domain,
|
||||
"datapoint_mappings": datapoint_mappings,
|
||||
"graph_binding": graph_binding,
|
||||
"task_configs": task_configs or {},
|
||||
}
|
||||
|
||||
logger.info(f"Registered domain configuration: {domain}")
|
||||
|
||||
def register_pipeline_config(
|
||||
self,
|
||||
pipeline_name: str,
|
||||
domain: str,
|
||||
custom_mappings: Optional[List[DataPointMapping]] = None,
|
||||
custom_binding: Optional[GraphBindingConfig] = None,
|
||||
task_configs: Optional[Dict[str, Dict[str, Any]]] = None
|
||||
) -> None:
|
||||
"""Register configuration for a specific pipeline."""
|
||||
|
||||
# Start with domain config if available
|
||||
base_config = self.domain_configs.get(domain, {})
|
||||
|
||||
pipeline_config = {
|
||||
"pipeline_name": pipeline_name,
|
||||
"domain": domain,
|
||||
"datapoint_mappings": custom_mappings or base_config.get("datapoint_mappings", []),
|
||||
"graph_binding": custom_binding or base_config.get("graph_binding", GraphBindingConfig()),
|
||||
"task_configs": {**base_config.get("task_configs", {}), **(task_configs or {})},
|
||||
}
|
||||
|
||||
self.pipeline_configs[pipeline_name] = pipeline_config
|
||||
logger.info(f"Registered pipeline configuration: {pipeline_name}")
|
||||
|
||||
def get_domain_config(self, domain: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get configuration for a domain."""
|
||||
return self.domain_configs.get(domain)
|
||||
|
||||
def get_pipeline_config(self, pipeline_name: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get configuration for a pipeline."""
|
||||
return self.pipeline_configs.get(pipeline_name)
|
||||
|
||||
def load_from_file(self, config_file: str) -> None:
|
||||
"""Load configuration from file."""
|
||||
|
||||
config_path = Path(config_file)
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"Configuration file not found: {config_file}")
|
||||
|
||||
if config_path.suffix.lower() == '.json':
|
||||
with open(config_path) as f:
|
||||
config_data = json.load(f)
|
||||
elif config_path.suffix.lower() in ['.yml', '.yaml']:
|
||||
with open(config_path) as f:
|
||||
config_data = yaml.safe_load(f)
|
||||
else:
|
||||
raise ValueError(f"Unsupported configuration file format: {config_path.suffix}")
|
||||
|
||||
self._parse_config_data(config_data)
|
||||
logger.info(f"Loaded configuration from {config_file}")
|
||||
|
||||
def save_to_file(self, config_file: str, format: str = "json") -> None:
|
||||
"""Save configuration to file."""
|
||||
|
||||
config_data = {
|
||||
"domains": self._serialize_domain_configs(),
|
||||
"pipelines": self._serialize_pipeline_configs(),
|
||||
}
|
||||
|
||||
config_path = Path(config_file)
|
||||
|
||||
if format.lower() == "json":
|
||||
with open(config_path, 'w') as f:
|
||||
json.dump(config_data, f, indent=2)
|
||||
elif format.lower() in ["yml", "yaml"]:
|
||||
with open(config_path, 'w') as f:
|
||||
yaml.dump(config_data, f, default_flow_style=False)
|
||||
else:
|
||||
raise ValueError(f"Unsupported format: {format}")
|
||||
|
||||
logger.info(f"Saved configuration to {config_file}")
|
||||
|
||||
def register_custom_resolver(
|
||||
self,
|
||||
resolver_name: str,
|
||||
resolver_func: Callable
|
||||
) -> None:
|
||||
"""Register a custom DataPoint resolver."""
|
||||
self.custom_resolvers[resolver_name] = resolver_func
|
||||
logger.info(f"Registered custom resolver: {resolver_name}")
|
||||
|
||||
def register_custom_binding_strategy(
|
||||
self,
|
||||
strategy_name: str,
|
||||
strategy_func: Callable
|
||||
) -> None:
|
||||
"""Register a custom graph binding strategy."""
|
||||
self.custom_binding_strategies[strategy_name] = strategy_func
|
||||
logger.info(f"Registered custom binding strategy: {strategy_name}")
|
||||
|
||||
def create_datapoint_mapping(
|
||||
self,
|
||||
ontology_node_type: str,
|
||||
datapoint_class: str,
|
||||
field_mappings: Optional[Dict[str, str]] = None,
|
||||
custom_resolver: Optional[str] = None,
|
||||
validation_rules: Optional[List[str]] = None
|
||||
) -> DataPointMapping:
|
||||
"""Create a DataPointMapping configuration."""
|
||||
|
||||
return DataPointMapping(
|
||||
ontology_node_type=ontology_node_type,
|
||||
datapoint_class=datapoint_class,
|
||||
field_mappings=field_mappings or {},
|
||||
custom_resolver=custom_resolver,
|
||||
validation_rules=validation_rules or []
|
||||
)
|
||||
|
||||
def create_graph_binding_config(
|
||||
self,
|
||||
node_type_mapping: Optional[Dict[str, str]] = None,
|
||||
edge_type_mapping: Optional[Dict[str, str]] = None,
|
||||
property_transformations: Optional[Dict[str, Callable]] = None,
|
||||
custom_binding_strategy: Optional[str] = None
|
||||
) -> GraphBindingConfig:
|
||||
"""Create a GraphBindingConfig."""
|
||||
|
||||
return GraphBindingConfig(
|
||||
node_type_mapping=node_type_mapping or {},
|
||||
edge_type_mapping=edge_type_mapping or {},
|
||||
property_transformations=property_transformations or {},
|
||||
custom_binding_strategy=custom_binding_strategy
|
||||
)
|
||||
|
||||
def _parse_config_data(self, config_data: Dict[str, Any]) -> None:
|
||||
"""Parse configuration data from file."""
|
||||
|
||||
# Parse domain configurations
|
||||
for domain, domain_config in config_data.get("domains", {}).items():
|
||||
mappings = []
|
||||
for mapping_data in domain_config.get("datapoint_mappings", []):
|
||||
mapping = DataPointMapping(**mapping_data)
|
||||
mappings.append(mapping)
|
||||
|
||||
binding_data = domain_config.get("graph_binding", {})
|
||||
binding = GraphBindingConfig(**binding_data)
|
||||
|
||||
task_configs = domain_config.get("task_configs", {})
|
||||
|
||||
self.register_domain_config(domain, mappings, binding, task_configs)
|
||||
|
||||
# Parse pipeline configurations
|
||||
for pipeline_name, pipeline_config in config_data.get("pipelines", {}).items():
|
||||
domain = pipeline_config.get("domain")
|
||||
|
||||
custom_mappings = None
|
||||
if "datapoint_mappings" in pipeline_config:
|
||||
custom_mappings = []
|
||||
for mapping_data in pipeline_config["datapoint_mappings"]:
|
||||
mapping = DataPointMapping(**mapping_data)
|
||||
custom_mappings.append(mapping)
|
||||
|
||||
custom_binding = None
|
||||
if "graph_binding" in pipeline_config:
|
||||
binding_data = pipeline_config["graph_binding"]
|
||||
custom_binding = GraphBindingConfig(**binding_data)
|
||||
|
||||
task_configs = pipeline_config.get("task_configs", {})
|
||||
|
||||
self.register_pipeline_config(
|
||||
pipeline_name, domain, custom_mappings, custom_binding, task_configs
|
||||
)
|
||||
|
||||
def _serialize_domain_configs(self) -> Dict[str, Any]:
|
||||
"""Serialize domain configurations for saving."""
|
||||
serialized = {}
|
||||
|
||||
for domain, config in self.domain_configs.items():
|
||||
serialized[domain] = {
|
||||
"domain": config["domain"],
|
||||
"datapoint_mappings": [
|
||||
mapping.dict() for mapping in config["datapoint_mappings"]
|
||||
],
|
||||
"graph_binding": config["graph_binding"].dict(),
|
||||
"task_configs": config["task_configs"],
|
||||
}
|
||||
|
||||
return serialized
|
||||
|
||||
def _serialize_pipeline_configs(self) -> Dict[str, Any]:
|
||||
"""Serialize pipeline configurations for saving."""
|
||||
serialized = {}
|
||||
|
||||
for pipeline_name, config in self.pipeline_configs.items():
|
||||
serialized[pipeline_name] = {
|
||||
"pipeline_name": config["pipeline_name"],
|
||||
"domain": config["domain"],
|
||||
"datapoint_mappings": [
|
||||
mapping.dict() for mapping in config["datapoint_mappings"]
|
||||
],
|
||||
"graph_binding": config["graph_binding"].dict(),
|
||||
"task_configs": config["task_configs"],
|
||||
}
|
||||
|
||||
return serialized
|
||||
|
||||
|
||||
# Global configuration instance
|
||||
_global_ontology_config = None
|
||||
|
||||
|
||||
def get_ontology_config() -> OntologyConfiguration:
|
||||
"""Get the global ontology configuration instance."""
|
||||
global _global_ontology_config
|
||||
if _global_ontology_config is None:
|
||||
_global_ontology_config = OntologyConfiguration()
|
||||
return _global_ontology_config
|
||||
|
||||
|
||||
def configure_domain(
|
||||
domain: str,
|
||||
datapoint_mappings: List[DataPointMapping],
|
||||
graph_binding: GraphBindingConfig,
|
||||
task_configs: Optional[Dict[str, Dict[str, Any]]] = None
|
||||
) -> None:
|
||||
"""Configure ontology for a domain (convenience function)."""
|
||||
config = get_ontology_config()
|
||||
config.register_domain_config(domain, datapoint_mappings, graph_binding, task_configs)
|
||||
|
||||
|
||||
def configure_pipeline(
|
||||
pipeline_name: str,
|
||||
domain: str,
|
||||
custom_mappings: Optional[List[DataPointMapping]] = None,
|
||||
custom_binding: Optional[GraphBindingConfig] = None,
|
||||
task_configs: Optional[Dict[str, Dict[str, Any]]] = None
|
||||
) -> None:
|
||||
"""Configure ontology for a pipeline (convenience function)."""
|
||||
config = get_ontology_config()
|
||||
config.register_pipeline_config(
|
||||
pipeline_name, domain, custom_mappings, custom_binding, task_configs
|
||||
)
|
||||
|
||||
|
||||
def load_ontology_config(config_file: str) -> None:
|
||||
"""Load ontology configuration from file (convenience function)."""
|
||||
config = get_ontology_config()
|
||||
config.load_from_file(config_file)
|
||||
|
||||
|
||||
# Example configuration templates
|
||||
def create_example_config_file(output_file: str) -> None:
|
||||
"""Create an example configuration file."""
|
||||
|
||||
example_config = {
|
||||
"domains": {
|
||||
"example_domain": {
|
||||
"domain": "example_domain",
|
||||
"datapoint_mappings": [
|
||||
{
|
||||
"ontology_node_type": "Entity",
|
||||
"datapoint_class": "cognee.infrastructure.engine.models.DataPoint.DataPoint",
|
||||
"field_mappings": {
|
||||
"name": "name",
|
||||
"description": "description",
|
||||
"category": "entity_type"
|
||||
},
|
||||
"custom_resolver": None,
|
||||
"validation_rules": ["required:name"]
|
||||
}
|
||||
],
|
||||
"graph_binding": {
|
||||
"node_type_mapping": {
|
||||
"Entity": "domain_entity",
|
||||
"Concept": "domain_concept"
|
||||
},
|
||||
"edge_type_mapping": {
|
||||
"related_to": "domain_relation",
|
||||
"part_of": "composition"
|
||||
},
|
||||
"property_transformations": {},
|
||||
"custom_binding_strategy": None
|
||||
},
|
||||
"task_configs": {
|
||||
"extract_graph_from_data": {
|
||||
"enhance_with_entities": True,
|
||||
"inject_datapoint_mappings": True,
|
||||
"inject_graph_binding": True,
|
||||
"target_entity_types": ["Entity", "Concept"]
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"pipelines": {
|
||||
"example_pipeline": {
|
||||
"pipeline_name": "example_pipeline",
|
||||
"domain": "example_domain",
|
||||
"datapoint_mappings": [], # Use domain defaults
|
||||
"graph_binding": {}, # Use domain defaults
|
||||
"task_configs": {
|
||||
"summarize_text": {
|
||||
"enhance_with_entities": True,
|
||||
"enable_ontology_validation": True
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
with open(output_file, 'w') as f:
|
||||
json.dump(example_config, f, indent=2)
|
||||
|
||||
logger.info(f"Created example configuration file: {output_file}")
|
||||
552
cognee/modules/ontology/example_usage.py
Normal file
552
cognee/modules/ontology/example_usage.py
Normal file
|
|
@ -0,0 +1,552 @@
|
|||
"""
|
||||
Example usage of the refactored ontology system.
|
||||
|
||||
This demonstrates how to:
|
||||
1. Set up ontology providers and registry
|
||||
2. Configure domain-specific mappings
|
||||
3. Integrate with pipelines
|
||||
4. Use custom resolvers and binding strategies
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import List, Any, Dict
|
||||
|
||||
from cognee.modules.ontology.interfaces import (
|
||||
OntologyContext,
|
||||
OntologyScope,
|
||||
DataPointMapping,
|
||||
GraphBindingConfig,
|
||||
)
|
||||
from cognee.modules.ontology.manager import create_ontology_manager
|
||||
from cognee.modules.ontology.registry import OntologyRegistry
|
||||
from cognee.modules.ontology.providers import JSONOntologyProvider, RDFOntologyProvider
|
||||
from cognee.modules.ontology.adapters import DefaultOntologyAdapter
|
||||
from cognee.modules.ontology.resolvers import DefaultDataPointResolver, DomainSpecificResolver
|
||||
from cognee.modules.ontology.binders import DefaultGraphBinder, DomainSpecificBinder
|
||||
from cognee.modules.ontology.pipeline_integration import (
|
||||
PipelineOntologyConfigurator,
|
||||
OntologyInjector,
|
||||
)
|
||||
from cognee.modules.ontology.configuration import (
|
||||
get_ontology_config,
|
||||
configure_domain,
|
||||
configure_pipeline,
|
||||
)
|
||||
from cognee.modules.pipelines.tasks.task import Task
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
logger = get_logger("OntologyExample")
|
||||
|
||||
|
||||
async def example_basic_setup():
|
||||
"""Example: Basic ontology system setup."""
|
||||
|
||||
logger.info("=== Basic Ontology System Setup ===")
|
||||
|
||||
# 1. Create registry and providers
|
||||
registry = OntologyRegistry()
|
||||
|
||||
providers = {
|
||||
"json_provider": JSONOntologyProvider(),
|
||||
"rdf_provider": RDFOntologyProvider(),
|
||||
}
|
||||
|
||||
adapters = {
|
||||
"default_adapter": DefaultOntologyAdapter(),
|
||||
}
|
||||
|
||||
# 2. Create resolvers and binders
|
||||
datapoint_resolver = DefaultDataPointResolver()
|
||||
graph_binder = DefaultGraphBinder()
|
||||
|
||||
# 3. Create ontology manager
|
||||
ontology_manager = await create_ontology_manager(
|
||||
registry=registry,
|
||||
providers=providers,
|
||||
adapters=adapters,
|
||||
datapoint_resolver=datapoint_resolver,
|
||||
graph_binder=graph_binder,
|
||||
)
|
||||
|
||||
logger.info("Ontology system initialized successfully")
|
||||
return ontology_manager
|
||||
|
||||
|
||||
async def example_load_ontologies(ontology_manager):
|
||||
"""Example: Loading different types of ontologies."""
|
||||
|
||||
logger.info("=== Loading Ontologies ===")
|
||||
|
||||
# 1. Load JSON ontology
|
||||
json_ontology_data = {
|
||||
"id": "medical_ontology",
|
||||
"name": "Medical Knowledge Base",
|
||||
"description": "Basic medical ontology",
|
||||
"nodes": [
|
||||
{
|
||||
"id": "disease_001",
|
||||
"name": "Diabetes",
|
||||
"type": "Disease",
|
||||
"description": "A group of metabolic disorders",
|
||||
"category": "medical_condition",
|
||||
"properties": {"icd_code": "E11", "severity": "chronic"}
|
||||
},
|
||||
{
|
||||
"id": "symptom_001",
|
||||
"name": "Fatigue",
|
||||
"type": "Symptom",
|
||||
"description": "Extreme tiredness",
|
||||
"category": "clinical_finding",
|
||||
"properties": {"frequency": "common"}
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"id": "rel_001",
|
||||
"source": "disease_001",
|
||||
"target": "symptom_001",
|
||||
"relationship": "causes",
|
||||
"properties": {"strength": "moderate"}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
json_provider = ontology_manager.providers["json_provider"]
|
||||
medical_ontology = await json_provider.load_ontology(json_ontology_data)
|
||||
|
||||
# 2. Register ontology in registry
|
||||
ontology_id = await ontology_manager.registry.register_ontology(
|
||||
medical_ontology,
|
||||
OntologyScope.DOMAIN,
|
||||
OntologyContext(domain="medical")
|
||||
)
|
||||
|
||||
logger.info(f"Loaded and registered medical ontology: {ontology_id}")
|
||||
|
||||
return medical_ontology
|
||||
|
||||
|
||||
async def example_configure_domain_mappings(ontology_manager):
|
||||
"""Example: Configure domain-specific DataPoint mappings."""
|
||||
|
||||
logger.info("=== Configuring Domain Mappings ===")
|
||||
|
||||
# 1. Define DataPoint mappings for medical domain
|
||||
medical_mappings = [
|
||||
DataPointMapping(
|
||||
ontology_node_type="Disease",
|
||||
datapoint_class="cognee.infrastructure.engine.models.DataPoint.DataPoint",
|
||||
field_mappings={
|
||||
"name": "name",
|
||||
"description": "description",
|
||||
"icd_code": "medical_code",
|
||||
"severity": "severity_level",
|
||||
},
|
||||
validation_rules=["required:name", "required:medical_code"]
|
||||
),
|
||||
DataPointMapping(
|
||||
ontology_node_type="Symptom",
|
||||
datapoint_class="cognee.infrastructure.engine.models.DataPoint.DataPoint",
|
||||
field_mappings={
|
||||
"name": "name",
|
||||
"description": "description",
|
||||
"frequency": "occurrence_rate",
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
# 2. Define graph binding configuration
|
||||
medical_binding = GraphBindingConfig(
|
||||
node_type_mapping={
|
||||
"Disease": "medical_condition",
|
||||
"Symptom": "clinical_finding",
|
||||
"Treatment": "therapeutic_procedure",
|
||||
},
|
||||
edge_type_mapping={
|
||||
"causes": "causality",
|
||||
"treats": "therapeutic_relationship",
|
||||
"associated_with": "clinical_association",
|
||||
}
|
||||
)
|
||||
|
||||
# 3. Configure the domain
|
||||
ontology_manager.configure_datapoint_mapping("medical", medical_mappings)
|
||||
ontology_manager.configure_graph_binding("medical", medical_binding)
|
||||
|
||||
logger.info("Configured medical domain mappings and bindings")
|
||||
|
||||
|
||||
async def example_custom_resolver(ontology_manager):
|
||||
"""Example: Register and use custom DataPoint resolver."""
|
||||
|
||||
logger.info("=== Custom DataPoint Resolver ===")
|
||||
|
||||
# 1. Define custom resolver function
|
||||
async def medical_disease_resolver(ontology_node, mapping_config, context=None):
|
||||
"""Custom resolver for medical disease entities."""
|
||||
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
||||
|
||||
# Create DataPoint with medical-specific logic
|
||||
datapoint = DataPoint(
|
||||
id=ontology_node.id,
|
||||
type="medical_disease",
|
||||
ontology_valid=True,
|
||||
metadata={
|
||||
"type": "medical_entity",
|
||||
"index_fields": ["name", "medical_code"],
|
||||
"domain": "medical",
|
||||
"ontology_node_id": ontology_node.id,
|
||||
}
|
||||
)
|
||||
|
||||
# Map ontology properties with domain-specific processing
|
||||
datapoint.name = ontology_node.name
|
||||
datapoint.description = ontology_node.description
|
||||
datapoint.medical_code = ontology_node.properties.get("icd_code", "")
|
||||
datapoint.severity_level = ontology_node.properties.get("severity", "unknown")
|
||||
|
||||
# Add computed properties
|
||||
datapoint.medical_category = "disease"
|
||||
datapoint.risk_level = "high" if datapoint.severity_level == "chronic" else "low"
|
||||
|
||||
logger.info(f"Custom resolver created DataPoint for disease: {datapoint.name}")
|
||||
return datapoint
|
||||
|
||||
# 2. Register custom resolver
|
||||
ontology_manager.register_custom_resolver("medical_disease_resolver", medical_disease_resolver)
|
||||
|
||||
# 3. Update mapping to use custom resolver
|
||||
updated_mapping = DataPointMapping(
|
||||
ontology_node_type="Disease",
|
||||
datapoint_class="cognee.infrastructure.engine.models.DataPoint.DataPoint",
|
||||
field_mappings={}, # Handled by custom resolver
|
||||
custom_resolver="medical_disease_resolver"
|
||||
)
|
||||
|
||||
ontology_manager.configure_datapoint_mapping("medical", [updated_mapping])
|
||||
|
||||
logger.info("Registered custom medical disease resolver")
|
||||
|
||||
|
||||
async def example_custom_binding_strategy(ontology_manager):
|
||||
"""Example: Register and use custom graph binding strategy."""
|
||||
|
||||
logger.info("=== Custom Graph Binding Strategy ===")
|
||||
|
||||
# 1. Define custom binding strategy
|
||||
async def medical_graph_binding_strategy(ontology, binding_config, context=None):
|
||||
"""Custom binding strategy for medical graphs."""
|
||||
from datetime import datetime
|
||||
|
||||
graph_nodes = []
|
||||
graph_edges = []
|
||||
|
||||
# Process nodes with medical-specific transformations
|
||||
for node in ontology.nodes:
|
||||
if node.type == "Disease":
|
||||
# Create medical condition node
|
||||
node_props = {
|
||||
"id": node.id,
|
||||
"name": node.name,
|
||||
"type": "medical_condition",
|
||||
"description": node.description,
|
||||
"medical_code": node.properties.get("icd_code", ""),
|
||||
"severity": node.properties.get("severity", "unknown"),
|
||||
"category": "pathology",
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
"ontology_source": True,
|
||||
}
|
||||
graph_nodes.append((node.id, node_props))
|
||||
|
||||
elif node.type == "Symptom":
|
||||
# Create clinical finding node
|
||||
node_props = {
|
||||
"id": node.id,
|
||||
"name": node.name,
|
||||
"type": "clinical_finding",
|
||||
"description": node.description,
|
||||
"frequency": node.properties.get("frequency", "unknown"),
|
||||
"category": "symptom",
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
"ontology_source": True,
|
||||
}
|
||||
graph_nodes.append((node.id, node_props))
|
||||
|
||||
# Process edges with medical-specific relationships
|
||||
for edge in ontology.edges:
|
||||
edge_props = {
|
||||
"source_node_id": edge.source_id,
|
||||
"target_node_id": edge.target_id,
|
||||
"relationship_name": "medical_" + edge.relationship_type,
|
||||
"strength": edge.properties.get("strength", "unknown"),
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
"ontology_source": True,
|
||||
}
|
||||
|
||||
graph_edges.append((
|
||||
edge.source_id,
|
||||
edge.target_id,
|
||||
"medical_" + edge.relationship_type,
|
||||
edge_props
|
||||
))
|
||||
|
||||
logger.info(f"Custom binding strategy created {len(graph_nodes)} nodes and {len(graph_edges)} edges")
|
||||
return graph_nodes, graph_edges
|
||||
|
||||
# 2. Register custom binding strategy
|
||||
ontology_manager.register_binding_strategy("medical_graph_binding", medical_graph_binding_strategy)
|
||||
|
||||
# 3. Update binding config to use custom strategy
|
||||
updated_binding = GraphBindingConfig(
|
||||
custom_binding_strategy="medical_graph_binding"
|
||||
)
|
||||
|
||||
ontology_manager.configure_graph_binding("medical", updated_binding)
|
||||
|
||||
logger.info("Registered custom medical graph binding strategy")
|
||||
|
||||
|
||||
async def example_pipeline_integration(ontology_manager):
|
||||
"""Example: Integrate ontology with pipeline tasks."""
|
||||
|
||||
logger.info("=== Pipeline Integration ===")
|
||||
|
||||
# 1. Create pipeline configurator
|
||||
pipeline_configurator = PipelineOntologyConfigurator(ontology_manager)
|
||||
|
||||
# 2. Configure medical pipeline
|
||||
medical_mappings = ontology_manager.domain_datapoint_mappings.get("medical", [])
|
||||
medical_binding = ontology_manager.domain_graph_bindings.get("medical")
|
||||
|
||||
task_configs = {
|
||||
"extract_graph_from_data": {
|
||||
"enhance_with_entities": True,
|
||||
"inject_datapoint_mappings": True,
|
||||
"inject_graph_binding": True,
|
||||
"target_entity_types": ["Disease", "Symptom", "Treatment"],
|
||||
},
|
||||
"summarize_text": {
|
||||
"enhance_with_entities": True,
|
||||
"enable_ontology_validation": True,
|
||||
"validation_threshold": 0.85,
|
||||
}
|
||||
}
|
||||
|
||||
pipeline_configurator.configure_pipeline(
|
||||
pipeline_name="medical_cognify_pipeline",
|
||||
domain="medical",
|
||||
datapoint_mappings=medical_mappings,
|
||||
graph_binding=medical_binding,
|
||||
task_specific_configs=task_configs
|
||||
)
|
||||
|
||||
# 3. Create ontology injector for the pipeline
|
||||
injector = pipeline_configurator.create_ontology_injector("medical_cognify_pipeline")
|
||||
|
||||
# 4. Create sample task and inject ontology
|
||||
async def sample_extract_task(data_chunks, **kwargs):
|
||||
"""Sample extraction task."""
|
||||
logger.info("Executing extract task with ontology context")
|
||||
ontology_context = kwargs.get("ontology_context")
|
||||
datapoint_mappings = kwargs.get("datapoint_mappings", [])
|
||||
|
||||
logger.info(f"Task received ontology context for domain: {ontology_context.domain}")
|
||||
logger.info(f"Task has {len(datapoint_mappings)} DataPoint mappings available")
|
||||
|
||||
# Simulate task processing with ontology enhancement
|
||||
enhanced_results = []
|
||||
for chunk in data_chunks:
|
||||
# In real implementation, this would use ontology for entity extraction
|
||||
enhanced_results.append({
|
||||
"chunk": chunk,
|
||||
"ontology_enhanced": True,
|
||||
"domain": ontology_context.domain,
|
||||
"entities_found": ["Diabetes", "Fatigue"] # Simulated
|
||||
})
|
||||
|
||||
return enhanced_results
|
||||
|
||||
sample_task = Task(sample_extract_task)
|
||||
|
||||
# 5. Get pipeline context and inject ontology
|
||||
context = pipeline_configurator.get_pipeline_context(
|
||||
"medical_cognify_pipeline",
|
||||
user_id="user123",
|
||||
dataset_id="medical_dataset_001"
|
||||
)
|
||||
|
||||
enhanced_task = await injector.inject_into_task(sample_task, context)
|
||||
|
||||
# 6. Execute enhanced task
|
||||
sample_data = ["Patient reports fatigue and frequent urination", "Diagnosis: Type 2 Diabetes"]
|
||||
results = await enhanced_task.run(sample_data)
|
||||
|
||||
logger.info(f"Enhanced task completed with {len(results)} results")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def example_content_enhancement(ontology_manager):
|
||||
"""Example: Enhance content with ontological information."""
|
||||
|
||||
logger.info("=== Content Enhancement ===")
|
||||
|
||||
# 1. Create context for medical domain
|
||||
context = OntologyContext(
|
||||
domain="medical",
|
||||
pipeline_name="medical_analysis",
|
||||
user_id="doctor123"
|
||||
)
|
||||
|
||||
# 2. Sample medical text
|
||||
medical_text = """
|
||||
The patient presents with chronic fatigue and excessive thirst.
|
||||
Blood glucose levels are elevated, indicating possible diabetes mellitus.
|
||||
Further testing for HbA1c is recommended to confirm diagnosis.
|
||||
"""
|
||||
|
||||
# 3. Enhance content with ontology
|
||||
enhanced_content = await ontology_manager.enhance_with_ontology(medical_text, context)
|
||||
|
||||
logger.info("Enhanced content:")
|
||||
logger.info(f" Original content: {enhanced_content['original_content'][:50]}...")
|
||||
logger.info(f" Extracted entities: {len(enhanced_content['extracted_entities'])}")
|
||||
logger.info(f" Semantic relationships: {len(enhanced_content['semantic_relationships'])}")
|
||||
|
||||
for entity in enhanced_content['extracted_entities']:
|
||||
logger.info(f" Found entity: {entity['name']} (type: {entity['type']})")
|
||||
|
||||
for relationship in enhanced_content['semantic_relationships']:
|
||||
logger.info(f" Relationship: {relationship['source']} -> {relationship['relationship']} -> {relationship['target']}")
|
||||
|
||||
return enhanced_content
|
||||
|
||||
|
||||
async def example_datapoint_resolution(ontology_manager):
|
||||
"""Example: Resolve ontology nodes to DataPoint instances."""
|
||||
|
||||
logger.info("=== DataPoint Resolution ===")
|
||||
|
||||
# 1. Get medical ontology
|
||||
context = OntologyContext(domain="medical")
|
||||
ontologies = await ontology_manager.get_applicable_ontologies(context)
|
||||
|
||||
if not ontologies:
|
||||
logger.warning("No medical ontologies found")
|
||||
return []
|
||||
|
||||
medical_ontology = ontologies[0]
|
||||
|
||||
# 2. Filter disease nodes
|
||||
disease_nodes = [node for node in medical_ontology.nodes if node.type == "Disease"]
|
||||
|
||||
if not disease_nodes:
|
||||
logger.warning("No disease nodes found in ontology")
|
||||
return []
|
||||
|
||||
# 3. Resolve to DataPoint instances
|
||||
datapoints = await ontology_manager.resolve_to_datapoints(disease_nodes, context)
|
||||
|
||||
logger.info(f"Resolved {len(datapoints)} disease nodes to DataPoints:")
|
||||
for dp in datapoints:
|
||||
logger.info(f" DataPoint: {dp.type} - {getattr(dp, 'name', 'Unnamed')}")
|
||||
logger.info(f" ID: {dp.id}")
|
||||
logger.info(f" Ontology valid: {dp.ontology_valid}")
|
||||
if hasattr(dp, 'medical_code'):
|
||||
logger.info(f" Medical code: {dp.medical_code}")
|
||||
|
||||
return datapoints
|
||||
|
||||
|
||||
async def example_graph_binding(ontology_manager):
|
||||
"""Example: Bind ontology to graph structure."""
|
||||
|
||||
logger.info("=== Graph Binding ===")
|
||||
|
||||
# 1. Get medical ontology
|
||||
context = OntologyContext(domain="medical")
|
||||
ontologies = await ontology_manager.get_applicable_ontologies(context)
|
||||
|
||||
if not ontologies:
|
||||
logger.warning("No medical ontologies found")
|
||||
return [], []
|
||||
|
||||
medical_ontology = ontologies[0]
|
||||
|
||||
# 2. Bind to graph structure
|
||||
graph_nodes, graph_edges = await ontology_manager.bind_to_graph(medical_ontology, context)
|
||||
|
||||
logger.info(f"Bound ontology to graph structure:")
|
||||
logger.info(f" Graph nodes: {len(graph_nodes)}")
|
||||
logger.info(f" Graph edges: {len(graph_edges)}")
|
||||
|
||||
# Display sample nodes
|
||||
for i, node in enumerate(graph_nodes[:3]): # Show first 3
|
||||
if isinstance(node, tuple):
|
||||
node_id, node_props = node
|
||||
logger.info(f" Node {i+1}: {node_id} (type: {node_props.get('type', 'unknown')})")
|
||||
else:
|
||||
logger.info(f" Node {i+1}: {getattr(node, 'id', 'unknown')} (type: {getattr(node, 'type', 'unknown')})")
|
||||
|
||||
# Display sample edges
|
||||
for i, edge in enumerate(graph_edges[:3]): # Show first 3
|
||||
if isinstance(edge, tuple) and len(edge) >= 4:
|
||||
source, target, rel_type, props = edge[:4]
|
||||
logger.info(f" Edge {i+1}: {source} -> {rel_type} -> {target}")
|
||||
else:
|
||||
logger.info(f" Edge {i+1}: {edge}")
|
||||
|
||||
return graph_nodes, graph_edges
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run all examples."""
|
||||
|
||||
logger.info("Starting ontology system examples...")
|
||||
|
||||
try:
|
||||
# 1. Basic setup
|
||||
ontology_manager = await example_basic_setup()
|
||||
|
||||
# 2. Load ontologies
|
||||
medical_ontology = await example_load_ontologies(ontology_manager)
|
||||
|
||||
# 3. Configure domain mappings
|
||||
await example_configure_domain_mappings(ontology_manager)
|
||||
|
||||
# 4. Custom resolver
|
||||
await example_custom_resolver(ontology_manager)
|
||||
|
||||
# 5. Custom binding strategy
|
||||
await example_custom_binding_strategy(ontology_manager)
|
||||
|
||||
# 6. Pipeline integration
|
||||
pipeline_results = await example_pipeline_integration(ontology_manager)
|
||||
|
||||
# 7. Content enhancement
|
||||
enhanced_content = await example_content_enhancement(ontology_manager)
|
||||
|
||||
# 8. DataPoint resolution
|
||||
datapoints = await example_datapoint_resolution(ontology_manager)
|
||||
|
||||
# 9. Graph binding
|
||||
graph_nodes, graph_edges = await example_graph_binding(ontology_manager)
|
||||
|
||||
logger.info("All examples completed successfully!")
|
||||
|
||||
# Print summary
|
||||
logger.info("\n=== Summary ===")
|
||||
logger.info(f"Loaded ontologies: 1")
|
||||
logger.info(f"Pipeline results: {len(pipeline_results) if pipeline_results else 0}")
|
||||
logger.info(f"Enhanced entities: {len(enhanced_content.get('extracted_entities', [])) if enhanced_content else 0}")
|
||||
logger.info(f"DataPoints created: {len(datapoints) if datapoints else 0}")
|
||||
logger.info(f"Graph nodes: {len(graph_nodes) if graph_nodes else 0}")
|
||||
logger.info(f"Graph edges: {len(graph_edges) if graph_edges else 0}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Example execution failed: {e}", exc_info=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
351
cognee/modules/ontology/interfaces.py
Normal file
351
cognee/modules/ontology/interfaces.py
Normal file
|
|
@ -0,0 +1,351 @@
|
|||
"""Abstract interfaces for ontology system components."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union, Type, Callable
|
||||
from pydantic import BaseModel
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class OntologyFormat(str, Enum):
|
||||
"""Supported ontology formats."""
|
||||
RDF_XML = "rdf_xml"
|
||||
OWL = "owl"
|
||||
JSON = "json"
|
||||
CSV = "csv"
|
||||
YAML = "yaml"
|
||||
DATABASE = "database"
|
||||
LLM_GENERATED = "llm_generated"
|
||||
|
||||
|
||||
class OntologyScope(str, Enum):
|
||||
"""Ontology scopes for different use cases."""
|
||||
GLOBAL = "global" # Applies to all pipelines
|
||||
DOMAIN = "domain" # Applies to specific domain (medical, legal, etc.)
|
||||
PIPELINE = "pipeline" # Applies to specific pipeline type
|
||||
USER = "user" # User-specific ontologies
|
||||
DATASET = "dataset" # Dataset-specific ontologies
|
||||
|
||||
|
||||
class OntologyNode(BaseModel):
|
||||
"""Standard ontology node representation."""
|
||||
id: str
|
||||
name: str
|
||||
type: str
|
||||
description: Optional[str] = None
|
||||
category: Optional[str] = None
|
||||
properties: Dict[str, Any] = {}
|
||||
labels: List[str] = []
|
||||
|
||||
|
||||
class OntologyEdge(BaseModel):
|
||||
"""Standard ontology edge representation."""
|
||||
id: str
|
||||
source_id: str
|
||||
target_id: str
|
||||
relationship_type: str
|
||||
properties: Dict[str, Any] = {}
|
||||
weight: Optional[float] = None
|
||||
|
||||
|
||||
class OntologyGraph(BaseModel):
|
||||
"""Standard ontology graph representation."""
|
||||
id: str
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
format: OntologyFormat
|
||||
scope: OntologyScope
|
||||
nodes: List[OntologyNode]
|
||||
edges: List[OntologyEdge]
|
||||
metadata: Dict[str, Any] = {}
|
||||
|
||||
|
||||
class OntologyContext(BaseModel):
|
||||
"""Context for ontology operations."""
|
||||
user_id: Optional[str] = None
|
||||
dataset_id: Optional[str] = None
|
||||
pipeline_name: Optional[str] = None
|
||||
domain: Optional[str] = None
|
||||
custom_properties: Dict[str, Any] = {}
|
||||
|
||||
|
||||
class DataPointMapping(BaseModel):
|
||||
"""Mapping configuration between ontology and DataPoint."""
|
||||
ontology_node_type: str
|
||||
datapoint_class: str
|
||||
field_mappings: Dict[str, str] = {} # ontology_field -> datapoint_field
|
||||
custom_resolver: Optional[str] = None # Function name for custom resolution
|
||||
validation_rules: List[str] = []
|
||||
|
||||
|
||||
class GraphBindingConfig(BaseModel):
|
||||
"""Configuration for how ontology binds to graph structures."""
|
||||
node_type_mapping: Dict[str, str] = {} # ontology_type -> graph_node_type
|
||||
edge_type_mapping: Dict[str, str] = {} # ontology_relation -> graph_edge_type
|
||||
property_transformations: Dict[str, Callable[[Any], Any]] = {}
|
||||
custom_binding_strategy: Optional[str] = None
|
||||
|
||||
|
||||
class IOntologyProvider(ABC):
|
||||
"""Abstract interface for ontology providers."""
|
||||
|
||||
@abstractmethod
|
||||
async def load_ontology(
|
||||
self,
|
||||
source: Union[str, Dict[str, Any]],
|
||||
context: Optional[OntologyContext] = None
|
||||
) -> OntologyGraph:
|
||||
"""Load ontology from source."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def save_ontology(
|
||||
self,
|
||||
ontology: OntologyGraph,
|
||||
destination: str,
|
||||
context: Optional[OntologyContext] = None
|
||||
) -> bool:
|
||||
"""Save ontology to destination."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def supports_format(self, format: OntologyFormat) -> bool:
|
||||
"""Check if provider supports given format."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def validate_ontology(self, ontology: OntologyGraph) -> bool:
|
||||
"""Validate ontology structure."""
|
||||
pass
|
||||
|
||||
|
||||
class IOntologyAdapter(ABC):
|
||||
"""Abstract interface for ontology adapters."""
|
||||
|
||||
@abstractmethod
|
||||
async def find_matching_nodes(
|
||||
self,
|
||||
query_text: str,
|
||||
ontology: OntologyGraph,
|
||||
similarity_threshold: float = 0.8
|
||||
) -> List[OntologyNode]:
|
||||
"""Find nodes matching query text."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_node_relationships(
|
||||
self,
|
||||
node_id: str,
|
||||
ontology: OntologyGraph,
|
||||
max_depth: int = 2
|
||||
) -> List[OntologyEdge]:
|
||||
"""Get relationships for a specific node."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def expand_subgraph(
|
||||
self,
|
||||
node_ids: List[str],
|
||||
ontology: OntologyGraph,
|
||||
directed: bool = True
|
||||
) -> Tuple[List[OntologyNode], List[OntologyEdge]]:
|
||||
"""Expand subgraph around given nodes."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def merge_ontologies(
|
||||
self,
|
||||
ontologies: List[OntologyGraph]
|
||||
) -> OntologyGraph:
|
||||
"""Merge multiple ontologies."""
|
||||
pass
|
||||
|
||||
|
||||
class IOntologyRegistry(ABC):
|
||||
"""Abstract interface for ontology registry."""
|
||||
|
||||
@abstractmethod
|
||||
async def register_ontology(
|
||||
self,
|
||||
ontology: OntologyGraph,
|
||||
scope: OntologyScope,
|
||||
context: Optional[OntologyContext] = None
|
||||
) -> str:
|
||||
"""Register an ontology."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_ontology(
|
||||
self,
|
||||
ontology_id: str,
|
||||
context: Optional[OntologyContext] = None
|
||||
) -> Optional[OntologyGraph]:
|
||||
"""Get ontology by ID."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def find_ontologies(
|
||||
self,
|
||||
scope: Optional[OntologyScope] = None,
|
||||
domain: Optional[str] = None,
|
||||
context: Optional[OntologyContext] = None
|
||||
) -> List[OntologyGraph]:
|
||||
"""Find ontologies matching criteria."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def unregister_ontology(
|
||||
self,
|
||||
ontology_id: str,
|
||||
context: Optional[OntologyContext] = None
|
||||
) -> bool:
|
||||
"""Unregister an ontology."""
|
||||
pass
|
||||
|
||||
|
||||
class IDataPointResolver(ABC):
|
||||
"""Abstract interface for resolving ontology to DataPoint instances."""
|
||||
|
||||
@abstractmethod
|
||||
async def resolve_to_datapoint(
|
||||
self,
|
||||
ontology_node: OntologyNode,
|
||||
mapping_config: DataPointMapping,
|
||||
context: Optional[OntologyContext] = None
|
||||
) -> Any: # Should be DataPoint but avoiding circular import
|
||||
"""Resolve ontology node to DataPoint instance."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def resolve_from_datapoint(
|
||||
self,
|
||||
datapoint: Any, # DataPoint
|
||||
mapping_config: DataPointMapping,
|
||||
context: Optional[OntologyContext] = None
|
||||
) -> OntologyNode:
|
||||
"""Resolve DataPoint instance to ontology node."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def validate_mapping(
|
||||
self,
|
||||
mapping_config: DataPointMapping
|
||||
) -> bool:
|
||||
"""Validate mapping configuration."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def register_custom_resolver(
|
||||
self,
|
||||
resolver_name: str,
|
||||
resolver_func: Callable[[OntologyNode, DataPointMapping], Any]
|
||||
) -> None:
|
||||
"""Register a custom resolver function."""
|
||||
pass
|
||||
|
||||
|
||||
class IGraphBinder(ABC):
|
||||
"""Abstract interface for binding ontology to graph structures."""
|
||||
|
||||
@abstractmethod
|
||||
async def bind_ontology_to_graph(
|
||||
self,
|
||||
ontology: OntologyGraph,
|
||||
binding_config: GraphBindingConfig,
|
||||
context: Optional[OntologyContext] = None
|
||||
) -> Tuple[List[Any], List[Any]]: # (graph_nodes, graph_edges)
|
||||
"""Bind ontology to graph structure."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def transform_node_properties(
|
||||
self,
|
||||
node: OntologyNode,
|
||||
transformations: Dict[str, Callable[[Any], Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""Transform node properties according to binding config."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def transform_edge_properties(
|
||||
self,
|
||||
edge: OntologyEdge,
|
||||
transformations: Dict[str, Callable[[Any], Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""Transform edge properties according to binding config."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def register_binding_strategy(
|
||||
self,
|
||||
strategy_name: str,
|
||||
strategy_func: Callable[[OntologyGraph, GraphBindingConfig], Tuple[List[Any], List[Any]]]
|
||||
) -> None:
|
||||
"""Register a custom binding strategy."""
|
||||
pass
|
||||
|
||||
|
||||
class IOntologyManager(ABC):
|
||||
"""Abstract interface for ontology manager."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_applicable_ontologies(
|
||||
self,
|
||||
context: OntologyContext
|
||||
) -> List[OntologyGraph]:
|
||||
"""Get ontologies applicable to given context."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def enhance_with_ontology(
|
||||
self,
|
||||
content: str,
|
||||
context: OntologyContext
|
||||
) -> Dict[str, Any]:
|
||||
"""Enhance content with ontological information."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def inject_ontology_into_task(
|
||||
self,
|
||||
task_name: str,
|
||||
task_params: Dict[str, Any],
|
||||
context: OntologyContext
|
||||
) -> Dict[str, Any]:
|
||||
"""Inject ontological context into task parameters."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def resolve_to_datapoints(
|
||||
self,
|
||||
ontology_nodes: List[OntologyNode],
|
||||
context: OntologyContext
|
||||
) -> List[Any]: # List[DataPoint]
|
||||
"""Resolve ontology nodes to DataPoint instances."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def bind_to_graph(
|
||||
self,
|
||||
ontology: OntologyGraph,
|
||||
context: OntologyContext
|
||||
) -> Tuple[List[Any], List[Any]]: # (graph_nodes, graph_edges)
|
||||
"""Bind ontology to graph structure using configured binding."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def configure_datapoint_mapping(
|
||||
self,
|
||||
domain: str,
|
||||
mappings: List[DataPointMapping]
|
||||
) -> None:
|
||||
"""Configure DataPoint mappings for a domain."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def configure_graph_binding(
|
||||
self,
|
||||
domain: str,
|
||||
binding_config: GraphBindingConfig
|
||||
) -> None:
|
||||
"""Configure graph binding for a domain."""
|
||||
pass
|
||||
363
cognee/modules/ontology/manager.py
Normal file
363
cognee/modules/ontology/manager.py
Normal file
|
|
@ -0,0 +1,363 @@
|
|||
"""Core ontology manager implementation."""
|
||||
|
||||
import importlib
|
||||
from typing import Any, Dict, List, Optional, Tuple, Callable, Type
|
||||
from uuid import uuid4
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.ontology.interfaces import (
|
||||
IOntologyManager,
|
||||
IOntologyRegistry,
|
||||
IOntologyProvider,
|
||||
IOntologyAdapter,
|
||||
IDataPointResolver,
|
||||
IGraphBinder,
|
||||
OntologyGraph,
|
||||
OntologyNode,
|
||||
OntologyContext,
|
||||
DataPointMapping,
|
||||
GraphBindingConfig,
|
||||
OntologyScope,
|
||||
)
|
||||
|
||||
logger = get_logger("OntologyManager")
|
||||
|
||||
|
||||
class OntologyManager(IOntologyManager):
|
||||
"""Core implementation of ontology management."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
registry: IOntologyRegistry,
|
||||
providers: Dict[str, IOntologyProvider],
|
||||
adapters: Dict[str, IOntologyAdapter],
|
||||
datapoint_resolver: IDataPointResolver,
|
||||
graph_binder: IGraphBinder,
|
||||
):
|
||||
self.registry = registry
|
||||
self.providers = providers
|
||||
self.adapters = adapters
|
||||
self.datapoint_resolver = datapoint_resolver
|
||||
self.graph_binder = graph_binder
|
||||
|
||||
# Domain-specific configurations
|
||||
self.domain_datapoint_mappings: Dict[str, List[DataPointMapping]] = {}
|
||||
self.domain_graph_bindings: Dict[str, GraphBindingConfig] = {}
|
||||
|
||||
# Custom resolvers and binding strategies
|
||||
self.custom_resolvers: Dict[str, Callable] = {}
|
||||
self.custom_binding_strategies: Dict[str, Callable] = {}
|
||||
|
||||
async def get_applicable_ontologies(
|
||||
self,
|
||||
context: OntologyContext
|
||||
) -> List[OntologyGraph]:
|
||||
"""Get ontologies applicable to given context."""
|
||||
ontologies = []
|
||||
|
||||
# Get global ontologies
|
||||
global_ontologies = await self.registry.find_ontologies(
|
||||
scope=OntologyScope.GLOBAL,
|
||||
context=context
|
||||
)
|
||||
ontologies.extend(global_ontologies)
|
||||
|
||||
# Get domain-specific ontologies
|
||||
if context.domain:
|
||||
domain_ontologies = await self.registry.find_ontologies(
|
||||
scope=OntologyScope.DOMAIN,
|
||||
domain=context.domain,
|
||||
context=context
|
||||
)
|
||||
ontologies.extend(domain_ontologies)
|
||||
|
||||
# Get pipeline-specific ontologies
|
||||
if context.pipeline_name:
|
||||
pipeline_ontologies = await self.registry.find_ontologies(
|
||||
scope=OntologyScope.PIPELINE,
|
||||
context=context
|
||||
)
|
||||
ontologies.extend(pipeline_ontologies)
|
||||
|
||||
# Get user-specific ontologies
|
||||
if context.user_id:
|
||||
user_ontologies = await self.registry.find_ontologies(
|
||||
scope=OntologyScope.USER,
|
||||
context=context
|
||||
)
|
||||
ontologies.extend(user_ontologies)
|
||||
|
||||
# Get dataset-specific ontologies
|
||||
if context.dataset_id:
|
||||
dataset_ontologies = await self.registry.find_ontologies(
|
||||
scope=OntologyScope.DATASET,
|
||||
context=context
|
||||
)
|
||||
ontologies.extend(dataset_ontologies)
|
||||
|
||||
# Remove duplicates and prioritize by scope
|
||||
unique_ontologies = self._prioritize_ontologies(ontologies)
|
||||
|
||||
logger.info(f"Found {len(unique_ontologies)} applicable ontologies for context")
|
||||
return unique_ontologies
|
||||
|
||||
async def enhance_with_ontology(
|
||||
self,
|
||||
content: str,
|
||||
context: OntologyContext
|
||||
) -> Dict[str, Any]:
|
||||
"""Enhance content with ontological information."""
|
||||
applicable_ontologies = await self.get_applicable_ontologies(context)
|
||||
|
||||
enhanced_data = {
|
||||
"original_content": content,
|
||||
"ontological_annotations": [],
|
||||
"extracted_entities": [],
|
||||
"semantic_relationships": [],
|
||||
}
|
||||
|
||||
for ontology in applicable_ontologies:
|
||||
adapter_name = self._get_adapter_for_ontology(ontology)
|
||||
if adapter_name not in self.adapters:
|
||||
logger.warning(f"No adapter found for ontology {ontology.id}")
|
||||
continue
|
||||
|
||||
adapter = self.adapters[adapter_name]
|
||||
|
||||
# Find matching nodes in content
|
||||
matching_nodes = await adapter.find_matching_nodes(
|
||||
content, ontology, similarity_threshold=0.7
|
||||
)
|
||||
|
||||
for node in matching_nodes:
|
||||
enhanced_data["extracted_entities"].append({
|
||||
"node_id": node.id,
|
||||
"name": node.name,
|
||||
"type": node.type,
|
||||
"category": node.category,
|
||||
"ontology_id": ontology.id,
|
||||
"confidence": 0.8, # This should come from the adapter
|
||||
})
|
||||
|
||||
# Get relationships for this node
|
||||
relationships = await adapter.get_node_relationships(
|
||||
node.id, ontology, max_depth=1
|
||||
)
|
||||
|
||||
for rel in relationships:
|
||||
enhanced_data["semantic_relationships"].append({
|
||||
"source": rel.source_id,
|
||||
"target": rel.target_id,
|
||||
"relationship": rel.relationship_type,
|
||||
"ontology_id": ontology.id,
|
||||
})
|
||||
|
||||
return enhanced_data
|
||||
|
||||
async def inject_ontology_into_task(
|
||||
self,
|
||||
task_name: str,
|
||||
task_params: Dict[str, Any],
|
||||
context: OntologyContext
|
||||
) -> Dict[str, Any]:
|
||||
"""Inject ontological context into task parameters."""
|
||||
applicable_ontologies = await self.get_applicable_ontologies(context)
|
||||
|
||||
# Merge all applicable ontologies
|
||||
if len(applicable_ontologies) > 1:
|
||||
primary_adapter = list(self.adapters.values())[0] # Use first available adapter
|
||||
merged_ontology = await primary_adapter.merge_ontologies(applicable_ontologies)
|
||||
ontologies_to_inject = [merged_ontology]
|
||||
else:
|
||||
ontologies_to_inject = applicable_ontologies
|
||||
|
||||
# Inject ontology-specific parameters
|
||||
enhanced_params = task_params.copy()
|
||||
enhanced_params["ontology_context"] = {
|
||||
"ontologies": [ont.id for ont in ontologies_to_inject],
|
||||
"domain": context.domain,
|
||||
"pipeline_name": context.pipeline_name,
|
||||
}
|
||||
|
||||
# Add ontology-aware configurations
|
||||
if context.domain and context.domain in self.domain_datapoint_mappings:
|
||||
enhanced_params["datapoint_mappings"] = self.domain_datapoint_mappings[context.domain]
|
||||
|
||||
if context.domain and context.domain in self.domain_graph_bindings:
|
||||
enhanced_params["graph_binding_config"] = self.domain_graph_bindings[context.domain]
|
||||
|
||||
return enhanced_params
|
||||
|
||||
async def resolve_to_datapoints(
|
||||
self,
|
||||
ontology_nodes: List[OntologyNode],
|
||||
context: OntologyContext
|
||||
) -> List[Any]: # List[DataPoint]
|
||||
"""Resolve ontology nodes to DataPoint instances."""
|
||||
datapoints = []
|
||||
|
||||
# Get domain-specific mappings
|
||||
mappings = self.domain_datapoint_mappings.get(context.domain, [])
|
||||
if not mappings:
|
||||
logger.warning(f"No DataPoint mappings configured for domain: {context.domain}")
|
||||
return datapoints
|
||||
|
||||
for node in ontology_nodes:
|
||||
# Find appropriate mapping for this node type
|
||||
mapping = self._find_mapping_for_node(node, mappings)
|
||||
if not mapping:
|
||||
logger.debug(f"No mapping found for node type: {node.type}")
|
||||
continue
|
||||
|
||||
try:
|
||||
datapoint = await self.datapoint_resolver.resolve_to_datapoint(
|
||||
node, mapping, context
|
||||
)
|
||||
if datapoint:
|
||||
datapoints.append(datapoint)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to resolve node {node.id} to DataPoint: {e}")
|
||||
|
||||
logger.info(f"Resolved {len(datapoints)} DataPoints from {len(ontology_nodes)} nodes")
|
||||
return datapoints
|
||||
|
||||
async def bind_to_graph(
|
||||
self,
|
||||
ontology: OntologyGraph,
|
||||
context: OntologyContext
|
||||
) -> Tuple[List[Any], List[Any]]: # (graph_nodes, graph_edges)
|
||||
"""Bind ontology to graph structure using configured binding."""
|
||||
binding_config = self.domain_graph_bindings.get(context.domain)
|
||||
if not binding_config:
|
||||
logger.warning(f"No graph binding configured for domain: {context.domain}")
|
||||
# Use default binding
|
||||
binding_config = GraphBindingConfig()
|
||||
|
||||
return await self.graph_binder.bind_ontology_to_graph(
|
||||
ontology, binding_config, context
|
||||
)
|
||||
|
||||
def configure_datapoint_mapping(
|
||||
self,
|
||||
domain: str,
|
||||
mappings: List[DataPointMapping]
|
||||
) -> None:
|
||||
"""Configure DataPoint mappings for a domain."""
|
||||
self.domain_datapoint_mappings[domain] = mappings
|
||||
logger.info(f"Configured {len(mappings)} DataPoint mappings for domain: {domain}")
|
||||
|
||||
def configure_graph_binding(
|
||||
self,
|
||||
domain: str,
|
||||
binding_config: GraphBindingConfig
|
||||
) -> None:
|
||||
"""Configure graph binding for a domain."""
|
||||
self.domain_graph_bindings[domain] = binding_config
|
||||
logger.info(f"Configured graph binding for domain: {domain}")
|
||||
|
||||
def register_custom_resolver(
|
||||
self,
|
||||
resolver_name: str,
|
||||
resolver_func: Callable
|
||||
) -> None:
|
||||
"""Register a custom DataPoint resolver."""
|
||||
self.custom_resolvers[resolver_name] = resolver_func
|
||||
self.datapoint_resolver.register_custom_resolver(resolver_name, resolver_func)
|
||||
logger.info(f"Registered custom resolver: {resolver_name}")
|
||||
|
||||
def register_binding_strategy(
|
||||
self,
|
||||
strategy_name: str,
|
||||
strategy_func: Callable
|
||||
) -> None:
|
||||
"""Register a custom graph binding strategy."""
|
||||
self.custom_binding_strategies[strategy_name] = strategy_func
|
||||
self.graph_binder.register_binding_strategy(strategy_name, strategy_func)
|
||||
logger.info(f"Registered custom binding strategy: {strategy_name}")
|
||||
|
||||
def _prioritize_ontologies(self, ontologies: List[OntologyGraph]) -> List[OntologyGraph]:
|
||||
"""Prioritize ontologies by scope (dataset > user > pipeline > domain > global)."""
|
||||
priority_order = [
|
||||
OntologyScope.DATASET,
|
||||
OntologyScope.USER,
|
||||
OntologyScope.PIPELINE,
|
||||
OntologyScope.DOMAIN,
|
||||
OntologyScope.GLOBAL,
|
||||
]
|
||||
|
||||
seen_ids = set()
|
||||
prioritized = []
|
||||
|
||||
for scope in priority_order:
|
||||
for ontology in ontologies:
|
||||
if ontology.scope == scope and ontology.id not in seen_ids:
|
||||
prioritized.append(ontology)
|
||||
seen_ids.add(ontology.id)
|
||||
|
||||
return prioritized
|
||||
|
||||
def _get_adapter_for_ontology(self, ontology: OntologyGraph) -> str:
|
||||
"""Get the appropriate adapter name for an ontology."""
|
||||
# This could be made more sophisticated with adapter selection logic
|
||||
format_to_adapter = {
|
||||
"rdf_xml": "rdf_adapter",
|
||||
"owl": "rdf_adapter",
|
||||
"json": "json_adapter",
|
||||
"llm_generated": "llm_adapter",
|
||||
}
|
||||
return format_to_adapter.get(ontology.format.value, "default_adapter")
|
||||
|
||||
def _find_mapping_for_node(
|
||||
self,
|
||||
node: OntologyNode,
|
||||
mappings: List[DataPointMapping]
|
||||
) -> Optional[DataPointMapping]:
|
||||
"""Find the appropriate DataPoint mapping for a node."""
|
||||
for mapping in mappings:
|
||||
if mapping.ontology_node_type == node.type:
|
||||
return mapping
|
||||
return None
|
||||
|
||||
|
||||
async def create_ontology_manager(
|
||||
registry: IOntologyRegistry,
|
||||
providers: Optional[Dict[str, IOntologyProvider]] = None,
|
||||
adapters: Optional[Dict[str, IOntologyAdapter]] = None,
|
||||
datapoint_resolver: Optional[IDataPointResolver] = None,
|
||||
graph_binder: Optional[IGraphBinder] = None,
|
||||
) -> OntologyManager:
|
||||
"""Factory function to create an OntologyManager with default implementations."""
|
||||
|
||||
# Import default implementations
|
||||
from cognee.modules.ontology.registry import OntologyRegistry
|
||||
from cognee.modules.ontology.providers import RDFOntologyProvider, JSONOntologyProvider
|
||||
from cognee.modules.ontology.adapters import DefaultOntologyAdapter
|
||||
from cognee.modules.ontology.resolvers import DefaultDataPointResolver
|
||||
from cognee.modules.ontology.binders import DefaultGraphBinder
|
||||
|
||||
if providers is None:
|
||||
providers = {
|
||||
"rdf_provider": RDFOntologyProvider(),
|
||||
"json_provider": JSONOntologyProvider(),
|
||||
}
|
||||
|
||||
if adapters is None:
|
||||
adapters = {
|
||||
"default_adapter": DefaultOntologyAdapter(),
|
||||
"rdf_adapter": DefaultOntologyAdapter(),
|
||||
"json_adapter": DefaultOntologyAdapter(),
|
||||
}
|
||||
|
||||
if datapoint_resolver is None:
|
||||
datapoint_resolver = DefaultDataPointResolver()
|
||||
|
||||
if graph_binder is None:
|
||||
graph_binder = DefaultGraphBinder()
|
||||
|
||||
return OntologyManager(
|
||||
registry=registry,
|
||||
providers=providers,
|
||||
adapters=adapters,
|
||||
datapoint_resolver=datapoint_resolver,
|
||||
graph_binder=graph_binder,
|
||||
)
|
||||
27
cognee/modules/ontology/methods/__init__.py
Normal file
27
cognee/modules/ontology/methods/__init__.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
# Ontology CRUD operations following Cognee methods pattern
|
||||
|
||||
# Create operations
|
||||
from .create_ontology import create_ontology
|
||||
from .create_ontology_from_file import create_ontology_from_file
|
||||
|
||||
# Read operations
|
||||
from .get_ontology import get_ontology
|
||||
from .get_ontologies import get_ontologies
|
||||
from .get_ontology_by_domain import get_ontology_by_domain
|
||||
from .load_ontology import load_ontology
|
||||
|
||||
# Update operations
|
||||
from .update_ontology import update_ontology
|
||||
from .register_ontology import register_ontology
|
||||
|
||||
# Delete operations
|
||||
from .delete_ontology import delete_ontology
|
||||
from .unregister_ontology import unregister_ontology
|
||||
|
||||
# Search operations
|
||||
from .search_ontologies import search_ontologies
|
||||
from .find_matching_nodes import find_matching_nodes
|
||||
|
||||
# Utility operations
|
||||
from .validate_ontology import validate_ontology
|
||||
from .merge_ontologies import merge_ontologies
|
||||
193
cognee/modules/ontology/methods/create_ontology.py
Normal file
193
cognee/modules/ontology/methods/create_ontology.py
Normal file
|
|
@ -0,0 +1,193 @@
|
|||
"""Create ontology method following Cognee patterns."""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from uuid import uuid4
|
||||
from datetime import datetime
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.ontology.interfaces import (
|
||||
OntologyGraph,
|
||||
OntologyNode,
|
||||
OntologyEdge,
|
||||
OntologyScope,
|
||||
OntologyFormat,
|
||||
OntologyContext,
|
||||
)
|
||||
from cognee.modules.ontology.config import get_ontology_config
|
||||
from cognee.modules.users.models import User
|
||||
|
||||
logger = get_logger("ontology.create")
|
||||
|
||||
|
||||
async def create_ontology(
|
||||
ontology_data: Dict[str, Any],
|
||||
user: User,
|
||||
scope: OntologyScope = OntologyScope.USER,
|
||||
context: Optional[OntologyContext] = None
|
||||
) -> OntologyGraph:
|
||||
"""
|
||||
Create a new ontology from provided data.
|
||||
|
||||
Args:
|
||||
ontology_data: Dictionary containing ontology structure
|
||||
user: User creating the ontology
|
||||
scope: Scope for the ontology (user, domain, global, etc.)
|
||||
context: Optional context for the ontology
|
||||
|
||||
Returns:
|
||||
Created OntologyGraph instance
|
||||
|
||||
Raises:
|
||||
ValueError: If ontology_data is invalid
|
||||
RuntimeError: If ontology creation fails
|
||||
"""
|
||||
|
||||
try:
|
||||
config = get_ontology_config()
|
||||
|
||||
# Validate required fields
|
||||
if "nodes" not in ontology_data:
|
||||
raise ValueError("Ontology data must contain 'nodes' field")
|
||||
|
||||
# Extract basic information
|
||||
ontology_id = ontology_data.get("id", str(uuid4()))
|
||||
ontology_name = ontology_data.get("name", f"ontology_{ontology_id}")
|
||||
description = ontology_data.get("description", "")
|
||||
format_type = OntologyFormat(ontology_data.get("format", config.default_format))
|
||||
|
||||
# Parse nodes
|
||||
nodes = []
|
||||
for node_data in ontology_data["nodes"]:
|
||||
if not isinstance(node_data, dict):
|
||||
logger.warning(f"Skipping invalid node data: {node_data}")
|
||||
continue
|
||||
|
||||
try:
|
||||
node = OntologyNode(
|
||||
id=node_data.get("id", str(uuid4())),
|
||||
name=node_data.get("name", "unnamed_node"),
|
||||
type=node_data.get("type", "entity"),
|
||||
description=node_data.get("description", ""),
|
||||
category=node_data.get("category", "general"),
|
||||
properties=node_data.get("properties", {})
|
||||
)
|
||||
nodes.append(node)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse node {node_data.get('id', 'unknown')}: {e}")
|
||||
continue
|
||||
|
||||
# Parse edges
|
||||
edges = []
|
||||
for edge_data in ontology_data.get("edges", []):
|
||||
if not isinstance(edge_data, dict):
|
||||
logger.warning(f"Skipping invalid edge data: {edge_data}")
|
||||
continue
|
||||
|
||||
try:
|
||||
edge = OntologyEdge(
|
||||
id=edge_data.get("id", str(uuid4())),
|
||||
source_id=edge_data["source"],
|
||||
target_id=edge_data["target"],
|
||||
relationship_type=edge_data.get("relationship", "related_to"),
|
||||
properties=edge_data.get("properties", {}),
|
||||
weight=edge_data.get("weight")
|
||||
)
|
||||
edges.append(edge)
|
||||
except KeyError as e:
|
||||
logger.warning(f"Edge missing required field {e}: {edge_data}")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse edge: {e}")
|
||||
continue
|
||||
|
||||
# Create metadata
|
||||
metadata = ontology_data.get("metadata", {})
|
||||
metadata.update({
|
||||
"created_by": str(user.id),
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"scope": scope.value,
|
||||
"format": format_type.value,
|
||||
})
|
||||
|
||||
if context:
|
||||
metadata.update({
|
||||
"domain": context.domain,
|
||||
"pipeline_name": context.pipeline_name,
|
||||
"dataset_id": context.dataset_id,
|
||||
})
|
||||
|
||||
# Create ontology
|
||||
ontology = OntologyGraph(
|
||||
id=ontology_id,
|
||||
name=ontology_name,
|
||||
description=description,
|
||||
format=format_type,
|
||||
scope=scope,
|
||||
nodes=nodes,
|
||||
edges=edges,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Created ontology '{ontology_name}' with {len(nodes)} nodes "
|
||||
f"and {len(edges)} edges for user {user.id}"
|
||||
)
|
||||
|
||||
return ontology
|
||||
|
||||
except ValueError:
|
||||
# Re-raise validation errors
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to create ontology: {e}"
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
|
||||
async def create_empty_ontology(
|
||||
name: str,
|
||||
user: User,
|
||||
scope: OntologyScope = OntologyScope.USER,
|
||||
domain: Optional[str] = None,
|
||||
description: str = ""
|
||||
) -> OntologyGraph:
|
||||
"""
|
||||
Create an empty ontology with basic structure.
|
||||
|
||||
Args:
|
||||
name: Name for the ontology
|
||||
user: User creating the ontology
|
||||
scope: Scope for the ontology
|
||||
domain: Optional domain for the ontology
|
||||
description: Optional description
|
||||
|
||||
Returns:
|
||||
Empty OntologyGraph instance
|
||||
"""
|
||||
|
||||
config = get_ontology_config()
|
||||
ontology_id = str(uuid4())
|
||||
|
||||
metadata = {
|
||||
"created_by": str(user.id),
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"scope": scope.value,
|
||||
"format": config.default_format,
|
||||
}
|
||||
|
||||
if domain:
|
||||
metadata["domain"] = domain
|
||||
|
||||
ontology = OntologyGraph(
|
||||
id=ontology_id,
|
||||
name=name,
|
||||
description=description,
|
||||
format=OntologyFormat(config.default_format),
|
||||
scope=scope,
|
||||
nodes=[],
|
||||
edges=[],
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
logger.info(f"Created empty ontology '{name}' for user {user.id}")
|
||||
return ontology
|
||||
56
cognee/modules/ontology/methods/delete_ontology.py
Normal file
56
cognee/modules/ontology/methods/delete_ontology.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
"""Delete ontology method following Cognee patterns."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.ontology.interfaces import OntologyContext
|
||||
from cognee.modules.ontology.registry import OntologyRegistry
|
||||
from cognee.modules.users.models import User
|
||||
|
||||
logger = get_logger("ontology.delete")
|
||||
|
||||
|
||||
async def delete_ontology(
|
||||
ontology_id: str,
|
||||
user: User,
|
||||
context: Optional[OntologyContext] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Delete ontology following Cognee delete patterns.
|
||||
|
||||
Args:
|
||||
ontology_id: ID of the ontology to delete
|
||||
user: User requesting deletion
|
||||
context: Optional context for access control
|
||||
|
||||
Returns:
|
||||
True if deletion successful, False otherwise
|
||||
"""
|
||||
|
||||
try:
|
||||
# This would use dependency injection in real implementation
|
||||
registry = OntologyRegistry()
|
||||
|
||||
# Get ontology first to check existence and permissions
|
||||
ontology = await registry.get_ontology(ontology_id, context)
|
||||
|
||||
if ontology is None:
|
||||
logger.warning(f"Ontology {ontology_id} not found for deletion")
|
||||
return False
|
||||
|
||||
# TODO: Add access control check
|
||||
# For now, assume user has access
|
||||
|
||||
# Perform deletion
|
||||
success = await registry.unregister_ontology(ontology_id, context)
|
||||
|
||||
if success:
|
||||
logger.info(f"Deleted ontology {ontology_id} for user {user.id}")
|
||||
else:
|
||||
logger.error(f"Failed to delete ontology {ontology_id}")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting ontology {ontology_id}: {e}")
|
||||
return False
|
||||
48
cognee/modules/ontology/methods/get_ontology.py
Normal file
48
cognee/modules/ontology/methods/get_ontology.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
"""Get ontology method following Cognee patterns."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.ontology.interfaces import OntologyGraph, OntologyContext
|
||||
from cognee.modules.ontology.registry import OntologyRegistry
|
||||
from cognee.modules.users.models import User
|
||||
|
||||
logger = get_logger("ontology.get")
|
||||
|
||||
|
||||
async def get_ontology(
|
||||
ontology_id: str,
|
||||
user: User,
|
||||
context: Optional[OntologyContext] = None
|
||||
) -> Optional[OntologyGraph]:
|
||||
"""
|
||||
Get ontology by ID following Cognee get patterns.
|
||||
|
||||
Args:
|
||||
ontology_id: ID of the ontology to retrieve
|
||||
user: User requesting the ontology
|
||||
context: Optional context for access control
|
||||
|
||||
Returns:
|
||||
OntologyGraph if found and accessible, None otherwise
|
||||
"""
|
||||
|
||||
try:
|
||||
# This would use dependency injection in real implementation
|
||||
registry = OntologyRegistry()
|
||||
|
||||
ontology = await registry.get_ontology(ontology_id, context)
|
||||
|
||||
if ontology is None:
|
||||
logger.info(f"Ontology {ontology_id} not found")
|
||||
return None
|
||||
|
||||
# TODO: Add access control check based on user and ontology ownership
|
||||
# For now, assume user has access
|
||||
|
||||
logger.info(f"Retrieved ontology {ontology_id} for user {user.id}")
|
||||
return ontology
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get ontology {ontology_id}: {e}")
|
||||
return None
|
||||
83
cognee/modules/ontology/methods/load_ontology.py
Normal file
83
cognee/modules/ontology/methods/load_ontology.py
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
"""Load ontology method following Cognee patterns."""
|
||||
|
||||
from typing import Union, Dict, Any, Optional
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.ontology.interfaces import OntologyGraph, OntologyContext
|
||||
from cognee.modules.ontology.providers import JSONOntologyProvider, RDFOntologyProvider, CSVOntologyProvider
|
||||
from cognee.modules.ontology.config import get_ontology_config
|
||||
from cognee.modules.users.models import User
|
||||
|
||||
logger = get_logger("ontology.load")
|
||||
|
||||
|
||||
async def load_ontology(
|
||||
source: Union[str, Dict[str, Any]],
|
||||
user: User,
|
||||
context: Optional[OntologyContext] = None
|
||||
) -> Optional[OntologyGraph]:
|
||||
"""
|
||||
Load ontology from various sources following Cognee patterns.
|
||||
|
||||
Args:
|
||||
source: File path, URL, or data dictionary
|
||||
user: User loading the ontology
|
||||
context: Optional context for the ontology
|
||||
|
||||
Returns:
|
||||
Loaded OntologyGraph or None if loading failed
|
||||
"""
|
||||
|
||||
try:
|
||||
config = get_ontology_config()
|
||||
|
||||
# Determine provider based on source
|
||||
provider = None
|
||||
|
||||
if isinstance(source, str):
|
||||
# File path or URL
|
||||
if source.endswith(('.owl', '.rdf', '.xml')) and config.rdf_provider_enabled:
|
||||
provider = RDFOntologyProvider()
|
||||
if not provider.available:
|
||||
logger.warning("RDF provider not available, falling back to JSON")
|
||||
provider = None
|
||||
elif source.endswith('.json') and config.json_provider_enabled:
|
||||
provider = JSONOntologyProvider()
|
||||
elif source.endswith('.csv') and config.csv_provider_enabled:
|
||||
provider = CSVOntologyProvider()
|
||||
else:
|
||||
# Default to JSON provider
|
||||
provider = JSONOntologyProvider()
|
||||
else:
|
||||
# Dictionary data - use JSON provider
|
||||
provider = JSONOntologyProvider()
|
||||
|
||||
if provider is None:
|
||||
logger.error(f"No suitable provider found for source: {source}")
|
||||
return None
|
||||
|
||||
# Load ontology
|
||||
ontology = await provider.load_ontology(source, context)
|
||||
|
||||
# Validate ontology
|
||||
if not await provider.validate_ontology(ontology):
|
||||
logger.error(f"Ontology validation failed for source: {source}")
|
||||
return None
|
||||
|
||||
# Add metadata about loading
|
||||
ontology.metadata.update({
|
||||
"loaded_by": str(user.id),
|
||||
"source": str(source),
|
||||
"provider": provider.__class__.__name__,
|
||||
})
|
||||
|
||||
logger.info(
|
||||
f"Loaded ontology '{ontology.name}' with {len(ontology.nodes)} nodes "
|
||||
f"from source: {source}"
|
||||
)
|
||||
|
||||
return ontology
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load ontology from {source}: {e}")
|
||||
return None
|
||||
73
cognee/modules/ontology/methods/register_ontology.py
Normal file
73
cognee/modules/ontology/methods/register_ontology.py
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
"""Register ontology method following Cognee patterns."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.ontology.interfaces import OntologyGraph, OntologyScope, OntologyContext
|
||||
from cognee.modules.ontology.registry import OntologyRegistry
|
||||
from cognee.modules.users.models import User
|
||||
|
||||
logger = get_logger("ontology.register")
|
||||
|
||||
|
||||
async def register_ontology(
|
||||
ontology: OntologyGraph,
|
||||
user: User,
|
||||
scope: OntologyScope = OntologyScope.USER,
|
||||
context: Optional[OntologyContext] = None
|
||||
) -> str:
|
||||
"""
|
||||
Register ontology in the registry following Cognee patterns.
|
||||
|
||||
Args:
|
||||
ontology: OntologyGraph to register
|
||||
user: User registering the ontology
|
||||
scope: Scope for the ontology
|
||||
context: Optional context for registration
|
||||
|
||||
Returns:
|
||||
Ontology ID if registration successful
|
||||
|
||||
Raises:
|
||||
ValueError: If ontology is invalid
|
||||
RuntimeError: If registration fails
|
||||
"""
|
||||
|
||||
try:
|
||||
# Validate ontology
|
||||
if not ontology.nodes:
|
||||
raise ValueError("Cannot register empty ontology")
|
||||
|
||||
# Update metadata
|
||||
ontology.metadata.update({
|
||||
"registered_by": str(user.id),
|
||||
"scope": scope.value,
|
||||
})
|
||||
|
||||
if context:
|
||||
ontology.metadata.update({
|
||||
"domain": context.domain,
|
||||
"pipeline_name": context.pipeline_name,
|
||||
"dataset_id": context.dataset_id,
|
||||
})
|
||||
|
||||
# This would use dependency injection in real implementation
|
||||
registry = OntologyRegistry()
|
||||
|
||||
# Register in registry
|
||||
ontology_id = await registry.register_ontology(ontology, scope, context)
|
||||
|
||||
logger.info(
|
||||
f"Registered ontology '{ontology.name}' (ID: {ontology_id}) "
|
||||
f"with scope {scope.value} for user {user.id}"
|
||||
)
|
||||
|
||||
return ontology_id
|
||||
|
||||
except ValueError:
|
||||
# Re-raise validation errors
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to register ontology '{ontology.name}': {e}"
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg) from e
|
||||
496
cognee/modules/ontology/pipeline_integration.py
Normal file
496
cognee/modules/ontology/pipeline_integration.py
Normal file
|
|
@ -0,0 +1,496 @@
|
|||
"""Pipeline integration for ontology system."""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
import inspect
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.pipelines.tasks.task import Task
|
||||
from cognee.modules.ontology.interfaces import (
|
||||
IOntologyManager,
|
||||
OntologyContext,
|
||||
DataPointMapping,
|
||||
GraphBindingConfig,
|
||||
)
|
||||
|
||||
logger = get_logger("OntologyPipelineIntegration")
|
||||
|
||||
|
||||
class OntologyInjector:
|
||||
"""Handles injection of ontology context into pipeline tasks."""
|
||||
|
||||
def __init__(self, ontology_manager: IOntologyManager):
|
||||
self.ontology_manager = ontology_manager
|
||||
self.task_ontology_configs: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def configure_task_ontology(
|
||||
self,
|
||||
task_name: str,
|
||||
ontology_config: Dict[str, Any]
|
||||
) -> None:
|
||||
"""Configure ontology settings for a specific task."""
|
||||
self.task_ontology_configs[task_name] = ontology_config
|
||||
logger.info(f"Configured ontology for task: {task_name}")
|
||||
|
||||
async def inject_into_task(
|
||||
self,
|
||||
task: Task,
|
||||
context: OntologyContext
|
||||
) -> Task:
|
||||
"""Inject ontology context into a task."""
|
||||
|
||||
task_name = self._get_task_name(task)
|
||||
|
||||
# Check if task has ontology configuration
|
||||
if task_name not in self.task_ontology_configs:
|
||||
# No specific configuration, use default behavior
|
||||
return await self._apply_default_injection(task, context)
|
||||
|
||||
# Apply configured ontology injection
|
||||
config = self.task_ontology_configs[task_name]
|
||||
return await self._apply_configured_injection(task, context, config)
|
||||
|
||||
async def enhance_task_params(
|
||||
self,
|
||||
task_params: Dict[str, Any],
|
||||
context: OntologyContext,
|
||||
task_name: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Enhance task parameters with ontological context."""
|
||||
|
||||
enhanced_params = await self.ontology_manager.inject_ontology_into_task(
|
||||
task_name or "unknown_task",
|
||||
task_params,
|
||||
context
|
||||
)
|
||||
|
||||
return enhanced_params
|
||||
|
||||
def _get_task_name(self, task: Task) -> str:
|
||||
"""Extract task name from Task object."""
|
||||
if hasattr(task.executable, '__name__'):
|
||||
return task.executable.__name__
|
||||
elif hasattr(task.executable, '__class__'):
|
||||
return task.executable.__class__.__name__
|
||||
else:
|
||||
return str(task.executable)
|
||||
|
||||
async def _apply_default_injection(
|
||||
self,
|
||||
task: Task,
|
||||
context: OntologyContext
|
||||
) -> Task:
|
||||
"""Apply default ontology injection to task."""
|
||||
|
||||
# Get applicable ontologies
|
||||
ontologies = await self.ontology_manager.get_applicable_ontologies(context)
|
||||
|
||||
if not ontologies:
|
||||
logger.debug("No applicable ontologies found for task")
|
||||
return task
|
||||
|
||||
# Enhance task parameters
|
||||
enhanced_params = task.default_params.copy()
|
||||
enhanced_params["kwargs"]["ontology_context"] = context
|
||||
enhanced_params["kwargs"]["available_ontologies"] = [ont.id for ont in ontologies]
|
||||
|
||||
# Create new task with enhanced parameters
|
||||
enhanced_task = Task(
|
||||
task.executable,
|
||||
*enhanced_params["args"],
|
||||
task_config=task.task_config,
|
||||
**enhanced_params["kwargs"]
|
||||
)
|
||||
|
||||
return enhanced_task
|
||||
|
||||
async def _apply_configured_injection(
|
||||
self,
|
||||
task: Task,
|
||||
context: OntologyContext,
|
||||
config: Dict[str, Any]
|
||||
) -> Task:
|
||||
"""Apply configured ontology injection to task."""
|
||||
|
||||
enhanced_params = task.default_params.copy()
|
||||
|
||||
# Apply ontology-specific enhancements based on config
|
||||
if config.get("enhance_with_entities", False):
|
||||
enhanced_params = await self._inject_entity_enhancement(
|
||||
enhanced_params, context, config
|
||||
)
|
||||
|
||||
if config.get("inject_datapoint_mappings", False):
|
||||
enhanced_params = await self._inject_datapoint_mappings(
|
||||
enhanced_params, context, config
|
||||
)
|
||||
|
||||
if config.get("inject_graph_binding", False):
|
||||
enhanced_params = await self._inject_graph_binding(
|
||||
enhanced_params, context, config
|
||||
)
|
||||
|
||||
if config.get("enable_ontology_validation", False):
|
||||
enhanced_params = await self._inject_validation_config(
|
||||
enhanced_params, context, config
|
||||
)
|
||||
|
||||
# Create new task with enhanced parameters
|
||||
enhanced_task = Task(
|
||||
task.executable,
|
||||
*enhanced_params["args"],
|
||||
task_config=task.task_config,
|
||||
**enhanced_params["kwargs"]
|
||||
)
|
||||
|
||||
return enhanced_task
|
||||
|
||||
async def _inject_entity_enhancement(
|
||||
self,
|
||||
params: Dict[str, Any],
|
||||
context: OntologyContext,
|
||||
config: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Inject entity enhancement capabilities."""
|
||||
|
||||
params["kwargs"]["ontology_manager"] = self.ontology_manager
|
||||
params["kwargs"]["ontology_context"] = context
|
||||
params["kwargs"]["entity_extraction_enabled"] = True
|
||||
|
||||
# Add specific entity types to extract if configured
|
||||
if "target_entity_types" in config:
|
||||
params["kwargs"]["target_entity_types"] = config["target_entity_types"]
|
||||
|
||||
return params
|
||||
|
||||
async def _inject_datapoint_mappings(
|
||||
self,
|
||||
params: Dict[str, Any],
|
||||
context: OntologyContext,
|
||||
config: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Inject DataPoint mapping configurations."""
|
||||
|
||||
# Get domain-specific mappings
|
||||
if context.domain:
|
||||
domain_mappings = getattr(
|
||||
self.ontology_manager, 'domain_datapoint_mappings', {}
|
||||
).get(context.domain, [])
|
||||
|
||||
if domain_mappings:
|
||||
params["kwargs"]["datapoint_mappings"] = domain_mappings
|
||||
params["kwargs"]["datapoint_resolver"] = self.ontology_manager.datapoint_resolver
|
||||
|
||||
# Add custom mappings from config
|
||||
if "custom_mappings" in config:
|
||||
params["kwargs"]["custom_datapoint_mappings"] = config["custom_mappings"]
|
||||
|
||||
return params
|
||||
|
||||
async def _inject_graph_binding(
|
||||
self,
|
||||
params: Dict[str, Any],
|
||||
context: OntologyContext,
|
||||
config: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Inject graph binding configurations."""
|
||||
|
||||
# Get domain-specific binding config
|
||||
if context.domain:
|
||||
domain_binding = getattr(
|
||||
self.ontology_manager, 'domain_graph_bindings', {}
|
||||
).get(context.domain)
|
||||
|
||||
if domain_binding:
|
||||
params["kwargs"]["graph_binding_config"] = domain_binding
|
||||
params["kwargs"]["graph_binder"] = self.ontology_manager.graph_binder
|
||||
|
||||
# Add custom binding from config
|
||||
if "custom_binding" in config:
|
||||
params["kwargs"]["custom_graph_binding"] = config["custom_binding"]
|
||||
|
||||
return params
|
||||
|
||||
async def _inject_validation_config(
|
||||
self,
|
||||
params: Dict[str, Any],
|
||||
context: OntologyContext,
|
||||
config: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Inject ontology validation configurations."""
|
||||
|
||||
params["kwargs"]["ontology_validation_enabled"] = True
|
||||
params["kwargs"]["validation_threshold"] = config.get("validation_threshold", 0.8)
|
||||
params["kwargs"]["strict_validation"] = config.get("strict_validation", False)
|
||||
|
||||
return params
|
||||
|
||||
|
||||
class OntologyAwareTaskWrapper:
|
||||
"""Wrapper to make existing tasks ontology-aware."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
original_task: Task,
|
||||
ontology_manager: IOntologyManager,
|
||||
context: OntologyContext
|
||||
):
|
||||
self.original_task = original_task
|
||||
self.ontology_manager = ontology_manager
|
||||
self.context = context
|
||||
|
||||
async def execute_with_ontology(self, *args, **kwargs):
|
||||
"""Execute task with ontology enhancements."""
|
||||
|
||||
# Enhance content if provided
|
||||
if "content" in kwargs:
|
||||
enhanced_content = await self.ontology_manager.enhance_with_ontology(
|
||||
kwargs["content"], self.context
|
||||
)
|
||||
kwargs["enhanced_content"] = enhanced_content
|
||||
|
||||
# Add ontology context
|
||||
kwargs["ontology_context"] = self.context
|
||||
kwargs["ontology_manager"] = self.ontology_manager
|
||||
|
||||
# Execute original task
|
||||
return await self.original_task.run(*args, **kwargs)
|
||||
|
||||
|
||||
class PipelineOntologyConfigurator:
|
||||
"""Configures ontology integration for entire pipelines."""
|
||||
|
||||
def __init__(self, ontology_manager: IOntologyManager):
|
||||
self.ontology_manager = ontology_manager
|
||||
self.pipeline_configs: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def configure_pipeline(
|
||||
self,
|
||||
pipeline_name: str,
|
||||
domain: str,
|
||||
datapoint_mappings: List[DataPointMapping],
|
||||
graph_binding: GraphBindingConfig,
|
||||
task_specific_configs: Optional[Dict[str, Dict[str, Any]]] = None
|
||||
) -> None:
|
||||
"""Configure ontology for an entire pipeline."""
|
||||
|
||||
# Configure domain mappings
|
||||
self.ontology_manager.configure_datapoint_mapping(domain, datapoint_mappings)
|
||||
self.ontology_manager.configure_graph_binding(domain, graph_binding)
|
||||
|
||||
# Store pipeline configuration
|
||||
self.pipeline_configs[pipeline_name] = {
|
||||
"domain": domain,
|
||||
"datapoint_mappings": datapoint_mappings,
|
||||
"graph_binding": graph_binding,
|
||||
"task_configs": task_specific_configs or {},
|
||||
}
|
||||
|
||||
logger.info(f"Configured ontology for pipeline: {pipeline_name}")
|
||||
|
||||
def get_pipeline_context(
|
||||
self,
|
||||
pipeline_name: str,
|
||||
user_id: Optional[str] = None,
|
||||
dataset_id: Optional[str] = None,
|
||||
custom_properties: Optional[Dict[str, Any]] = None
|
||||
) -> OntologyContext:
|
||||
"""Get ontology context for a pipeline."""
|
||||
|
||||
config = self.pipeline_configs.get(pipeline_name, {})
|
||||
|
||||
return OntologyContext(
|
||||
user_id=user_id,
|
||||
dataset_id=dataset_id,
|
||||
pipeline_name=pipeline_name,
|
||||
domain=config.get("domain"),
|
||||
custom_properties=custom_properties or {}
|
||||
)
|
||||
|
||||
def create_ontology_injector(self, pipeline_name: str) -> OntologyInjector:
|
||||
"""Create an ontology injector configured for a specific pipeline."""
|
||||
|
||||
injector = OntologyInjector(self.ontology_manager)
|
||||
|
||||
# Apply pipeline-specific task configurations
|
||||
if pipeline_name in self.pipeline_configs:
|
||||
task_configs = self.pipeline_configs[pipeline_name].get("task_configs", {})
|
||||
for task_name, config in task_configs.items():
|
||||
injector.configure_task_ontology(task_name, config)
|
||||
|
||||
return injector
|
||||
|
||||
|
||||
# Pre-configured pipeline setups for common domains
|
||||
def create_medical_pipeline_config() -> Dict[str, Any]:
|
||||
"""Create pre-configured ontology setup for medical pipelines."""
|
||||
|
||||
datapoint_mappings = [
|
||||
DataPointMapping(
|
||||
ontology_node_type="Disease",
|
||||
datapoint_class="cognee.infrastructure.engine.models.DataPoint.DataPoint",
|
||||
field_mappings={
|
||||
"name": "name",
|
||||
"description": "description",
|
||||
"icd_code": "medical_code",
|
||||
"severity": "severity_level",
|
||||
},
|
||||
validation_rules=["required:name", "type:severity_level:str"]
|
||||
),
|
||||
DataPointMapping(
|
||||
ontology_node_type="Symptom",
|
||||
datapoint_class="cognee.infrastructure.engine.models.DataPoint.DataPoint",
|
||||
field_mappings={
|
||||
"name": "name",
|
||||
"description": "description",
|
||||
"frequency": "occurrence_rate",
|
||||
}
|
||||
),
|
||||
]
|
||||
|
||||
graph_binding = GraphBindingConfig(
|
||||
node_type_mapping={
|
||||
"Disease": "medical_entity",
|
||||
"Symptom": "clinical_finding",
|
||||
"Treatment": "therapeutic_procedure",
|
||||
},
|
||||
edge_type_mapping={
|
||||
"treats": "therapeutic_relationship",
|
||||
"causes": "causality",
|
||||
"associated_with": "clinical_association",
|
||||
}
|
||||
)
|
||||
|
||||
task_configs = {
|
||||
"extract_graph_from_data": {
|
||||
"enhance_with_entities": True,
|
||||
"inject_datapoint_mappings": True,
|
||||
"inject_graph_binding": True,
|
||||
"target_entity_types": ["Disease", "Symptom", "Treatment"],
|
||||
},
|
||||
"summarize_text": {
|
||||
"enhance_with_entities": True,
|
||||
"enable_ontology_validation": True,
|
||||
"validation_threshold": 0.85,
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
"domain": "medical",
|
||||
"datapoint_mappings": datapoint_mappings,
|
||||
"graph_binding": graph_binding,
|
||||
"task_configs": task_configs,
|
||||
}
|
||||
|
||||
|
||||
def create_legal_pipeline_config() -> Dict[str, Any]:
|
||||
"""Create pre-configured ontology setup for legal pipelines."""
|
||||
|
||||
datapoint_mappings = [
|
||||
DataPointMapping(
|
||||
ontology_node_type="Law",
|
||||
datapoint_class="cognee.infrastructure.engine.models.DataPoint.DataPoint",
|
||||
field_mappings={
|
||||
"name": "name",
|
||||
"description": "description",
|
||||
"jurisdiction": "legal_authority",
|
||||
"citation": "legal_citation",
|
||||
}
|
||||
),
|
||||
DataPointMapping(
|
||||
ontology_node_type="Case",
|
||||
datapoint_class="cognee.infrastructure.engine.models.DataPoint.DataPoint",
|
||||
field_mappings={
|
||||
"name": "name",
|
||||
"description": "description",
|
||||
"court": "court_level",
|
||||
"date": "decision_date",
|
||||
}
|
||||
),
|
||||
]
|
||||
|
||||
graph_binding = GraphBindingConfig(
|
||||
node_type_mapping={
|
||||
"Law": "legal_statute",
|
||||
"Case": "legal_precedent",
|
||||
"Court": "judicial_body",
|
||||
},
|
||||
edge_type_mapping={
|
||||
"cites": "legal_citation",
|
||||
"overrules": "legal_override",
|
||||
"applies": "legal_application",
|
||||
}
|
||||
)
|
||||
|
||||
task_configs = {
|
||||
"extract_graph_from_data": {
|
||||
"enhance_with_entities": True,
|
||||
"inject_datapoint_mappings": True,
|
||||
"inject_graph_binding": True,
|
||||
"target_entity_types": ["Law", "Case", "Court", "Legal_Concept"],
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
"domain": "legal",
|
||||
"datapoint_mappings": datapoint_mappings,
|
||||
"graph_binding": graph_binding,
|
||||
"task_configs": task_configs,
|
||||
}
|
||||
|
||||
|
||||
def create_code_pipeline_config() -> Dict[str, Any]:
|
||||
"""Create pre-configured ontology setup for code analysis pipelines."""
|
||||
|
||||
datapoint_mappings = [
|
||||
DataPointMapping(
|
||||
ontology_node_type="Function",
|
||||
datapoint_class="cognee.infrastructure.engine.models.DataPoint.DataPoint",
|
||||
field_mappings={
|
||||
"name": "name",
|
||||
"description": "description",
|
||||
"parameters": "function_parameters",
|
||||
"return_type": "return_type",
|
||||
}
|
||||
),
|
||||
DataPointMapping(
|
||||
ontology_node_type="Class",
|
||||
datapoint_class="cognee.infrastructure.engine.models.DataPoint.DataPoint",
|
||||
field_mappings={
|
||||
"name": "name",
|
||||
"description": "description",
|
||||
"methods": "class_methods",
|
||||
"inheritance": "parent_classes",
|
||||
}
|
||||
),
|
||||
]
|
||||
|
||||
graph_binding = GraphBindingConfig(
|
||||
node_type_mapping={
|
||||
"Function": "code_function",
|
||||
"Class": "code_class",
|
||||
"Module": "code_module",
|
||||
"Variable": "code_variable",
|
||||
},
|
||||
edge_type_mapping={
|
||||
"calls": "function_call",
|
||||
"inherits": "inheritance",
|
||||
"imports": "module_import",
|
||||
"defines": "definition",
|
||||
}
|
||||
)
|
||||
|
||||
task_configs = {
|
||||
"extract_graph_from_code": {
|
||||
"enhance_with_entities": True,
|
||||
"inject_datapoint_mappings": True,
|
||||
"inject_graph_binding": True,
|
||||
"target_entity_types": ["Function", "Class", "Module", "Variable"],
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
"domain": "code",
|
||||
"datapoint_mappings": datapoint_mappings,
|
||||
"graph_binding": graph_binding,
|
||||
"task_configs": task_configs,
|
||||
}
|
||||
475
cognee/modules/ontology/providers.py
Normal file
475
cognee/modules/ontology/providers.py
Normal file
|
|
@ -0,0 +1,475 @@
|
|||
"""Ontology provider implementations."""
|
||||
|
||||
import json
|
||||
import csv
|
||||
from typing import Dict, Any, Union, Optional
|
||||
from pathlib import Path
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.ontology.interfaces import (
|
||||
IOntologyProvider,
|
||||
OntologyGraph,
|
||||
OntologyNode,
|
||||
OntologyEdge,
|
||||
OntologyFormat,
|
||||
OntologyScope,
|
||||
OntologyContext,
|
||||
)
|
||||
|
||||
logger = get_logger("OntologyProviders")
|
||||
|
||||
|
||||
class RDFOntologyProvider(IOntologyProvider):
|
||||
"""Provider for RDF/OWL ontologies."""
|
||||
|
||||
def __init__(self):
|
||||
try:
|
||||
from rdflib import Graph, URIRef, RDF, RDFS, OWL
|
||||
self.Graph = Graph
|
||||
self.URIRef = URIRef
|
||||
self.RDF = RDF
|
||||
self.RDFS = RDFS
|
||||
self.OWL = OWL
|
||||
self.available = True
|
||||
except ImportError:
|
||||
logger.warning("rdflib not available, RDF support disabled")
|
||||
self.available = False
|
||||
|
||||
async def load_ontology(
|
||||
self,
|
||||
source: Union[str, Dict[str, Any]],
|
||||
context: Optional[OntologyContext] = None
|
||||
) -> OntologyGraph:
|
||||
"""Load ontology from RDF/OWL file."""
|
||||
|
||||
if not self.available:
|
||||
raise ImportError("rdflib is required for RDF ontology support")
|
||||
|
||||
if isinstance(source, dict):
|
||||
file_path = source.get("file_path")
|
||||
else:
|
||||
file_path = source
|
||||
|
||||
if not file_path or not Path(file_path).exists():
|
||||
raise FileNotFoundError(f"RDF file not found: {file_path}")
|
||||
|
||||
# Parse RDF graph
|
||||
rdf_graph = self.Graph()
|
||||
rdf_graph.parse(file_path)
|
||||
|
||||
# Convert to our ontology format
|
||||
nodes = []
|
||||
edges = []
|
||||
|
||||
# Extract classes and individuals
|
||||
for cls in rdf_graph.subjects(self.RDF.type, self.OWL.Class):
|
||||
node = OntologyNode(
|
||||
id=self._uri_to_id(cls),
|
||||
name=self._extract_name(cls),
|
||||
type="class",
|
||||
category="owl_class",
|
||||
properties=self._extract_node_properties(cls, rdf_graph)
|
||||
)
|
||||
nodes.append(node)
|
||||
|
||||
# Extract individuals
|
||||
for individual in rdf_graph.subjects(self.RDF.type, None):
|
||||
if not any(rdf_graph.triples((individual, self.RDF.type, self.OWL.Class))):
|
||||
node = OntologyNode(
|
||||
id=self._uri_to_id(individual),
|
||||
name=self._extract_name(individual),
|
||||
type="individual",
|
||||
category="owl_individual",
|
||||
properties=self._extract_node_properties(individual, rdf_graph)
|
||||
)
|
||||
nodes.append(node)
|
||||
|
||||
# Extract relationships
|
||||
for s, p, o in rdf_graph:
|
||||
if p != self.RDF.type: # Skip type relationships
|
||||
edge = OntologyEdge(
|
||||
id=f"{self._uri_to_id(s)}_{self._uri_to_id(p)}_{self._uri_to_id(o)}",
|
||||
source_id=self._uri_to_id(s),
|
||||
target_id=self._uri_to_id(o),
|
||||
relationship_type=self._extract_name(p),
|
||||
properties={"predicate_uri": str(p)}
|
||||
)
|
||||
edges.append(edge)
|
||||
|
||||
ontology = OntologyGraph(
|
||||
id=f"rdf_{Path(file_path).stem}",
|
||||
name=Path(file_path).stem,
|
||||
description=f"RDF ontology loaded from {file_path}",
|
||||
format=OntologyFormat.RDF_XML,
|
||||
scope=OntologyScope.DOMAIN,
|
||||
nodes=nodes,
|
||||
edges=edges,
|
||||
metadata={"source_file": file_path, "triple_count": len(rdf_graph)}
|
||||
)
|
||||
|
||||
logger.info(f"Loaded RDF ontology with {len(nodes)} nodes and {len(edges)} edges")
|
||||
return ontology
|
||||
|
||||
async def save_ontology(
|
||||
self,
|
||||
ontology: OntologyGraph,
|
||||
destination: str,
|
||||
context: Optional[OntologyContext] = None
|
||||
) -> bool:
|
||||
"""Save ontology to RDF/OWL file."""
|
||||
|
||||
if not self.available:
|
||||
raise ImportError("rdflib is required for RDF ontology support")
|
||||
|
||||
# Convert back to RDF
|
||||
rdf_graph = self.Graph()
|
||||
|
||||
# Add nodes
|
||||
for node in ontology.nodes:
|
||||
uri = self.URIRef(f"http://example.org/ontology#{node.id}")
|
||||
if node.type == "class":
|
||||
rdf_graph.add((uri, self.RDF.type, self.OWL.Class))
|
||||
else:
|
||||
# Add as individual of some class
|
||||
class_uri = self.URIRef(f"http://example.org/ontology#{node.type}")
|
||||
rdf_graph.add((uri, self.RDF.type, class_uri))
|
||||
|
||||
# Add edges
|
||||
for edge in ontology.edges:
|
||||
s_uri = self.URIRef(f"http://example.org/ontology#{edge.source_id}")
|
||||
p_uri = self.URIRef(f"http://example.org/ontology#{edge.relationship_type}")
|
||||
o_uri = self.URIRef(f"http://example.org/ontology#{edge.target_id}")
|
||||
rdf_graph.add((s_uri, p_uri, o_uri))
|
||||
|
||||
# Serialize to file
|
||||
try:
|
||||
rdf_graph.serialize(destination=destination, format='xml')
|
||||
logger.info(f"Saved RDF ontology to {destination}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save RDF ontology: {e}")
|
||||
return False
|
||||
|
||||
def supports_format(self, format: OntologyFormat) -> bool:
|
||||
"""Check if provider supports given format."""
|
||||
return self.available and format in [OntologyFormat.RDF_XML, OntologyFormat.OWL]
|
||||
|
||||
async def validate_ontology(self, ontology: OntologyGraph) -> bool:
|
||||
"""Validate RDF ontology structure."""
|
||||
# Basic validation - could be enhanced with OWL reasoning
|
||||
node_ids = {node.id for node in ontology.nodes}
|
||||
|
||||
for edge in ontology.edges:
|
||||
if edge.source_id not in node_ids or edge.target_id not in node_ids:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _uri_to_id(self, uri) -> str:
|
||||
"""Convert URI to simple ID."""
|
||||
uri_str = str(uri)
|
||||
if "#" in uri_str:
|
||||
return uri_str.split("#")[-1]
|
||||
return uri_str.rstrip("/").split("/")[-1]
|
||||
|
||||
def _extract_name(self, uri) -> str:
|
||||
"""Extract readable name from URI."""
|
||||
return self._uri_to_id(uri).replace("_", " ").title()
|
||||
|
||||
def _extract_node_properties(self, uri, graph) -> Dict[str, Any]:
|
||||
"""Extract additional properties for a node."""
|
||||
props = {}
|
||||
|
||||
# Get labels
|
||||
for label in graph.objects(uri, self.RDFS.label):
|
||||
props["label"] = str(label)
|
||||
|
||||
# Get comments
|
||||
for comment in graph.objects(uri, self.RDFS.comment):
|
||||
props["comment"] = str(comment)
|
||||
|
||||
return props
|
||||
|
||||
|
||||
class JSONOntologyProvider(IOntologyProvider):
|
||||
"""Provider for JSON-based ontologies."""
|
||||
|
||||
async def load_ontology(
|
||||
self,
|
||||
source: Union[str, Dict[str, Any]],
|
||||
context: Optional[OntologyContext] = None
|
||||
) -> OntologyGraph:
|
||||
"""Load ontology from JSON file or dict."""
|
||||
|
||||
if isinstance(source, str):
|
||||
# Load from file
|
||||
with open(source, 'r') as f:
|
||||
data = json.load(f)
|
||||
ontology_id = f"json_{Path(source).stem}"
|
||||
source_file = source
|
||||
else:
|
||||
# Use provided dict
|
||||
data = source
|
||||
ontology_id = data.get("id", "json_ontology")
|
||||
source_file = None
|
||||
|
||||
# Parse nodes
|
||||
nodes = []
|
||||
for node_data in data.get("nodes", []):
|
||||
node = OntologyNode(
|
||||
id=node_data["id"],
|
||||
name=node_data.get("name", node_data["id"]),
|
||||
type=node_data.get("type", "entity"),
|
||||
description=node_data.get("description", ""),
|
||||
category=node_data.get("category", "general"),
|
||||
properties=node_data.get("properties", {})
|
||||
)
|
||||
nodes.append(node)
|
||||
|
||||
# Parse edges
|
||||
edges = []
|
||||
for edge_data in data.get("edges", []):
|
||||
edge = OntologyEdge(
|
||||
id=edge_data.get("id", f"{edge_data['source']}_{edge_data['target']}"),
|
||||
source_id=edge_data["source"],
|
||||
target_id=edge_data["target"],
|
||||
relationship_type=edge_data.get("relationship", "related_to"),
|
||||
properties=edge_data.get("properties", {}),
|
||||
weight=edge_data.get("weight")
|
||||
)
|
||||
edges.append(edge)
|
||||
|
||||
ontology = OntologyGraph(
|
||||
id=ontology_id,
|
||||
name=data.get("name", ontology_id),
|
||||
description=data.get("description", "JSON-based ontology"),
|
||||
format=OntologyFormat.JSON,
|
||||
scope=OntologyScope.DOMAIN,
|
||||
nodes=nodes,
|
||||
edges=edges,
|
||||
metadata=data.get("metadata", {"source_file": source_file})
|
||||
)
|
||||
|
||||
logger.info(f"Loaded JSON ontology with {len(nodes)} nodes and {len(edges)} edges")
|
||||
return ontology
|
||||
|
||||
async def save_ontology(
|
||||
self,
|
||||
ontology: OntologyGraph,
|
||||
destination: str,
|
||||
context: Optional[OntologyContext] = None
|
||||
) -> bool:
|
||||
"""Save ontology to JSON file."""
|
||||
|
||||
data = {
|
||||
"id": ontology.id,
|
||||
"name": ontology.name,
|
||||
"description": ontology.description,
|
||||
"format": ontology.format.value,
|
||||
"scope": ontology.scope.value,
|
||||
"nodes": [
|
||||
{
|
||||
"id": node.id,
|
||||
"name": node.name,
|
||||
"type": node.type,
|
||||
"description": node.description,
|
||||
"category": node.category,
|
||||
"properties": node.properties
|
||||
}
|
||||
for node in ontology.nodes
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"id": edge.id,
|
||||
"source": edge.source_id,
|
||||
"target": edge.target_id,
|
||||
"relationship": edge.relationship_type,
|
||||
"properties": edge.properties,
|
||||
"weight": edge.weight
|
||||
}
|
||||
for edge in ontology.edges
|
||||
],
|
||||
"metadata": ontology.metadata
|
||||
}
|
||||
|
||||
try:
|
||||
with open(destination, 'w') as f:
|
||||
json.dump(data, f, indent=2)
|
||||
logger.info(f"Saved JSON ontology to {destination}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save JSON ontology: {e}")
|
||||
return False
|
||||
|
||||
def supports_format(self, format: OntologyFormat) -> bool:
|
||||
"""Check if provider supports given format."""
|
||||
return format == OntologyFormat.JSON
|
||||
|
||||
async def validate_ontology(self, ontology: OntologyGraph) -> bool:
|
||||
"""Validate JSON ontology structure."""
|
||||
node_ids = {node.id for node in ontology.nodes}
|
||||
|
||||
for edge in ontology.edges:
|
||||
if edge.source_id not in node_ids or edge.target_id not in node_ids:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class CSVOntologyProvider(IOntologyProvider):
|
||||
"""Provider for CSV-based ontologies."""
|
||||
|
||||
async def load_ontology(
|
||||
self,
|
||||
source: Union[str, Dict[str, Any]],
|
||||
context: Optional[OntologyContext] = None
|
||||
) -> OntologyGraph:
|
||||
"""Load ontology from CSV files."""
|
||||
|
||||
if isinstance(source, dict):
|
||||
nodes_file = source.get("nodes_file")
|
||||
edges_file = source.get("edges_file")
|
||||
else:
|
||||
# Assume single file or directory
|
||||
source_path = Path(source)
|
||||
if source_path.is_dir():
|
||||
nodes_file = source_path / "nodes.csv"
|
||||
edges_file = source_path / "edges.csv"
|
||||
else:
|
||||
nodes_file = source
|
||||
edges_file = None
|
||||
|
||||
# Load nodes
|
||||
nodes = []
|
||||
if nodes_file and Path(nodes_file).exists():
|
||||
with open(nodes_file, 'r') as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
node = OntologyNode(
|
||||
id=row["id"],
|
||||
name=row.get("name", row["id"]),
|
||||
type=row.get("type", "entity"),
|
||||
description=row.get("description", ""),
|
||||
category=row.get("category", "general"),
|
||||
properties={k: v for k, v in row.items()
|
||||
if k not in ["id", "name", "type", "description", "category"]}
|
||||
)
|
||||
nodes.append(node)
|
||||
|
||||
# Load edges
|
||||
edges = []
|
||||
if edges_file and Path(edges_file).exists():
|
||||
with open(edges_file, 'r') as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
edge = OntologyEdge(
|
||||
id=row.get("id", f"{row['source']}_{row['target']}"),
|
||||
source_id=row["source"],
|
||||
target_id=row["target"],
|
||||
relationship_type=row.get("relationship", "related_to"),
|
||||
properties={k: v for k, v in row.items()
|
||||
if k not in ["id", "source", "target", "relationship"]},
|
||||
weight=float(row["weight"]) if row.get("weight") else None
|
||||
)
|
||||
edges.append(edge)
|
||||
|
||||
ontology = OntologyGraph(
|
||||
id=f"csv_{Path(nodes_file).stem}" if nodes_file else "csv_ontology",
|
||||
name=f"CSV Ontology",
|
||||
description="CSV-based ontology",
|
||||
format=OntologyFormat.CSV,
|
||||
scope=OntologyScope.DOMAIN,
|
||||
nodes=nodes,
|
||||
edges=edges,
|
||||
metadata={"nodes_file": str(nodes_file), "edges_file": str(edges_file)}
|
||||
)
|
||||
|
||||
logger.info(f"Loaded CSV ontology with {len(nodes)} nodes and {len(edges)} edges")
|
||||
return ontology
|
||||
|
||||
async def save_ontology(
|
||||
self,
|
||||
ontology: OntologyGraph,
|
||||
destination: str,
|
||||
context: Optional[OntologyContext] = None
|
||||
) -> bool:
|
||||
"""Save ontology to CSV files."""
|
||||
|
||||
dest_path = Path(destination)
|
||||
if dest_path.suffix == ".csv":
|
||||
# Single file - save nodes only
|
||||
nodes_file = destination
|
||||
edges_file = None
|
||||
else:
|
||||
# Directory - save separate files
|
||||
dest_path.mkdir(exist_ok=True)
|
||||
nodes_file = dest_path / "nodes.csv"
|
||||
edges_file = dest_path / "edges.csv"
|
||||
|
||||
try:
|
||||
# Save nodes
|
||||
if ontology.nodes:
|
||||
all_properties = set()
|
||||
for node in ontology.nodes:
|
||||
all_properties.update(node.properties.keys())
|
||||
|
||||
fieldnames = ["id", "name", "type", "description", "category"] + list(all_properties)
|
||||
|
||||
with open(nodes_file, 'w', newline='') as f:
|
||||
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
|
||||
for node in ontology.nodes:
|
||||
row = {
|
||||
"id": node.id,
|
||||
"name": node.name,
|
||||
"type": node.type,
|
||||
"description": node.description,
|
||||
"category": node.category,
|
||||
**node.properties
|
||||
}
|
||||
writer.writerow(row)
|
||||
|
||||
# Save edges
|
||||
if edges_file and ontology.edges:
|
||||
all_properties = set()
|
||||
for edge in ontology.edges:
|
||||
all_properties.update(edge.properties.keys())
|
||||
|
||||
fieldnames = ["id", "source", "target", "relationship", "weight"] + list(all_properties)
|
||||
|
||||
with open(edges_file, 'w', newline='') as f:
|
||||
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
|
||||
for edge in ontology.edges:
|
||||
row = {
|
||||
"id": edge.id,
|
||||
"source": edge.source_id,
|
||||
"target": edge.target_id,
|
||||
"relationship": edge.relationship_type,
|
||||
"weight": edge.weight,
|
||||
**edge.properties
|
||||
}
|
||||
writer.writerow(row)
|
||||
|
||||
logger.info(f"Saved CSV ontology to {destination}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save CSV ontology: {e}")
|
||||
return False
|
||||
|
||||
def supports_format(self, format: OntologyFormat) -> bool:
|
||||
"""Check if provider supports given format."""
|
||||
return format == OntologyFormat.CSV
|
||||
|
||||
async def validate_ontology(self, ontology: OntologyGraph) -> bool:
|
||||
"""Validate CSV ontology structure."""
|
||||
node_ids = {node.id for node in ontology.nodes}
|
||||
|
||||
for edge in ontology.edges:
|
||||
if edge.source_id not in node_ids or edge.target_id not in node_ids:
|
||||
return False
|
||||
|
||||
return True
|
||||
241
cognee/modules/ontology/registry.py
Normal file
241
cognee/modules/ontology/registry.py
Normal file
|
|
@ -0,0 +1,241 @@
|
|||
"""Ontology registry implementation."""
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.ontology.interfaces import (
|
||||
IOntologyRegistry,
|
||||
OntologyGraph,
|
||||
OntologyScope,
|
||||
OntologyContext,
|
||||
)
|
||||
|
||||
logger = get_logger("OntologyRegistry")
|
||||
|
||||
|
||||
class OntologyRegistry(IOntologyRegistry):
|
||||
"""In-memory implementation of ontology registry."""
|
||||
|
||||
def __init__(self):
|
||||
self.ontologies: Dict[str, OntologyGraph] = {}
|
||||
self.scope_index: Dict[OntologyScope, List[str]] = {
|
||||
scope: [] for scope in OntologyScope
|
||||
}
|
||||
self.domain_index: Dict[str, List[str]] = {}
|
||||
self.user_index: Dict[str, List[str]] = {}
|
||||
self.dataset_index: Dict[str, List[str]] = {}
|
||||
self.pipeline_index: Dict[str, List[str]] = {}
|
||||
|
||||
async def register_ontology(
|
||||
self,
|
||||
ontology: OntologyGraph,
|
||||
scope: OntologyScope,
|
||||
context: Optional[OntologyContext] = None
|
||||
) -> str:
|
||||
"""Register an ontology."""
|
||||
|
||||
ontology_id = ontology.id or str(uuid4())
|
||||
ontology.id = ontology_id
|
||||
ontology.scope = scope
|
||||
|
||||
self.ontologies[ontology_id] = ontology
|
||||
|
||||
# Update scope index
|
||||
if ontology_id not in self.scope_index[scope]:
|
||||
self.scope_index[scope].append(ontology_id)
|
||||
|
||||
# Update domain index if applicable
|
||||
if context and context.domain:
|
||||
if context.domain not in self.domain_index:
|
||||
self.domain_index[context.domain] = []
|
||||
if ontology_id not in self.domain_index[context.domain]:
|
||||
self.domain_index[context.domain].append(ontology_id)
|
||||
|
||||
# Update user index if applicable
|
||||
if context and context.user_id:
|
||||
if context.user_id not in self.user_index:
|
||||
self.user_index[context.user_id] = []
|
||||
if ontology_id not in self.user_index[context.user_id]:
|
||||
self.user_index[context.user_id].append(ontology_id)
|
||||
|
||||
# Update dataset index if applicable
|
||||
if context and context.dataset_id:
|
||||
if context.dataset_id not in self.dataset_index:
|
||||
self.dataset_index[context.dataset_id] = []
|
||||
if ontology_id not in self.dataset_index[context.dataset_id]:
|
||||
self.dataset_index[context.dataset_id].append(ontology_id)
|
||||
|
||||
# Update pipeline index if applicable
|
||||
if context and context.pipeline_name:
|
||||
if context.pipeline_name not in self.pipeline_index:
|
||||
self.pipeline_index[context.pipeline_name] = []
|
||||
if ontology_id not in self.pipeline_index[context.pipeline_name]:
|
||||
self.pipeline_index[context.pipeline_name].append(ontology_id)
|
||||
|
||||
logger.info(f"Registered ontology {ontology_id} with scope {scope}")
|
||||
return ontology_id
|
||||
|
||||
async def get_ontology(
|
||||
self,
|
||||
ontology_id: str,
|
||||
context: Optional[OntologyContext] = None
|
||||
) -> Optional[OntologyGraph]:
|
||||
"""Get ontology by ID."""
|
||||
return self.ontologies.get(ontology_id)
|
||||
|
||||
async def find_ontologies(
|
||||
self,
|
||||
scope: Optional[OntologyScope] = None,
|
||||
domain: Optional[str] = None,
|
||||
context: Optional[OntologyContext] = None
|
||||
) -> List[OntologyGraph]:
|
||||
"""Find ontologies matching criteria."""
|
||||
|
||||
candidate_ids = set()
|
||||
|
||||
# Filter by scope
|
||||
if scope:
|
||||
candidate_ids.update(self.scope_index.get(scope, []))
|
||||
else:
|
||||
# If no scope specified, get all
|
||||
for scope_ids in self.scope_index.values():
|
||||
candidate_ids.update(scope_ids)
|
||||
|
||||
# Filter by domain
|
||||
if domain:
|
||||
domain_ids = set(self.domain_index.get(domain, []))
|
||||
candidate_ids = candidate_ids.intersection(domain_ids)
|
||||
|
||||
# Filter by context
|
||||
if context:
|
||||
if context.user_id:
|
||||
user_ids = set(self.user_index.get(context.user_id, []))
|
||||
if scope == OntologyScope.USER:
|
||||
candidate_ids = candidate_ids.intersection(user_ids)
|
||||
else:
|
||||
candidate_ids.update(user_ids)
|
||||
|
||||
if context.dataset_id:
|
||||
dataset_ids = set(self.dataset_index.get(context.dataset_id, []))
|
||||
if scope == OntologyScope.DATASET:
|
||||
candidate_ids = candidate_ids.intersection(dataset_ids)
|
||||
else:
|
||||
candidate_ids.update(dataset_ids)
|
||||
|
||||
if context.pipeline_name:
|
||||
pipeline_ids = set(self.pipeline_index.get(context.pipeline_name, []))
|
||||
if scope == OntologyScope.PIPELINE:
|
||||
candidate_ids = candidate_ids.intersection(pipeline_ids)
|
||||
else:
|
||||
candidate_ids.update(pipeline_ids)
|
||||
|
||||
# Return matching ontologies
|
||||
result = []
|
||||
for ontology_id in candidate_ids:
|
||||
if ontology_id in self.ontologies:
|
||||
result.append(self.ontologies[ontology_id])
|
||||
|
||||
logger.debug(f"Found {len(result)} ontologies matching criteria")
|
||||
return result
|
||||
|
||||
async def unregister_ontology(
|
||||
self,
|
||||
ontology_id: str,
|
||||
context: Optional[OntologyContext] = None
|
||||
) -> bool:
|
||||
"""Unregister an ontology."""
|
||||
|
||||
if ontology_id not in self.ontologies:
|
||||
return False
|
||||
|
||||
ontology = self.ontologies[ontology_id]
|
||||
|
||||
# Remove from all indices
|
||||
for scope_ids in self.scope_index.values():
|
||||
if ontology_id in scope_ids:
|
||||
scope_ids.remove(ontology_id)
|
||||
|
||||
for domain_ids in self.domain_index.values():
|
||||
if ontology_id in domain_ids:
|
||||
domain_ids.remove(ontology_id)
|
||||
|
||||
for user_ids in self.user_index.values():
|
||||
if ontology_id in user_ids:
|
||||
user_ids.remove(ontology_id)
|
||||
|
||||
for dataset_ids in self.dataset_index.values():
|
||||
if ontology_id in dataset_ids:
|
||||
dataset_ids.remove(ontology_id)
|
||||
|
||||
for pipeline_ids in self.pipeline_index.values():
|
||||
if ontology_id in pipeline_ids:
|
||||
pipeline_ids.remove(ontology_id)
|
||||
|
||||
# Remove from main registry
|
||||
del self.ontologies[ontology_id]
|
||||
|
||||
logger.info(f"Unregistered ontology {ontology_id}")
|
||||
return True
|
||||
|
||||
def get_stats(self) -> Dict[str, int]:
|
||||
"""Get registry statistics."""
|
||||
return {
|
||||
"total_ontologies": len(self.ontologies),
|
||||
"global_ontologies": len(self.scope_index[OntologyScope.GLOBAL]),
|
||||
"domain_ontologies": len(self.scope_index[OntologyScope.DOMAIN]),
|
||||
"pipeline_ontologies": len(self.scope_index[OntologyScope.PIPELINE]),
|
||||
"user_ontologies": len(self.scope_index[OntologyScope.USER]),
|
||||
"dataset_ontologies": len(self.scope_index[OntologyScope.DATASET]),
|
||||
"unique_domains": len(self.domain_index),
|
||||
"unique_users": len(self.user_index),
|
||||
"unique_datasets": len(self.dataset_index),
|
||||
"unique_pipelines": len(self.pipeline_index),
|
||||
}
|
||||
|
||||
|
||||
class DatabaseOntologyRegistry(IOntologyRegistry):
|
||||
"""Database-backed ontology registry (placeholder implementation)."""
|
||||
|
||||
def __init__(self, db_connection=None):
|
||||
self.db_connection = db_connection
|
||||
# This would use actual database operations in a real implementation
|
||||
self._memory_registry = OntologyRegistry()
|
||||
|
||||
async def register_ontology(
|
||||
self,
|
||||
ontology: OntologyGraph,
|
||||
scope: OntologyScope,
|
||||
context: Optional[OntologyContext] = None
|
||||
) -> str:
|
||||
"""Register an ontology in database."""
|
||||
# TODO: Implement database storage
|
||||
return await self._memory_registry.register_ontology(ontology, scope, context)
|
||||
|
||||
async def get_ontology(
|
||||
self,
|
||||
ontology_id: str,
|
||||
context: Optional[OntologyContext] = None
|
||||
) -> Optional[OntologyGraph]:
|
||||
"""Get ontology from database."""
|
||||
# TODO: Implement database retrieval
|
||||
return await self._memory_registry.get_ontology(ontology_id, context)
|
||||
|
||||
async def find_ontologies(
|
||||
self,
|
||||
scope: Optional[OntologyScope] = None,
|
||||
domain: Optional[str] = None,
|
||||
context: Optional[OntologyContext] = None
|
||||
) -> List[OntologyGraph]:
|
||||
"""Find ontologies in database."""
|
||||
# TODO: Implement database query
|
||||
return await self._memory_registry.find_ontologies(scope, domain, context)
|
||||
|
||||
async def unregister_ontology(
|
||||
self,
|
||||
ontology_id: str,
|
||||
context: Optional[OntologyContext] = None
|
||||
) -> bool:
|
||||
"""Unregister ontology from database."""
|
||||
# TODO: Implement database deletion
|
||||
return await self._memory_registry.unregister_ontology(ontology_id, context)
|
||||
275
cognee/modules/ontology/resolvers.py
Normal file
275
cognee/modules/ontology/resolvers.py
Normal file
|
|
@ -0,0 +1,275 @@
|
|||
"""DataPoint resolution implementations."""
|
||||
|
||||
import importlib
|
||||
from typing import Any, Dict, List, Optional, Callable, Type
|
||||
from uuid import uuid4
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.ontology.interfaces import (
|
||||
IDataPointResolver,
|
||||
OntologyNode,
|
||||
DataPointMapping,
|
||||
OntologyContext,
|
||||
)
|
||||
|
||||
logger = get_logger("DataPointResolver")
|
||||
|
||||
|
||||
class DefaultDataPointResolver(IDataPointResolver):
|
||||
"""Default implementation for DataPoint resolution."""
|
||||
|
||||
def __init__(self):
|
||||
self.custom_resolvers: Dict[str, Callable] = {}
|
||||
|
||||
async def resolve_to_datapoint(
|
||||
self,
|
||||
ontology_node: OntologyNode,
|
||||
mapping_config: DataPointMapping,
|
||||
context: Optional[OntologyContext] = None
|
||||
) -> Any: # DataPoint
|
||||
"""Resolve ontology node to DataPoint instance."""
|
||||
|
||||
# Use custom resolver if specified
|
||||
if mapping_config.custom_resolver and mapping_config.custom_resolver in self.custom_resolvers:
|
||||
return await self._apply_custom_resolver(
|
||||
ontology_node, mapping_config, context
|
||||
)
|
||||
|
||||
# Use default resolution logic
|
||||
return await self._default_resolution(ontology_node, mapping_config, context)
|
||||
|
||||
async def resolve_from_datapoint(
|
||||
self,
|
||||
datapoint: Any, # DataPoint
|
||||
mapping_config: DataPointMapping,
|
||||
context: Optional[OntologyContext] = None
|
||||
) -> OntologyNode:
|
||||
"""Resolve DataPoint instance to ontology node."""
|
||||
|
||||
# Extract properties from DataPoint
|
||||
datapoint_dict = datapoint.to_dict() if hasattr(datapoint, 'to_dict') else datapoint.__dict__
|
||||
|
||||
# Map DataPoint fields to ontology properties
|
||||
ontology_properties = {}
|
||||
reverse_mappings = {v: k for k, v in mapping_config.field_mappings.items()}
|
||||
|
||||
for datapoint_field, ontology_field in reverse_mappings.items():
|
||||
if datapoint_field in datapoint_dict:
|
||||
ontology_properties[ontology_field] = datapoint_dict[datapoint_field]
|
||||
|
||||
# Create ontology node
|
||||
node = OntologyNode(
|
||||
id=str(datapoint.id) if hasattr(datapoint, 'id') else str(uuid4()),
|
||||
name=ontology_properties.get('name', str(datapoint.id)),
|
||||
type=mapping_config.ontology_node_type,
|
||||
description=ontology_properties.get('description', ''),
|
||||
category=ontology_properties.get('category', 'entity'),
|
||||
properties=ontology_properties
|
||||
)
|
||||
|
||||
return node
|
||||
|
||||
async def validate_mapping(
|
||||
self,
|
||||
mapping_config: DataPointMapping
|
||||
) -> bool:
|
||||
"""Validate mapping configuration."""
|
||||
try:
|
||||
# Check if DataPoint class exists
|
||||
module_path, class_name = mapping_config.datapoint_class.rsplit('.', 1)
|
||||
module = importlib.import_module(module_path)
|
||||
datapoint_class = getattr(module, class_name)
|
||||
|
||||
# Validate field mappings
|
||||
if hasattr(datapoint_class, '__annotations__'):
|
||||
valid_fields = set(datapoint_class.__annotations__.keys())
|
||||
mapped_fields = set(mapping_config.field_mappings.values())
|
||||
|
||||
invalid_fields = mapped_fields - valid_fields
|
||||
if invalid_fields:
|
||||
logger.warning(f"Invalid field mappings: {invalid_fields}")
|
||||
return False
|
||||
|
||||
# Validate custom resolver if specified
|
||||
if mapping_config.custom_resolver:
|
||||
if mapping_config.custom_resolver not in self.custom_resolvers:
|
||||
logger.warning(f"Custom resolver not found: {mapping_config.custom_resolver}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Mapping validation failed: {e}")
|
||||
return False
|
||||
|
||||
def register_custom_resolver(
|
||||
self,
|
||||
resolver_name: str,
|
||||
resolver_func: Callable[[OntologyNode, DataPointMapping], Any]
|
||||
) -> None:
|
||||
"""Register a custom resolver function."""
|
||||
self.custom_resolvers[resolver_name] = resolver_func
|
||||
logger.info(f"Registered custom resolver: {resolver_name}")
|
||||
|
||||
async def _apply_custom_resolver(
|
||||
self,
|
||||
ontology_node: OntologyNode,
|
||||
mapping_config: DataPointMapping,
|
||||
context: Optional[OntologyContext] = None
|
||||
) -> Any:
|
||||
"""Apply custom resolver function."""
|
||||
resolver_func = self.custom_resolvers[mapping_config.custom_resolver]
|
||||
|
||||
if callable(resolver_func):
|
||||
try:
|
||||
# Call resolver function
|
||||
if context:
|
||||
result = await resolver_func(ontology_node, mapping_config, context)
|
||||
else:
|
||||
result = await resolver_func(ontology_node, mapping_config)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Custom resolver failed: {e}")
|
||||
# Fallback to default resolution
|
||||
return await self._default_resolution(ontology_node, mapping_config, context)
|
||||
else:
|
||||
logger.error(f"Invalid custom resolver: {mapping_config.custom_resolver}")
|
||||
return await self._default_resolution(ontology_node, mapping_config, context)
|
||||
|
||||
async def _default_resolution(
|
||||
self,
|
||||
ontology_node: OntologyNode,
|
||||
mapping_config: DataPointMapping,
|
||||
context: Optional[OntologyContext] = None
|
||||
) -> Any:
|
||||
"""Default resolution logic."""
|
||||
try:
|
||||
# Import the DataPoint class
|
||||
module_path, class_name = mapping_config.datapoint_class.rsplit('.', 1)
|
||||
module = importlib.import_module(module_path)
|
||||
datapoint_class = getattr(module, class_name)
|
||||
|
||||
# Map ontology properties to DataPoint fields
|
||||
datapoint_data = {}
|
||||
|
||||
# Apply field mappings
|
||||
for ontology_field, datapoint_field in mapping_config.field_mappings.items():
|
||||
if ontology_field in ontology_node.properties:
|
||||
datapoint_data[datapoint_field] = ontology_node.properties[ontology_field]
|
||||
elif hasattr(ontology_node, ontology_field):
|
||||
datapoint_data[datapoint_field] = getattr(ontology_node, ontology_field)
|
||||
|
||||
# Set default mappings if not provided
|
||||
if 'id' not in datapoint_data:
|
||||
datapoint_data['id'] = ontology_node.id
|
||||
if 'type' not in datapoint_data:
|
||||
datapoint_data['type'] = ontology_node.type
|
||||
|
||||
# Add ontology metadata
|
||||
if hasattr(datapoint_class, 'metadata'):
|
||||
datapoint_data['metadata'] = {
|
||||
'type': ontology_node.type,
|
||||
'index_fields': list(mapping_config.field_mappings.values()),
|
||||
'ontology_source': True,
|
||||
'ontology_node_id': ontology_node.id,
|
||||
}
|
||||
|
||||
# Set ontology_valid flag
|
||||
datapoint_data['ontology_valid'] = True
|
||||
|
||||
# Create DataPoint instance
|
||||
datapoint = datapoint_class(**datapoint_data)
|
||||
|
||||
# Apply validation rules if specified
|
||||
if mapping_config.validation_rules:
|
||||
await self._apply_validation_rules(datapoint, mapping_config.validation_rules)
|
||||
|
||||
return datapoint
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Default resolution failed for node {ontology_node.id}: {e}")
|
||||
return None
|
||||
|
||||
async def _apply_validation_rules(
|
||||
self,
|
||||
datapoint: Any,
|
||||
validation_rules: List[str]
|
||||
) -> None:
|
||||
"""Apply validation rules to DataPoint."""
|
||||
for rule in validation_rules:
|
||||
try:
|
||||
# This is a simple implementation - in practice, you'd want
|
||||
# a more sophisticated rule engine
|
||||
if rule.startswith("required:"):
|
||||
field_name = rule.split(":", 1)[1]
|
||||
if not hasattr(datapoint, field_name) or getattr(datapoint, field_name) is None:
|
||||
raise ValueError(f"Required field {field_name} is missing")
|
||||
|
||||
elif rule.startswith("type:"):
|
||||
field_name, expected_type = rule.split(":", 2)[1:]
|
||||
if hasattr(datapoint, field_name):
|
||||
field_value = getattr(datapoint, field_name)
|
||||
if field_value is not None and not isinstance(field_value, eval(expected_type)):
|
||||
raise ValueError(f"Field {field_name} has wrong type")
|
||||
|
||||
# Add more validation rules as needed
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Validation rule '{rule}' failed: {e}")
|
||||
|
||||
|
||||
class DomainSpecificResolver(DefaultDataPointResolver):
|
||||
"""Domain-specific resolver with specialized logic."""
|
||||
|
||||
def __init__(self, domain: str):
|
||||
super().__init__()
|
||||
self.domain = domain
|
||||
|
||||
async def _default_resolution(
|
||||
self,
|
||||
ontology_node: OntologyNode,
|
||||
mapping_config: DataPointMapping,
|
||||
context: Optional[OntologyContext] = None
|
||||
) -> Any:
|
||||
"""Domain-specific resolution logic."""
|
||||
|
||||
# Apply domain-specific preprocessing
|
||||
if self.domain == "medical":
|
||||
ontology_node = await self._preprocess_medical_node(ontology_node)
|
||||
elif self.domain == "legal":
|
||||
ontology_node = await self._preprocess_legal_node(ontology_node)
|
||||
elif self.domain == "code":
|
||||
ontology_node = await self._preprocess_code_node(ontology_node)
|
||||
|
||||
# Use parent's default resolution
|
||||
return await super()._default_resolution(ontology_node, mapping_config, context)
|
||||
|
||||
async def _preprocess_medical_node(self, node: OntologyNode) -> OntologyNode:
|
||||
"""Preprocess medical domain nodes."""
|
||||
# Add medical-specific property transformations
|
||||
if node.type == "Disease":
|
||||
node.properties["medical_category"] = "disease"
|
||||
elif node.type == "Symptom":
|
||||
node.properties["medical_category"] = "symptom"
|
||||
|
||||
return node
|
||||
|
||||
async def _preprocess_legal_node(self, node: OntologyNode) -> OntologyNode:
|
||||
"""Preprocess legal domain nodes."""
|
||||
# Add legal-specific property transformations
|
||||
if node.type == "Law":
|
||||
node.properties["legal_category"] = "legislation"
|
||||
elif node.type == "Case":
|
||||
node.properties["legal_category"] = "precedent"
|
||||
|
||||
return node
|
||||
|
||||
async def _preprocess_code_node(self, node: OntologyNode) -> OntologyNode:
|
||||
"""Preprocess code domain nodes."""
|
||||
# Add code-specific property transformations
|
||||
if node.type == "Function":
|
||||
node.properties["code_category"] = "function"
|
||||
elif node.type == "Class":
|
||||
node.properties["code_category"] = "class"
|
||||
|
||||
return node
|
||||
410
cognee/tasks/graph/extract_graph_from_data_ontology_aware.py
Normal file
410
cognee/tasks/graph/extract_graph_from_data_ontology_aware.py
Normal file
|
|
@ -0,0 +1,410 @@
|
|||
"""
|
||||
Enhanced graph extraction task with ontology awareness.
|
||||
|
||||
This demonstrates how to update existing tasks to use the new ontology system.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Type, List, Optional, Any, Dict
|
||||
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.tasks.storage.add_data_points import add_data_points
|
||||
from cognee.modules.chunking.models.DocumentChunk import DocumentChunk
|
||||
from cognee.modules.graph.utils import (
|
||||
expand_with_nodes_and_edges,
|
||||
retrieve_existing_edges,
|
||||
)
|
||||
from cognee.shared.data_models import KnowledgeGraph
|
||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
# New ontology imports
|
||||
from cognee.modules.ontology.interfaces import (
|
||||
IOntologyManager,
|
||||
OntologyContext,
|
||||
DataPointMapping,
|
||||
GraphBindingConfig,
|
||||
)
|
||||
|
||||
logger = get_logger("extract_graph_from_data_ontology_aware")
|
||||
|
||||
|
||||
async def extract_graph_from_data_ontology_aware(
|
||||
data_chunks: List[DocumentChunk],
|
||||
graph_model: Type[Any] = KnowledgeGraph,
|
||||
ontology_manager: Optional[IOntologyManager] = None,
|
||||
ontology_context: Optional[OntologyContext] = None,
|
||||
datapoint_mappings: Optional[List[DataPointMapping]] = None,
|
||||
graph_binding_config: Optional[GraphBindingConfig] = None,
|
||||
entity_extraction_enabled: bool = False,
|
||||
target_entity_types: Optional[List[str]] = None,
|
||||
enhanced_content: Optional[Dict[str, Any]] = None,
|
||||
**kwargs
|
||||
) -> List[DocumentChunk]:
|
||||
"""
|
||||
Enhanced graph extraction with ontology awareness.
|
||||
|
||||
Args:
|
||||
data_chunks: Document chunks to process
|
||||
graph_model: Graph model type (KnowledgeGraph or custom)
|
||||
ontology_manager: Ontology manager instance (injected by pipeline)
|
||||
ontology_context: Ontology context (injected by pipeline)
|
||||
datapoint_mappings: DataPoint mappings (injected by pipeline)
|
||||
graph_binding_config: Graph binding configuration (injected by pipeline)
|
||||
entity_extraction_enabled: Whether to use ontology for entity extraction
|
||||
target_entity_types: Specific entity types to extract
|
||||
enhanced_content: Pre-enhanced content with ontological information
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
Updated document chunks with extracted graph data
|
||||
"""
|
||||
|
||||
logger.info(f"Processing {len(data_chunks)} chunks with ontology awareness")
|
||||
|
||||
# Check if ontology integration is available
|
||||
if ontology_manager and ontology_context:
|
||||
logger.info(f"Ontology integration enabled for domain: {ontology_context.domain}")
|
||||
return await _extract_with_ontology(
|
||||
data_chunks=data_chunks,
|
||||
graph_model=graph_model,
|
||||
ontology_manager=ontology_manager,
|
||||
ontology_context=ontology_context,
|
||||
datapoint_mappings=datapoint_mappings,
|
||||
graph_binding_config=graph_binding_config,
|
||||
entity_extraction_enabled=entity_extraction_enabled,
|
||||
target_entity_types=target_entity_types,
|
||||
enhanced_content=enhanced_content,
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
logger.info("No ontology integration, using standard extraction")
|
||||
return await _extract_standard(data_chunks, graph_model, **kwargs)
|
||||
|
||||
|
||||
async def _extract_with_ontology(
|
||||
data_chunks: List[DocumentChunk],
|
||||
graph_model: Type[Any],
|
||||
ontology_manager: IOntologyManager,
|
||||
ontology_context: OntologyContext,
|
||||
datapoint_mappings: Optional[List[DataPointMapping]] = None,
|
||||
graph_binding_config: Optional[GraphBindingConfig] = None,
|
||||
entity_extraction_enabled: bool = False,
|
||||
target_entity_types: Optional[List[str]] = None,
|
||||
enhanced_content: Optional[Dict[str, Any]] = None,
|
||||
**kwargs
|
||||
) -> List[DocumentChunk]:
|
||||
"""Extract graph data with ontology enhancement."""
|
||||
|
||||
# Step 1: Get applicable ontologies
|
||||
ontologies = await ontology_manager.get_applicable_ontologies(ontology_context)
|
||||
logger.info(f"Found {len(ontologies)} applicable ontologies")
|
||||
|
||||
if not ontologies:
|
||||
logger.warning("No applicable ontologies found, falling back to standard extraction")
|
||||
return await _extract_standard(data_chunks, graph_model, **kwargs)
|
||||
|
||||
# Step 2: Enhance content with ontological information
|
||||
chunk_graphs = []
|
||||
enhanced_datapoints = []
|
||||
|
||||
for chunk in data_chunks:
|
||||
# Enhance chunk content if not already done
|
||||
if enhanced_content:
|
||||
chunk_enhanced = enhanced_content
|
||||
else:
|
||||
chunk_enhanced = await ontology_manager.enhance_with_ontology(
|
||||
chunk.text, ontology_context
|
||||
)
|
||||
|
||||
# Extract graph using enhanced information
|
||||
chunk_graph = await _extract_chunk_graph_with_ontology(
|
||||
chunk=chunk,
|
||||
enhanced_content=chunk_enhanced,
|
||||
graph_model=graph_model,
|
||||
ontology_manager=ontology_manager,
|
||||
ontology_context=ontology_context,
|
||||
target_entity_types=target_entity_types,
|
||||
)
|
||||
|
||||
chunk_graphs.append(chunk_graph)
|
||||
|
||||
# Convert ontological entities to DataPoints if mappings available
|
||||
if datapoint_mappings and chunk_enhanced.get('extracted_entities'):
|
||||
ontology_nodes = await _convert_entities_to_ontology_nodes(
|
||||
chunk_enhanced['extracted_entities'], ontologies[0]
|
||||
)
|
||||
|
||||
if ontology_nodes:
|
||||
datapoints = await ontology_manager.resolve_to_datapoints(
|
||||
ontology_nodes, ontology_context
|
||||
)
|
||||
enhanced_datapoints.extend(datapoints)
|
||||
|
||||
# Step 3: Integrate with graph database using custom binding
|
||||
if graph_model is KnowledgeGraph:
|
||||
# Use standard integration with ontology-enhanced nodes/edges
|
||||
await _integrate_ontology_enhanced_graphs(
|
||||
data_chunks=data_chunks,
|
||||
chunk_graphs=chunk_graphs,
|
||||
enhanced_datapoints=enhanced_datapoints,
|
||||
ontology_manager=ontology_manager,
|
||||
ontology_context=ontology_context,
|
||||
graph_binding_config=graph_binding_config,
|
||||
)
|
||||
else:
|
||||
# Custom graph model - just attach graphs to chunks
|
||||
for chunk_index, chunk_graph in enumerate(chunk_graphs):
|
||||
data_chunks[chunk_index].contains = chunk_graph
|
||||
|
||||
logger.info(f"Completed ontology-aware graph extraction for {len(data_chunks)} chunks")
|
||||
return data_chunks
|
||||
|
||||
|
||||
async def _extract_chunk_graph_with_ontology(
|
||||
chunk: DocumentChunk,
|
||||
enhanced_content: Dict[str, Any],
|
||||
graph_model: Type[Any],
|
||||
ontology_manager: IOntologyManager,
|
||||
ontology_context: OntologyContext,
|
||||
target_entity_types: Optional[List[str]] = None,
|
||||
) -> Any:
|
||||
"""Extract graph for a single chunk using ontological enhancement."""
|
||||
|
||||
# Build entity-aware prompt for LLM
|
||||
extracted_entities = enhanced_content.get('extracted_entities', [])
|
||||
semantic_relationships = enhanced_content.get('semantic_relationships', [])
|
||||
|
||||
# Filter entities by target types if specified
|
||||
if target_entity_types:
|
||||
extracted_entities = [
|
||||
entity for entity in extracted_entities
|
||||
if entity.get('type') in target_entity_types
|
||||
]
|
||||
|
||||
# Create enhanced prompt with ontological context
|
||||
ontology_context_prompt = _build_ontology_context_prompt(
|
||||
extracted_entities, semantic_relationships
|
||||
)
|
||||
|
||||
# Use LLM to extract graph with ontological guidance
|
||||
full_prompt = f"""
|
||||
{ontology_context_prompt}
|
||||
|
||||
Text to analyze:
|
||||
{chunk.text}
|
||||
|
||||
Extract entities and relationships, giving preference to the ontological entities
|
||||
and relationships mentioned above when they appear in the text.
|
||||
"""
|
||||
|
||||
# Extract using LLM with ontological context
|
||||
chunk_graph = await LLMGateway.acreate_structured_output(
|
||||
full_prompt,
|
||||
"You are a knowledge graph extractor. Use the ontological context to guide your extraction.",
|
||||
graph_model
|
||||
)
|
||||
|
||||
# Enhance extracted graph with ontological metadata
|
||||
if hasattr(chunk_graph, 'nodes'):
|
||||
for node in chunk_graph.nodes:
|
||||
# Add ontological metadata to nodes
|
||||
matching_entity = _find_matching_ontological_entity(node, extracted_entities)
|
||||
if matching_entity:
|
||||
node.ontology_source = matching_entity.get('ontology_id')
|
||||
node.ontology_confidence = matching_entity.get('confidence', 0.0)
|
||||
if hasattr(node, 'type'):
|
||||
node.type = matching_entity.get('type', node.type)
|
||||
|
||||
return chunk_graph
|
||||
|
||||
|
||||
async def _integrate_ontology_enhanced_graphs(
|
||||
data_chunks: List[DocumentChunk],
|
||||
chunk_graphs: List[Any],
|
||||
enhanced_datapoints: List[Any],
|
||||
ontology_manager: IOntologyManager,
|
||||
ontology_context: OntologyContext,
|
||||
graph_binding_config: Optional[GraphBindingConfig] = None,
|
||||
):
|
||||
"""Integrate ontology-enhanced graphs into the graph database."""
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
# Get existing edges to avoid duplicates
|
||||
existing_edges_map = await retrieve_existing_edges(data_chunks, chunk_graphs)
|
||||
|
||||
# Apply ontology-aware graph binding if available
|
||||
if graph_binding_config:
|
||||
# Transform graphs using custom binding
|
||||
all_nodes = []
|
||||
all_edges = []
|
||||
|
||||
for chunk_graph in chunk_graphs:
|
||||
if hasattr(chunk_graph, 'nodes') and hasattr(chunk_graph, 'edges'):
|
||||
# Convert chunk graph to ontology format for binding
|
||||
from cognee.modules.ontology.interfaces import OntologyGraph, OntologyNode, OntologyEdge
|
||||
|
||||
ontology_nodes = []
|
||||
ontology_edges = []
|
||||
|
||||
for node in chunk_graph.nodes:
|
||||
ont_node = OntologyNode(
|
||||
id=node.id,
|
||||
name=node.name,
|
||||
type=node.type,
|
||||
description=getattr(node, 'description', ''),
|
||||
properties=getattr(node, '__dict__', {})
|
||||
)
|
||||
ontology_nodes.append(ont_node)
|
||||
|
||||
for edge in chunk_graph.edges:
|
||||
ont_edge = OntologyEdge(
|
||||
id=f"{edge.source_node_id}_{edge.target_node_id}_{edge.relationship_name}",
|
||||
source_id=edge.source_node_id,
|
||||
target_id=edge.target_node_id,
|
||||
relationship_type=edge.relationship_name,
|
||||
properties=getattr(edge, '__dict__', {})
|
||||
)
|
||||
ontology_edges.append(ont_edge)
|
||||
|
||||
# Create temporary ontology for binding
|
||||
temp_ontology = OntologyGraph(
|
||||
id="temp_chunk_ontology",
|
||||
name="Chunk Ontology",
|
||||
description="Temporary ontology for chunk graph",
|
||||
format="llm_generated",
|
||||
scope="dataset",
|
||||
nodes=ontology_nodes,
|
||||
edges=ontology_edges
|
||||
)
|
||||
|
||||
# Apply custom binding
|
||||
bound_nodes, bound_edges = await ontology_manager.bind_to_graph(
|
||||
temp_ontology, ontology_context
|
||||
)
|
||||
|
||||
all_nodes.extend(bound_nodes)
|
||||
all_edges.extend(bound_edges)
|
||||
|
||||
# Add bound nodes and edges to graph
|
||||
if all_nodes:
|
||||
await add_data_points(all_nodes)
|
||||
if all_edges:
|
||||
await graph_engine.add_edges(all_edges)
|
||||
|
||||
else:
|
||||
# Use standard integration
|
||||
graph_nodes, graph_edges = expand_with_nodes_and_edges(
|
||||
data_chunks, chunk_graphs, None, existing_edges_map
|
||||
)
|
||||
|
||||
if graph_nodes:
|
||||
await add_data_points(graph_nodes)
|
||||
if graph_edges:
|
||||
await graph_engine.add_edges(graph_edges)
|
||||
|
||||
# Add enhanced DataPoints from ontology resolution
|
||||
if enhanced_datapoints:
|
||||
await add_data_points(enhanced_datapoints)
|
||||
logger.info(f"Added {len(enhanced_datapoints)} ontology-resolved DataPoints")
|
||||
|
||||
|
||||
async def _extract_standard(
|
||||
data_chunks: List[DocumentChunk],
|
||||
graph_model: Type[Any],
|
||||
**kwargs
|
||||
) -> List[DocumentChunk]:
|
||||
"""Standard graph extraction without ontology (fallback)."""
|
||||
|
||||
# Import the original function to maintain compatibility
|
||||
from cognee.tasks.graph.extract_graph_from_data import integrate_chunk_graphs
|
||||
|
||||
# Generate standard graphs using LLM
|
||||
chunk_graphs = []
|
||||
for chunk in data_chunks:
|
||||
system_prompt = LLMGateway.read_query_prompt("generate_graph_prompt_oneshot.txt")
|
||||
chunk_graph = await LLMGateway.acreate_structured_output(
|
||||
chunk.text, system_prompt, graph_model
|
||||
)
|
||||
chunk_graphs.append(chunk_graph)
|
||||
|
||||
# Use standard integration
|
||||
return await integrate_chunk_graphs(
|
||||
data_chunks=data_chunks,
|
||||
chunk_graphs=chunk_graphs,
|
||||
graph_model=graph_model,
|
||||
ontology_adapter=None, # No ontology adapter
|
||||
)
|
||||
|
||||
|
||||
def _build_ontology_context_prompt(
|
||||
extracted_entities: List[Dict[str, Any]],
|
||||
semantic_relationships: List[Dict[str, Any]]
|
||||
) -> str:
|
||||
"""Build ontological context prompt for LLM."""
|
||||
|
||||
if not extracted_entities and not semantic_relationships:
|
||||
return ""
|
||||
|
||||
prompt = "ONTOLOGICAL CONTEXT:\n"
|
||||
|
||||
if extracted_entities:
|
||||
prompt += "Known entities in this domain:\n"
|
||||
for entity in extracted_entities[:10]: # Limit to top 10
|
||||
prompt += f"- {entity['name']} (type: {entity['type']})\n"
|
||||
prompt += "\n"
|
||||
|
||||
if semantic_relationships:
|
||||
prompt += "Known relationships:\n"
|
||||
for rel in semantic_relationships[:10]: # Limit to top 10
|
||||
prompt += f"- {rel['source']} {rel['relationship']} {rel['target']}\n"
|
||||
prompt += "\n"
|
||||
|
||||
prompt += "When extracting entities and relationships, prefer these ontological concepts when they appear in the text.\n\n"
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def _find_matching_ontological_entity(
|
||||
extracted_node: Any,
|
||||
ontological_entities: List[Dict[str, Any]]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Find matching ontological entity for an extracted node."""
|
||||
|
||||
node_name = getattr(extracted_node, 'name', '').lower()
|
||||
|
||||
for entity in ontological_entities:
|
||||
entity_name = entity.get('name', '').lower()
|
||||
if node_name == entity_name or node_name in entity_name or entity_name in node_name:
|
||||
return entity
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def _convert_entities_to_ontology_nodes(
|
||||
extracted_entities: List[Dict[str, Any]],
|
||||
ontology: Any
|
||||
) -> List[Any]:
|
||||
"""Convert extracted entities to ontology nodes for DataPoint resolution."""
|
||||
|
||||
from cognee.modules.ontology.interfaces import OntologyNode
|
||||
|
||||
ontology_nodes = []
|
||||
|
||||
for entity in extracted_entities:
|
||||
node = OntologyNode(
|
||||
id=entity.get('node_id', entity['name']),
|
||||
name=entity['name'],
|
||||
type=entity['type'],
|
||||
description=entity.get('description', ''),
|
||||
category=entity.get('category', 'entity'),
|
||||
properties={
|
||||
'confidence': entity.get('confidence', 0.0),
|
||||
'ontology_id': entity.get('ontology_id'),
|
||||
'source': 'llm_extraction'
|
||||
}
|
||||
)
|
||||
ontology_nodes.append(node)
|
||||
|
||||
return ontology_nodes
|
||||
Loading…
Add table
Reference in a new issue