diff --git a/cognee/infrastructure/llm/prompts/extract_entities_system.txt b/cognee/infrastructure/llm/prompts/extract_entities_system.txt new file mode 100644 index 000000000..6f9fc1484 --- /dev/null +++ b/cognee/infrastructure/llm/prompts/extract_entities_system.txt @@ -0,0 +1,42 @@ +You are an expert entity extraction system. Your task is to identify and extract important named entities from the provided text. + +Extract only distinct, meaningful entities that are central to understanding the text. Avoid extracting common nouns, pronouns, or generic terms. + +For each entity, provide: +1. name: The entity name +2. is_a: An EntityType object with: + - name: The type name (in uppercase) + - description: A brief description of the type +3. description: A brief description of the entity (1-2 sentences) + +Your response MUST be a valid JSON object with a single field "entities" containing an array of entity objects. Do not include any explanatory text, markdown formatting, or code blocks outside of the JSON. + +Example response format: +{ + "entities": [ + { + "name": "Albert Einstein", + "is_a": { + "name": "PERSON", + "description": "Entity type for person entities" + }, + "description": "A theoretical physicist who developed the theory of relativity." + }, + { + "name": "Theory of Relativity", + "is_a": { + "name": "CONCEPT", + "description": "Entity type for concept entities" + }, + "description": "A physics theory describing the relationship between space and time." + }, + { + "name": "Princeton University", + "is_a": { + "name": "ORGANIZATION", + "description": "Entity type for organization entities" + }, + "description": "An Ivy League research university in Princeton, New Jersey." + } + ] +} diff --git a/cognee/infrastructure/llm/prompts/extract_entities_user.txt b/cognee/infrastructure/llm/prompts/extract_entities_user.txt new file mode 100644 index 000000000..ed2c0fb74 --- /dev/null +++ b/cognee/infrastructure/llm/prompts/extract_entities_user.txt @@ -0,0 +1,3 @@ +Extract key entities from this text: + +{{ text }} diff --git a/cognee/tasks/entity_completion/entity_extractors/llm_entity_extractor.py b/cognee/tasks/entity_completion/entity_extractors/llm_entity_extractor.py new file mode 100644 index 000000000..3e8acd59c --- /dev/null +++ b/cognee/tasks/entity_completion/entity_extractors/llm_entity_extractor.py @@ -0,0 +1,73 @@ +import logging +from typing import List + +from pydantic import BaseModel + +from cognee.infrastructure.entities.BaseEntityExtractor import BaseEntityExtractor +from cognee.modules.engine.models import Entity +from cognee.modules.engine.models.EntityType import EntityType +from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt +from cognee.infrastructure.llm.get_llm_client import get_llm_client + +logger = logging.getLogger("llm_entity_extractor") + + +class EntityList(BaseModel): + """Response model containing a list of extracted entities.""" + + entities: List[Entity] + + +class LLMEntityExtractor(BaseEntityExtractor): + """Entity extractor that uses an LLM to identify entities in text.""" + + def __init__( + self, + system_prompt_template: str = "extract_entities_system.txt", + user_prompt_template: str = "extract_entities_user.txt", + ): + """Initialize the LLM entity extractor.""" + self.system_prompt_template = system_prompt_template + self.user_prompt_template = user_prompt_template + self._entity_type_cache = {} + + def _get_entity_type(self, type_name: str) -> EntityType: + """Get or create an EntityType object.""" + type_name = type_name.upper() + + if type_name not in self._entity_type_cache: + self._entity_type_cache[type_name] = EntityType( + name=type_name, description=f"Entity type for {type_name.lower()} entities" + ) + + return self._entity_type_cache[type_name] + + async def extract_entities(self, text: str) -> List[Entity]: + """Extract entities from text using an LLM.""" + if not text or not isinstance(text, str): + logger.warning("Invalid input text for entity extraction") + return [] + + try: + logger.info(f"Extracting entities from text: {text[:100]}...") + + llm_client = get_llm_client() + user_prompt = render_prompt(self.user_prompt_template, {"text": text}) + system_prompt = read_query_prompt(self.system_prompt_template) + + response = await llm_client.acreate_structured_output( + text_input=user_prompt, + system_prompt=system_prompt, + response_model=EntityList, + ) + + if not response.entities: + logger.warning("No entities were extracted from the text") + return [] + + logger.info(f"Extracted {len(response.entities)} entities") + return response.entities + + except Exception as e: + logger.error(f"Entity extraction failed: {str(e)}") + return []