feat: add regex entity extractor (#605)

<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->
- Created a new RegexEntityExtractor that uses regex patterns to
identify entities like emails, URLs, and dates in text
- Implemented a JSON-based configuration system to add or modify entity
types without changing code
- Built a separate RegexEntityConfig class to handle loading and
processing of entity configurations
- Added test suite covering all entity types and edge cases
## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Introduced a new regex-based extraction capability that uses
configurable patterns and description templates to identify common
entities such as emails, phone numbers, URLs, dates, and more.
- **Tests**
- Added comprehensive tests to validate the extraction functionality
across standard scenarios and edge cases for reliable text analysis.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
lxobr 2025-03-06 12:13:59 +01:00 committed by GitHub
parent 9d783675e0
commit ea5b11a3b4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 506 additions and 0 deletions

View file

@ -0,0 +1,62 @@
[
{
"entity_name": "EMAIL",
"entity_description": "Entity type for email entities",
"regex": "[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}",
"description_template": "Email address: {}"
},
{
"entity_name": "PHONE",
"entity_description": "Entity type for phone entities",
"regex": "\\+?\\d{1,4}[\\s-]?\\(?\\d{2,4}\\)?[\\s-]?\\d{3,4}[\\s-]?\\d{3,4}",
"description_template": "Phone number: {}"
},
{
"entity_name": "URL",
"entity_description": "Entity type for url entities",
"regex": "https?:\\/\\/(www\\.)?[a-zA-Z0-9-]+(\\.[a-zA-Z]{2,})+(\\/\\S*)?",
"description_template": "URL: {}"
},
{
"entity_name": "DATE",
"entity_description": "Entity type for date entities",
"regex": "(\\d{4}[-/]\\d{2}[-/]\\d{2})|(\\d{2}[-/]\\d{2}[-/]\\d{4})",
"description_template": "Date: {}"
},
{
"entity_name": "TIME",
"entity_description": "Entity type for time entities",
"regex": "(1[0-2]|0?[1-9]):[0-5][0-9](\\s?[APap][Mm])?|([01]?[0-9]|2[0-3]):[0-5][0-9]",
"description_template": "Time: {}"
},
{
"entity_name": "MONEY",
"entity_description": "Entity type for money entities",
"regex": "\\$?\\d{1,3}(,\\d{3})*(\\.[0-9]{2})?|\\€?\\d{1,3}(\\.\\d{3})*(,[0-9]{2})?",
"description_template": "Monetary amount: {}"
},
{
"entity_name": "PERSON",
"entity_description": "Entity type for person entities",
"regex": "\\b(?:(?:Dr|Prof|Mr|Mrs|Ms)\\.?\\s+)?[A-Z][a-z]+(?:\\s+(?:[A-Z][a-z]+|[A-Z]\\.?|(?:van|de|la|del|von|der|le)))+\\b",
"description_template": "Person name: {}"
},
{
"entity_name": "HASHTAG",
"entity_description": "Entity type for hashtag entities",
"regex": "\\#[A-Za-z0-9_]+",
"description_template": "Hashtag: {}"
},
{
"entity_name": "MENTION",
"entity_description": "Entity type for mention entities",
"regex": "\\@[A-Za-z0-9_]+",
"description_template": "Mention: {}"
},
{
"entity_name": "IP_ADDRESS",
"entity_description": "Entity type for ip_address entities",
"regex": "(?<!\\d\\.)(?:(?:25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]\\d?|0)\\.){3}(?:25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]\\d?|0)(?!\\.\\d)",
"description_template": "IP address: {}"
}
]

View file

@ -0,0 +1,91 @@
import json
import logging
import os
import re
from typing import Dict, List, Pattern, Any
from cognee.modules.engine.models.EntityType import EntityType
from cognee.root_dir import get_absolute_path
logger = logging.getLogger("regex_entity_config")
class RegexEntityConfig:
"""Class to load and process regex entity extraction configuration."""
def __init__(self, config_path: str):
"""Initialize the regex entity configuration with the config path."""
self.config_path = config_path
self.entity_configs = {}
self._load_config()
def _validate_config_fields(self, config: Dict[str, Any]) -> None:
"""Validate that all required fields are present in the configuration."""
required_fields = ["entity_name", "entity_description", "regex", "description_template"]
missing_fields = [field for field in required_fields if field not in config]
if missing_fields:
raise ValueError(
f"Missing required fields in entity configuration: {', '.join(missing_fields)}"
)
def _compile_regex(self, pattern: str, entity_name: str) -> Pattern:
"""Compile a regex pattern safely, with error handling."""
try:
return re.compile(pattern)
except re.error as e:
logger.error(f"Invalid regex pattern for entity '{entity_name}': {str(e)}")
raise ValueError(f"Invalid regex pattern for entity '{entity_name}': {str(e)}")
def _load_config(self) -> None:
"""Load and process the configuration from the JSON file."""
try:
with open(self.config_path, "r") as f:
config_list = json.load(f)
except FileNotFoundError:
logger.error(f"Config file not found: {self.config_path}")
raise
except json.JSONDecodeError as e:
logger.error(f"Invalid JSON in config file {self.config_path}: {str(e)}")
raise ValueError(f"Invalid JSON in config file: {str(e)}")
for config in config_list:
self._validate_config_fields(config)
entity_name = config["entity_name"]
entity_type = EntityType(name=entity_name, description=config["entity_description"])
compiled_pattern = self._compile_regex(config["regex"], entity_name)
self.entity_configs[entity_name] = {
"entity_type": entity_type,
"regex": config["regex"],
"compiled_pattern": compiled_pattern,
"description_template": config["description_template"],
}
logger.info(
f"Loaded {len(self.entity_configs)} entity configurations from {self.config_path}"
)
def get_entity_names(self) -> List[str]:
"""Return a list of all configured entity names."""
return list(self.entity_configs.keys())
def get_entity_config(self, entity_name: str) -> Dict[str, Any]:
"""Get the configuration for a specific entity type."""
if entity_name not in self.entity_configs:
raise KeyError(f"Unknown entity type: {entity_name}")
return self.entity_configs[entity_name]
def get_entity_type(self, entity_name: str) -> EntityType:
"""Get the EntityType object for a specific entity type."""
return self.get_entity_config(entity_name)["entity_type"]
def get_compiled_pattern(self, entity_name: str) -> Pattern:
"""Get the compiled regex pattern for a specific entity type."""
return self.get_entity_config(entity_name)["compiled_pattern"]
def get_description_template(self, entity_name: str) -> str:
"""Get the description template for a specific entity type."""
return self.get_entity_config(entity_name)["description_template"]

View file

@ -0,0 +1,72 @@
import logging
from typing import List, Optional
from cognee.infrastructure.entities.BaseEntityExtractor import BaseEntityExtractor
from cognee.modules.engine.models import Entity
from cognee.root_dir import get_absolute_path
from cognee.tasks.entity_completion.entity_extractors.regex_entity_config import RegexEntityConfig
logger = logging.getLogger("regex_entity_extractor")
class RegexEntityExtractor(BaseEntityExtractor):
"""Entity extractor that uses regular expressions to identify entities in text."""
def __init__(self, config_path: Optional[str] = None):
"""Initialize the regex entity extractor with an optional custom config path."""
if config_path is None:
config_path = get_absolute_path(
"tasks/entity_completion/entity_extractors/regex_entity_config.json"
)
self.config = RegexEntityConfig(config_path)
logger.info(
f"Initialized RegexEntityExtractor with {len(self.config.get_entity_names())} entity types"
)
def _create_entity(self, match_text: str, entity_type_obj, description_template: str) -> Entity:
"""Create an entity from a regex match."""
return Entity(
name=match_text,
is_a=entity_type_obj,
description=description_template.format(match_text),
)
def _extract_entities_by_type(self, entity_type: str, text: str) -> List[Entity]:
"""Extract entities of a specific type from the given text."""
try:
pattern = self.config.get_compiled_pattern(entity_type)
description_template = self.config.get_description_template(entity_type)
entity_type_obj = self.config.get_entity_type(entity_type)
return [
self._create_entity(match.group(0), entity_type_obj, description_template)
for match in pattern.finditer(text)
]
except KeyError:
logger.warning(f"Unknown entity type: {entity_type}")
return []
def _text_to_entities(self, text: str) -> List[Entity]:
"""Extract all entity types from the given text and return them as a list."""
all_entities = []
for entity_type in self.config.get_entity_names():
extracted_entities = self._extract_entities_by_type(entity_type, text)
all_entities.extend(extracted_entities)
logger.info(f"Extracted {len(all_entities)} entities")
return all_entities
async def extract_entities(self, text: str) -> List[Entity]:
"""Extract all configured entity types from the given text."""
if not text or not isinstance(text, str):
logger.warning("Invalid input text for entity extraction")
return []
try:
logger.info(f"Extracting entities from text: {text[:100]}...")
return self._text_to_entities(text)
except Exception as e:
logger.error(f"Entity extraction failed: {str(e)}")
return []

View file

@ -0,0 +1,281 @@
import pytest
from cognee.tasks.entity_completion.entity_extractors.regex_entity_extractor import (
RegexEntityExtractor,
)
@pytest.fixture
def regex_extractor():
"""Create a RegexEntityExtractor instance for testing."""
return RegexEntityExtractor()
@pytest.mark.asyncio
async def test_extract_emails(regex_extractor):
"""Test extraction of email addresses."""
text = "Contact us at support@example.com or sales@company.co.uk for more information."
entities = await regex_extractor.extract_entities(text)
# Filter only EMAIL entities
email_entities = [e for e in entities if e.is_a.name == "EMAIL"]
assert len(email_entities) == 2
assert "support@example.com" in [e.name for e in email_entities]
assert "sales@company.co.uk" in [e.name for e in email_entities]
@pytest.mark.asyncio
async def test_extract_phone_numbers(regex_extractor):
"""Test extraction of phone numbers."""
text = "Call us at +1-555-123-4567 or 020 7946 0958 for support."
entities = await regex_extractor.extract_entities(text)
# Filter only PHONE entities
phone_entities = [e for e in entities if e.is_a.name == "PHONE"]
assert len(phone_entities) == 2
assert "+1-555-123-4567" in [e.name for e in phone_entities]
assert "020 7946 0958" in [e.name for e in phone_entities]
@pytest.mark.asyncio
async def test_extract_urls(regex_extractor):
"""Test extraction of URLs."""
text = "Visit our website at https://www.example.com or http://docs.example.org/help for more information."
entities = await regex_extractor.extract_entities(text)
# Filter only URL entities
url_entities = [e for e in entities if e.is_a.name == "URL"]
assert len(url_entities) == 2
assert "https://www.example.com" in [e.name for e in url_entities]
assert "http://docs.example.org/help" in [e.name for e in url_entities]
@pytest.mark.asyncio
async def test_extract_dates(regex_extractor):
"""Test extraction of dates."""
text = "The event is scheduled for 2023-05-15 and ends on 06/30/2023."
entities = await regex_extractor.extract_entities(text)
# Filter only DATE entities
date_entities = [e for e in entities if e.is_a.name == "DATE"]
assert len(date_entities) == 2
assert "2023-05-15" in [e.name for e in date_entities]
assert "06/30/2023" in [e.name for e in date_entities]
@pytest.mark.asyncio
async def test_extract_times(regex_extractor):
"""Test extraction of times."""
text = "The meeting starts at 09:30 AM and ends at 14:45."
entities = await regex_extractor.extract_entities(text)
# Filter only TIME entities
time_entities = [e for e in entities if e.is_a.name == "TIME"]
assert len(time_entities) == 2
assert "09:30 AM" in [e.name for e in time_entities]
assert "14:45" in [e.name for e in time_entities]
@pytest.mark.asyncio
async def test_extract_money(regex_extractor):
"""Test extraction of monetary amounts."""
text = "The product costs $1,299.99 or €1.045,00 depending on your region."
entities = await regex_extractor.extract_entities(text)
# Filter only MONEY entities
money_entities = [e for e in entities if e.is_a.name == "MONEY"]
assert len(money_entities) == 2
assert "$1,299.99" in [e.name for e in money_entities]
assert "€1.045,00" in [e.name for e in money_entities]
@pytest.mark.asyncio
async def test_extract_person_names(regex_extractor):
"""Test extraction of person names with various formats."""
text = """
Standard names: John Smith and Sarah Johnson will be attending.
Names with titles: Dr. Jane Wilson and Prof Michael Brown will present.
Names with middle initials: James T. Kirk and William H Gates are invited.
Names with prefixes: Jean de la Fontaine and Ludwig van Beethoven are famous.
Single names like Mary or Robert should not be extracted as they could be
confused with regular capitalized words at the beginning of sentences.
"""
entities = await regex_extractor.extract_entities(text)
# Filter only PERSON entities
person_entities = [e for e in entities if e.is_a.name == "PERSON"]
entity_names = [e.name for e in person_entities]
# Standard two-part names
assert "John Smith" in entity_names
assert "Sarah Johnson" in entity_names
# Names with titles
assert "Dr. Jane Wilson" in entity_names
assert "Prof Michael Brown" in entity_names
# Names with middle initials
assert "James T. Kirk" in entity_names
assert "William H Gates" in entity_names
# Names with prefixes
assert "Jean de la Fontaine" in entity_names
assert "Ludwig van Beethoven" in entity_names
# Verify single names are not extracted
assert "Mary" not in entity_names
assert "Robert" not in entity_names
# Verify we have the expected number of names
assert len(person_entities) == 8
@pytest.mark.asyncio
async def test_extract_hashtags(regex_extractor):
"""Test extraction of hashtags."""
text = "Check out our latest post #Python #MachineLearning"
entities = await regex_extractor.extract_entities(text)
# Filter only HASHTAG entities
hashtag_entities = [e for e in entities if e.is_a.name == "HASHTAG"]
assert len(hashtag_entities) == 2
assert "#Python" in [e.name for e in hashtag_entities]
assert "#MachineLearning" in [e.name for e in hashtag_entities]
@pytest.mark.asyncio
async def test_extract_mentions(regex_extractor):
"""Test extraction of mentions."""
text = "Thanks to @johndoe and @jane_smith for their contributions."
entities = await regex_extractor.extract_entities(text)
# Filter only MENTION entities
mention_entities = [e for e in entities if e.is_a.name == "MENTION"]
assert len(mention_entities) == 2
assert "@johndoe" in [e.name for e in mention_entities]
assert "@jane_smith" in [e.name for e in mention_entities]
@pytest.mark.asyncio
async def test_extract_ip_addresses(regex_extractor):
"""Test extraction of IP addresses with proper validation of octet ranges."""
# Test with valid IP addresses
text = "The server IPs are 192.168.1.1, 10.0.0.1, 255.255.255.255, and 0.0.0.0."
entities = await regex_extractor.extract_entities(text)
# Filter only IP_ADDRESS entities
ip_entities = [e for e in entities if e.is_a.name == "IP_ADDRESS"]
assert len(ip_entities) == 4
assert "192.168.1.1" in [e.name for e in ip_entities]
assert "10.0.0.1" in [e.name for e in ip_entities]
assert "255.255.255.255" in [e.name for e in ip_entities]
assert "0.0.0.0" in [e.name for e in ip_entities]
@pytest.mark.asyncio
async def test_invalid_ip_addresses(regex_extractor):
"""Test that invalid IP addresses are not extracted."""
# Test with invalid IP addresses
text = "Invalid IPs: 999.999.999.999, 256.256.256.256, 1.2.3.4.5, 01.102.103.104"
entities = await regex_extractor.extract_entities(text)
# Filter only IP_ADDRESS entities
ip_entities = [e for e in entities if e.is_a.name == "IP_ADDRESS"]
# None of these should be extracted as valid IPs
assert len(ip_entities) == 1
assert "999.999.999.999" not in [e.name for e in ip_entities]
assert "256.256.256.256" not in [e.name for e in ip_entities]
assert "1.2.3.4.5" not in [e.name for e in ip_entities]
assert "01.102.103.104" not in [e.name for e in ip_entities]
assert "1.102.103.104" in [e.name for e in ip_entities]
@pytest.mark.asyncio
async def test_extract_multiple_entity_types(regex_extractor):
"""Test extraction of multiple entity types from a single text."""
text = """
Contact John Doe at john.doe@example.com or +1-555-123-4567.
Visit our website at https://www.example.com.
The meeting is scheduled for 2023-05-15 at 09:30 AM.
The project budget is $10,000.00.
Follow us on social media with #Python and mention @pythonorg.
Our server IP is 192.168.1.1.
"""
entities = await regex_extractor.extract_entities(text)
# Check that we have at least one entity of each type
entity_types = [e.is_a.name for e in entities]
assert "EMAIL" in entity_types
assert "PHONE" in entity_types
assert "URL" in entity_types
assert "DATE" in entity_types
assert "TIME" in entity_types
assert "MONEY" in entity_types
assert "PERSON" in entity_types
assert "HASHTAG" in entity_types
assert "MENTION" in entity_types
assert "IP_ADDRESS" in entity_types
@pytest.mark.asyncio
async def test_empty_text(regex_extractor):
"""Test extraction with empty text."""
entities = await regex_extractor.extract_entities("")
assert len(entities) == 0
@pytest.mark.asyncio
async def test_none_text(regex_extractor):
"""Test extraction with None text."""
entities = await regex_extractor.extract_entities(None)
assert len(entities) == 0
@pytest.mark.asyncio
async def test_text_without_entities(regex_extractor):
"""Test extraction with text that doesn't contain any entities."""
text = "This text does not contain any extractable entities."
entities = await regex_extractor.extract_entities(text)
assert len(entities) == 0
@pytest.mark.asyncio
async def test_custom_config_path(tmp_path):
"""Test extraction with a custom configuration path."""
# Create a minimal test config file
config_content = """[
{
"entity_name": "TEST_ENTITY",
"entity_description": "Test entity type",
"regex": "TEST\\\\d+",
"description_template": "Test entity: {}"
}
]"""
config_path = tmp_path / "test_config.json"
with open(config_path, "w") as f:
f.write(config_content)
# Create extractor with custom config
extractor = RegexEntityExtractor(str(config_path))
# Test extraction
text = "This contains TEST123 and TEST456."
entities = await extractor.extract_entities(text)
assert len(entities) == 2
assert all(e.is_a.name == "TEST_ENTITY" for e in entities)
assert "TEST123" in [e.name for e in entities]
assert "TEST456" in [e.name for e in entities]