feat: add regex entity extractor (#605)
<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> - Created a new RegexEntityExtractor that uses regex patterns to identify entities like emails, URLs, and dates in text - Implemented a JSON-based configuration system to add or modify entity types without changing code - Built a separate RegexEntityConfig class to handle loading and processing of entity configurations - Added test suite covering all entity types and edge cases ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced a new regex-based extraction capability that uses configurable patterns and description templates to identify common entities such as emails, phone numbers, URLs, dates, and more. - **Tests** - Added comprehensive tests to validate the extraction functionality across standard scenarios and edge cases for reliable text analysis. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
parent
9d783675e0
commit
ea5b11a3b4
4 changed files with 506 additions and 0 deletions
|
|
@ -0,0 +1,62 @@
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"entity_name": "EMAIL",
|
||||||
|
"entity_description": "Entity type for email entities",
|
||||||
|
"regex": "[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}",
|
||||||
|
"description_template": "Email address: {}"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"entity_name": "PHONE",
|
||||||
|
"entity_description": "Entity type for phone entities",
|
||||||
|
"regex": "\\+?\\d{1,4}[\\s-]?\\(?\\d{2,4}\\)?[\\s-]?\\d{3,4}[\\s-]?\\d{3,4}",
|
||||||
|
"description_template": "Phone number: {}"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"entity_name": "URL",
|
||||||
|
"entity_description": "Entity type for url entities",
|
||||||
|
"regex": "https?:\\/\\/(www\\.)?[a-zA-Z0-9-]+(\\.[a-zA-Z]{2,})+(\\/\\S*)?",
|
||||||
|
"description_template": "URL: {}"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"entity_name": "DATE",
|
||||||
|
"entity_description": "Entity type for date entities",
|
||||||
|
"regex": "(\\d{4}[-/]\\d{2}[-/]\\d{2})|(\\d{2}[-/]\\d{2}[-/]\\d{4})",
|
||||||
|
"description_template": "Date: {}"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"entity_name": "TIME",
|
||||||
|
"entity_description": "Entity type for time entities",
|
||||||
|
"regex": "(1[0-2]|0?[1-9]):[0-5][0-9](\\s?[APap][Mm])?|([01]?[0-9]|2[0-3]):[0-5][0-9]",
|
||||||
|
"description_template": "Time: {}"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"entity_name": "MONEY",
|
||||||
|
"entity_description": "Entity type for money entities",
|
||||||
|
"regex": "\\$?\\d{1,3}(,\\d{3})*(\\.[0-9]{2})?|\\€?\\d{1,3}(\\.\\d{3})*(,[0-9]{2})?",
|
||||||
|
"description_template": "Monetary amount: {}"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"entity_name": "PERSON",
|
||||||
|
"entity_description": "Entity type for person entities",
|
||||||
|
"regex": "\\b(?:(?:Dr|Prof|Mr|Mrs|Ms)\\.?\\s+)?[A-Z][a-z]+(?:\\s+(?:[A-Z][a-z]+|[A-Z]\\.?|(?:van|de|la|del|von|der|le)))+\\b",
|
||||||
|
"description_template": "Person name: {}"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"entity_name": "HASHTAG",
|
||||||
|
"entity_description": "Entity type for hashtag entities",
|
||||||
|
"regex": "\\#[A-Za-z0-9_]+",
|
||||||
|
"description_template": "Hashtag: {}"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"entity_name": "MENTION",
|
||||||
|
"entity_description": "Entity type for mention entities",
|
||||||
|
"regex": "\\@[A-Za-z0-9_]+",
|
||||||
|
"description_template": "Mention: {}"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"entity_name": "IP_ADDRESS",
|
||||||
|
"entity_description": "Entity type for ip_address entities",
|
||||||
|
"regex": "(?<!\\d\\.)(?:(?:25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]\\d?|0)\\.){3}(?:25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]\\d?|0)(?!\\.\\d)",
|
||||||
|
"description_template": "IP address: {}"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,91 @@
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from typing import Dict, List, Pattern, Any
|
||||||
|
|
||||||
|
from cognee.modules.engine.models.EntityType import EntityType
|
||||||
|
from cognee.root_dir import get_absolute_path
|
||||||
|
|
||||||
|
logger = logging.getLogger("regex_entity_config")
|
||||||
|
|
||||||
|
|
||||||
|
class RegexEntityConfig:
|
||||||
|
"""Class to load and process regex entity extraction configuration."""
|
||||||
|
|
||||||
|
def __init__(self, config_path: str):
|
||||||
|
"""Initialize the regex entity configuration with the config path."""
|
||||||
|
self.config_path = config_path
|
||||||
|
self.entity_configs = {}
|
||||||
|
self._load_config()
|
||||||
|
|
||||||
|
def _validate_config_fields(self, config: Dict[str, Any]) -> None:
|
||||||
|
"""Validate that all required fields are present in the configuration."""
|
||||||
|
required_fields = ["entity_name", "entity_description", "regex", "description_template"]
|
||||||
|
missing_fields = [field for field in required_fields if field not in config]
|
||||||
|
|
||||||
|
if missing_fields:
|
||||||
|
raise ValueError(
|
||||||
|
f"Missing required fields in entity configuration: {', '.join(missing_fields)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _compile_regex(self, pattern: str, entity_name: str) -> Pattern:
|
||||||
|
"""Compile a regex pattern safely, with error handling."""
|
||||||
|
try:
|
||||||
|
return re.compile(pattern)
|
||||||
|
except re.error as e:
|
||||||
|
logger.error(f"Invalid regex pattern for entity '{entity_name}': {str(e)}")
|
||||||
|
raise ValueError(f"Invalid regex pattern for entity '{entity_name}': {str(e)}")
|
||||||
|
|
||||||
|
def _load_config(self) -> None:
|
||||||
|
"""Load and process the configuration from the JSON file."""
|
||||||
|
try:
|
||||||
|
with open(self.config_path, "r") as f:
|
||||||
|
config_list = json.load(f)
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.error(f"Config file not found: {self.config_path}")
|
||||||
|
raise
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"Invalid JSON in config file {self.config_path}: {str(e)}")
|
||||||
|
raise ValueError(f"Invalid JSON in config file: {str(e)}")
|
||||||
|
|
||||||
|
for config in config_list:
|
||||||
|
self._validate_config_fields(config)
|
||||||
|
entity_name = config["entity_name"]
|
||||||
|
|
||||||
|
entity_type = EntityType(name=entity_name, description=config["entity_description"])
|
||||||
|
|
||||||
|
compiled_pattern = self._compile_regex(config["regex"], entity_name)
|
||||||
|
|
||||||
|
self.entity_configs[entity_name] = {
|
||||||
|
"entity_type": entity_type,
|
||||||
|
"regex": config["regex"],
|
||||||
|
"compiled_pattern": compiled_pattern,
|
||||||
|
"description_template": config["description_template"],
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Loaded {len(self.entity_configs)} entity configurations from {self.config_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_entity_names(self) -> List[str]:
|
||||||
|
"""Return a list of all configured entity names."""
|
||||||
|
return list(self.entity_configs.keys())
|
||||||
|
|
||||||
|
def get_entity_config(self, entity_name: str) -> Dict[str, Any]:
|
||||||
|
"""Get the configuration for a specific entity type."""
|
||||||
|
if entity_name not in self.entity_configs:
|
||||||
|
raise KeyError(f"Unknown entity type: {entity_name}")
|
||||||
|
return self.entity_configs[entity_name]
|
||||||
|
|
||||||
|
def get_entity_type(self, entity_name: str) -> EntityType:
|
||||||
|
"""Get the EntityType object for a specific entity type."""
|
||||||
|
return self.get_entity_config(entity_name)["entity_type"]
|
||||||
|
|
||||||
|
def get_compiled_pattern(self, entity_name: str) -> Pattern:
|
||||||
|
"""Get the compiled regex pattern for a specific entity type."""
|
||||||
|
return self.get_entity_config(entity_name)["compiled_pattern"]
|
||||||
|
|
||||||
|
def get_description_template(self, entity_name: str) -> str:
|
||||||
|
"""Get the description template for a specific entity type."""
|
||||||
|
return self.get_entity_config(entity_name)["description_template"]
|
||||||
|
|
@ -0,0 +1,72 @@
|
||||||
|
import logging
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from cognee.infrastructure.entities.BaseEntityExtractor import BaseEntityExtractor
|
||||||
|
from cognee.modules.engine.models import Entity
|
||||||
|
from cognee.root_dir import get_absolute_path
|
||||||
|
from cognee.tasks.entity_completion.entity_extractors.regex_entity_config import RegexEntityConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger("regex_entity_extractor")
|
||||||
|
|
||||||
|
|
||||||
|
class RegexEntityExtractor(BaseEntityExtractor):
|
||||||
|
"""Entity extractor that uses regular expressions to identify entities in text."""
|
||||||
|
|
||||||
|
def __init__(self, config_path: Optional[str] = None):
|
||||||
|
"""Initialize the regex entity extractor with an optional custom config path."""
|
||||||
|
if config_path is None:
|
||||||
|
config_path = get_absolute_path(
|
||||||
|
"tasks/entity_completion/entity_extractors/regex_entity_config.json"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.config = RegexEntityConfig(config_path)
|
||||||
|
logger.info(
|
||||||
|
f"Initialized RegexEntityExtractor with {len(self.config.get_entity_names())} entity types"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_entity(self, match_text: str, entity_type_obj, description_template: str) -> Entity:
|
||||||
|
"""Create an entity from a regex match."""
|
||||||
|
return Entity(
|
||||||
|
name=match_text,
|
||||||
|
is_a=entity_type_obj,
|
||||||
|
description=description_template.format(match_text),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _extract_entities_by_type(self, entity_type: str, text: str) -> List[Entity]:
|
||||||
|
"""Extract entities of a specific type from the given text."""
|
||||||
|
try:
|
||||||
|
pattern = self.config.get_compiled_pattern(entity_type)
|
||||||
|
description_template = self.config.get_description_template(entity_type)
|
||||||
|
entity_type_obj = self.config.get_entity_type(entity_type)
|
||||||
|
|
||||||
|
return [
|
||||||
|
self._create_entity(match.group(0), entity_type_obj, description_template)
|
||||||
|
for match in pattern.finditer(text)
|
||||||
|
]
|
||||||
|
except KeyError:
|
||||||
|
logger.warning(f"Unknown entity type: {entity_type}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _text_to_entities(self, text: str) -> List[Entity]:
|
||||||
|
"""Extract all entity types from the given text and return them as a list."""
|
||||||
|
all_entities = []
|
||||||
|
|
||||||
|
for entity_type in self.config.get_entity_names():
|
||||||
|
extracted_entities = self._extract_entities_by_type(entity_type, text)
|
||||||
|
all_entities.extend(extracted_entities)
|
||||||
|
|
||||||
|
logger.info(f"Extracted {len(all_entities)} entities")
|
||||||
|
return all_entities
|
||||||
|
|
||||||
|
async def extract_entities(self, text: str) -> List[Entity]:
|
||||||
|
"""Extract all configured entity types from the given text."""
|
||||||
|
if not text or not isinstance(text, str):
|
||||||
|
logger.warning("Invalid input text for entity extraction")
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(f"Extracting entities from text: {text[:100]}...")
|
||||||
|
return self._text_to_entities(text)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Entity extraction failed: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
@ -0,0 +1,281 @@
|
||||||
|
import pytest
|
||||||
|
from cognee.tasks.entity_completion.entity_extractors.regex_entity_extractor import (
|
||||||
|
RegexEntityExtractor,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def regex_extractor():
|
||||||
|
"""Create a RegexEntityExtractor instance for testing."""
|
||||||
|
return RegexEntityExtractor()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_emails(regex_extractor):
|
||||||
|
"""Test extraction of email addresses."""
|
||||||
|
text = "Contact us at support@example.com or sales@company.co.uk for more information."
|
||||||
|
entities = await regex_extractor.extract_entities(text)
|
||||||
|
|
||||||
|
# Filter only EMAIL entities
|
||||||
|
email_entities = [e for e in entities if e.is_a.name == "EMAIL"]
|
||||||
|
|
||||||
|
assert len(email_entities) == 2
|
||||||
|
assert "support@example.com" in [e.name for e in email_entities]
|
||||||
|
assert "sales@company.co.uk" in [e.name for e in email_entities]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_phone_numbers(regex_extractor):
|
||||||
|
"""Test extraction of phone numbers."""
|
||||||
|
text = "Call us at +1-555-123-4567 or 020 7946 0958 for support."
|
||||||
|
entities = await regex_extractor.extract_entities(text)
|
||||||
|
|
||||||
|
# Filter only PHONE entities
|
||||||
|
phone_entities = [e for e in entities if e.is_a.name == "PHONE"]
|
||||||
|
|
||||||
|
assert len(phone_entities) == 2
|
||||||
|
assert "+1-555-123-4567" in [e.name for e in phone_entities]
|
||||||
|
assert "020 7946 0958" in [e.name for e in phone_entities]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_urls(regex_extractor):
|
||||||
|
"""Test extraction of URLs."""
|
||||||
|
text = "Visit our website at https://www.example.com or http://docs.example.org/help for more information."
|
||||||
|
entities = await regex_extractor.extract_entities(text)
|
||||||
|
|
||||||
|
# Filter only URL entities
|
||||||
|
url_entities = [e for e in entities if e.is_a.name == "URL"]
|
||||||
|
|
||||||
|
assert len(url_entities) == 2
|
||||||
|
assert "https://www.example.com" in [e.name for e in url_entities]
|
||||||
|
assert "http://docs.example.org/help" in [e.name for e in url_entities]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_dates(regex_extractor):
|
||||||
|
"""Test extraction of dates."""
|
||||||
|
text = "The event is scheduled for 2023-05-15 and ends on 06/30/2023."
|
||||||
|
entities = await regex_extractor.extract_entities(text)
|
||||||
|
|
||||||
|
# Filter only DATE entities
|
||||||
|
date_entities = [e for e in entities if e.is_a.name == "DATE"]
|
||||||
|
|
||||||
|
assert len(date_entities) == 2
|
||||||
|
assert "2023-05-15" in [e.name for e in date_entities]
|
||||||
|
assert "06/30/2023" in [e.name for e in date_entities]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_times(regex_extractor):
|
||||||
|
"""Test extraction of times."""
|
||||||
|
text = "The meeting starts at 09:30 AM and ends at 14:45."
|
||||||
|
entities = await regex_extractor.extract_entities(text)
|
||||||
|
|
||||||
|
# Filter only TIME entities
|
||||||
|
time_entities = [e for e in entities if e.is_a.name == "TIME"]
|
||||||
|
|
||||||
|
assert len(time_entities) == 2
|
||||||
|
assert "09:30 AM" in [e.name for e in time_entities]
|
||||||
|
assert "14:45" in [e.name for e in time_entities]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_money(regex_extractor):
|
||||||
|
"""Test extraction of monetary amounts."""
|
||||||
|
text = "The product costs $1,299.99 or €1.045,00 depending on your region."
|
||||||
|
entities = await regex_extractor.extract_entities(text)
|
||||||
|
|
||||||
|
# Filter only MONEY entities
|
||||||
|
money_entities = [e for e in entities if e.is_a.name == "MONEY"]
|
||||||
|
|
||||||
|
assert len(money_entities) == 2
|
||||||
|
assert "$1,299.99" in [e.name for e in money_entities]
|
||||||
|
assert "€1.045,00" in [e.name for e in money_entities]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_person_names(regex_extractor):
|
||||||
|
"""Test extraction of person names with various formats."""
|
||||||
|
text = """
|
||||||
|
Standard names: John Smith and Sarah Johnson will be attending.
|
||||||
|
Names with titles: Dr. Jane Wilson and Prof Michael Brown will present.
|
||||||
|
Names with middle initials: James T. Kirk and William H Gates are invited.
|
||||||
|
Names with prefixes: Jean de la Fontaine and Ludwig van Beethoven are famous.
|
||||||
|
|
||||||
|
Single names like Mary or Robert should not be extracted as they could be
|
||||||
|
confused with regular capitalized words at the beginning of sentences.
|
||||||
|
"""
|
||||||
|
entities = await regex_extractor.extract_entities(text)
|
||||||
|
|
||||||
|
# Filter only PERSON entities
|
||||||
|
person_entities = [e for e in entities if e.is_a.name == "PERSON"]
|
||||||
|
entity_names = [e.name for e in person_entities]
|
||||||
|
|
||||||
|
# Standard two-part names
|
||||||
|
assert "John Smith" in entity_names
|
||||||
|
assert "Sarah Johnson" in entity_names
|
||||||
|
|
||||||
|
# Names with titles
|
||||||
|
assert "Dr. Jane Wilson" in entity_names
|
||||||
|
assert "Prof Michael Brown" in entity_names
|
||||||
|
|
||||||
|
# Names with middle initials
|
||||||
|
assert "James T. Kirk" in entity_names
|
||||||
|
assert "William H Gates" in entity_names
|
||||||
|
|
||||||
|
# Names with prefixes
|
||||||
|
assert "Jean de la Fontaine" in entity_names
|
||||||
|
assert "Ludwig van Beethoven" in entity_names
|
||||||
|
|
||||||
|
# Verify single names are not extracted
|
||||||
|
assert "Mary" not in entity_names
|
||||||
|
assert "Robert" not in entity_names
|
||||||
|
|
||||||
|
# Verify we have the expected number of names
|
||||||
|
assert len(person_entities) == 8
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_hashtags(regex_extractor):
|
||||||
|
"""Test extraction of hashtags."""
|
||||||
|
text = "Check out our latest post #Python #MachineLearning"
|
||||||
|
entities = await regex_extractor.extract_entities(text)
|
||||||
|
|
||||||
|
# Filter only HASHTAG entities
|
||||||
|
hashtag_entities = [e for e in entities if e.is_a.name == "HASHTAG"]
|
||||||
|
|
||||||
|
assert len(hashtag_entities) == 2
|
||||||
|
assert "#Python" in [e.name for e in hashtag_entities]
|
||||||
|
assert "#MachineLearning" in [e.name for e in hashtag_entities]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_mentions(regex_extractor):
|
||||||
|
"""Test extraction of mentions."""
|
||||||
|
text = "Thanks to @johndoe and @jane_smith for their contributions."
|
||||||
|
entities = await regex_extractor.extract_entities(text)
|
||||||
|
|
||||||
|
# Filter only MENTION entities
|
||||||
|
mention_entities = [e for e in entities if e.is_a.name == "MENTION"]
|
||||||
|
|
||||||
|
assert len(mention_entities) == 2
|
||||||
|
assert "@johndoe" in [e.name for e in mention_entities]
|
||||||
|
assert "@jane_smith" in [e.name for e in mention_entities]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_ip_addresses(regex_extractor):
|
||||||
|
"""Test extraction of IP addresses with proper validation of octet ranges."""
|
||||||
|
# Test with valid IP addresses
|
||||||
|
text = "The server IPs are 192.168.1.1, 10.0.0.1, 255.255.255.255, and 0.0.0.0."
|
||||||
|
entities = await regex_extractor.extract_entities(text)
|
||||||
|
|
||||||
|
# Filter only IP_ADDRESS entities
|
||||||
|
ip_entities = [e for e in entities if e.is_a.name == "IP_ADDRESS"]
|
||||||
|
|
||||||
|
assert len(ip_entities) == 4
|
||||||
|
assert "192.168.1.1" in [e.name for e in ip_entities]
|
||||||
|
assert "10.0.0.1" in [e.name for e in ip_entities]
|
||||||
|
assert "255.255.255.255" in [e.name for e in ip_entities]
|
||||||
|
assert "0.0.0.0" in [e.name for e in ip_entities]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invalid_ip_addresses(regex_extractor):
|
||||||
|
"""Test that invalid IP addresses are not extracted."""
|
||||||
|
# Test with invalid IP addresses
|
||||||
|
text = "Invalid IPs: 999.999.999.999, 256.256.256.256, 1.2.3.4.5, 01.102.103.104"
|
||||||
|
entities = await regex_extractor.extract_entities(text)
|
||||||
|
|
||||||
|
# Filter only IP_ADDRESS entities
|
||||||
|
ip_entities = [e for e in entities if e.is_a.name == "IP_ADDRESS"]
|
||||||
|
|
||||||
|
# None of these should be extracted as valid IPs
|
||||||
|
assert len(ip_entities) == 1
|
||||||
|
assert "999.999.999.999" not in [e.name for e in ip_entities]
|
||||||
|
assert "256.256.256.256" not in [e.name for e in ip_entities]
|
||||||
|
assert "1.2.3.4.5" not in [e.name for e in ip_entities]
|
||||||
|
assert "01.102.103.104" not in [e.name for e in ip_entities]
|
||||||
|
assert "1.102.103.104" in [e.name for e in ip_entities]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_multiple_entity_types(regex_extractor):
|
||||||
|
"""Test extraction of multiple entity types from a single text."""
|
||||||
|
text = """
|
||||||
|
Contact John Doe at john.doe@example.com or +1-555-123-4567.
|
||||||
|
Visit our website at https://www.example.com.
|
||||||
|
The meeting is scheduled for 2023-05-15 at 09:30 AM.
|
||||||
|
The project budget is $10,000.00.
|
||||||
|
Follow us on social media with #Python and mention @pythonorg.
|
||||||
|
Our server IP is 192.168.1.1.
|
||||||
|
"""
|
||||||
|
|
||||||
|
entities = await regex_extractor.extract_entities(text)
|
||||||
|
|
||||||
|
# Check that we have at least one entity of each type
|
||||||
|
entity_types = [e.is_a.name for e in entities]
|
||||||
|
|
||||||
|
assert "EMAIL" in entity_types
|
||||||
|
assert "PHONE" in entity_types
|
||||||
|
assert "URL" in entity_types
|
||||||
|
assert "DATE" in entity_types
|
||||||
|
assert "TIME" in entity_types
|
||||||
|
assert "MONEY" in entity_types
|
||||||
|
assert "PERSON" in entity_types
|
||||||
|
assert "HASHTAG" in entity_types
|
||||||
|
assert "MENTION" in entity_types
|
||||||
|
assert "IP_ADDRESS" in entity_types
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_text(regex_extractor):
|
||||||
|
"""Test extraction with empty text."""
|
||||||
|
entities = await regex_extractor.extract_entities("")
|
||||||
|
assert len(entities) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_none_text(regex_extractor):
|
||||||
|
"""Test extraction with None text."""
|
||||||
|
entities = await regex_extractor.extract_entities(None)
|
||||||
|
assert len(entities) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_text_without_entities(regex_extractor):
|
||||||
|
"""Test extraction with text that doesn't contain any entities."""
|
||||||
|
text = "This text does not contain any extractable entities."
|
||||||
|
entities = await regex_extractor.extract_entities(text)
|
||||||
|
assert len(entities) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_custom_config_path(tmp_path):
|
||||||
|
"""Test extraction with a custom configuration path."""
|
||||||
|
# Create a minimal test config file
|
||||||
|
config_content = """[
|
||||||
|
{
|
||||||
|
"entity_name": "TEST_ENTITY",
|
||||||
|
"entity_description": "Test entity type",
|
||||||
|
"regex": "TEST\\\\d+",
|
||||||
|
"description_template": "Test entity: {}"
|
||||||
|
}
|
||||||
|
]"""
|
||||||
|
|
||||||
|
config_path = tmp_path / "test_config.json"
|
||||||
|
with open(config_path, "w") as f:
|
||||||
|
f.write(config_content)
|
||||||
|
|
||||||
|
# Create extractor with custom config
|
||||||
|
extractor = RegexEntityExtractor(str(config_path))
|
||||||
|
|
||||||
|
# Test extraction
|
||||||
|
text = "This contains TEST123 and TEST456."
|
||||||
|
entities = await extractor.extract_entities(text)
|
||||||
|
|
||||||
|
assert len(entities) == 2
|
||||||
|
assert all(e.is_a.name == "TEST_ENTITY" for e in entities)
|
||||||
|
assert "TEST123" in [e.name for e in entities]
|
||||||
|
assert "TEST456" in [e.name for e in entities]
|
||||||
Loading…
Add table
Reference in a new issue