feat: entity completion skeleton [COG-1318] (#552)
<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> - Modular implementation of entity completion search - Added base classes that define entity extractors and context providers - Created dummy implementations that return test data - Set up adapters that let us switch between different entity extractors and context providers using strings - Added configuration class to control which implementations to use - Entity completion: query → find entities → get context → interact with LLM → return answer ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Enhanced the query completion experience with integrated language model response generation, improved validation, and robust error handling. - Introduced sample modules for context retrieval and entity extraction that simulate key processing steps. - Established foundational abstractions to support flexible context and entity handling strategies. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: hajdul88 <52442977+hajdul88@users.noreply.github.com>
This commit is contained in:
parent
a788875117
commit
55411ff44b
10 changed files with 157 additions and 0 deletions
13
cognee/infrastructure/context/BaseContextProvider.py
Normal file
13
cognee/infrastructure/context/BaseContextProvider.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
|
||||
|
||||
class BaseContextProvider(ABC):
|
||||
"""Base class for context retrieval strategies."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_context(self, entities: List[DataPoint], query: str) -> str:
|
||||
"""Get relevant context based on extracted entities and original query."""
|
||||
pass
|
||||
0
cognee/infrastructure/context/__init__.py
Normal file
0
cognee/infrastructure/context/__init__.py
Normal file
13
cognee/infrastructure/entities/BaseEntityExtractor.py
Normal file
13
cognee/infrastructure/entities/BaseEntityExtractor.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
from cognee.modules.engine.models import Entity
|
||||
|
||||
|
||||
class BaseEntityExtractor(ABC):
|
||||
"""Base class for entity extraction strategies."""
|
||||
|
||||
@abstractmethod
|
||||
async def extract_entities(self, text: str) -> List[Entity]:
|
||||
"""Extract entities from the given text."""
|
||||
pass
|
||||
0
cognee/infrastructure/entities/__init__.py
Normal file
0
cognee/infrastructure/entities/__init__.py
Normal file
0
cognee/tasks/entity_completion/__init__.py
Normal file
0
cognee/tasks/entity_completion/__init__.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
from typing import List
|
||||
|
||||
from cognee.modules.engine.models import Entity
|
||||
from cognee.infrastructure.context.BaseContextProvider import (
|
||||
BaseContextProvider,
|
||||
)
|
||||
|
||||
|
||||
class DummyContextProvider(BaseContextProvider):
|
||||
"""Simple context getter that returns a constant context."""
|
||||
|
||||
async def get_context(self, entities: List[Entity], query: str) -> str:
|
||||
return "Albert Einstein was a theoretical physicist."
|
||||
103
cognee/tasks/entity_completion/entity_completion.py
Normal file
103
cognee/tasks/entity_completion/entity_completion.py
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
from typing import List
|
||||
import logging
|
||||
|
||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt
|
||||
from cognee.infrastructure.entities.BaseEntityExtractor import (
|
||||
BaseEntityExtractor,
|
||||
)
|
||||
from cognee.infrastructure.context.BaseContextProvider import (
|
||||
BaseContextProvider,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("entity_completion")
|
||||
|
||||
# Default prompt template paths
|
||||
DEFAULT_SYSTEM_PROMPT_TEMPLATE = "answer_simple_question.txt"
|
||||
DEFAULT_USER_PROMPT_TEMPLATE = "context_for_question.txt"
|
||||
|
||||
|
||||
async def get_llm_response(
|
||||
query: str,
|
||||
context: str,
|
||||
system_prompt_template: str = None,
|
||||
user_prompt_template: str = None,
|
||||
) -> str:
|
||||
"""Generate LLM response based on query and context."""
|
||||
try:
|
||||
args = {
|
||||
"question": query,
|
||||
"context": context,
|
||||
}
|
||||
user_prompt = render_prompt(user_prompt_template or DEFAULT_USER_PROMPT_TEMPLATE, args)
|
||||
system_prompt = read_query_prompt(system_prompt_template or DEFAULT_SYSTEM_PROMPT_TEMPLATE)
|
||||
|
||||
llm_client = get_llm_client()
|
||||
return await llm_client.acreate_structured_output(
|
||||
text_input=user_prompt,
|
||||
system_prompt=system_prompt,
|
||||
response_model=str,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"LLM response generation failed: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def entity_completion(
|
||||
query: str,
|
||||
extractor: BaseEntityExtractor,
|
||||
context_provider: BaseContextProvider,
|
||||
system_prompt_template: str = None,
|
||||
user_prompt_template: str = None,
|
||||
) -> List[str]:
|
||||
"""Execute entity-based completion using provided components."""
|
||||
if not query or not isinstance(query, str):
|
||||
logger.error("Invalid query type or empty query")
|
||||
return ["Invalid query input"]
|
||||
|
||||
try:
|
||||
logger.info(f"Processing query: {query[:100]}")
|
||||
|
||||
entities = await extractor.extract_entities(query)
|
||||
if not entities:
|
||||
logger.info("No entities extracted")
|
||||
return ["No entities found"]
|
||||
|
||||
context = await context_provider.get_context(entities, query)
|
||||
if not context:
|
||||
logger.info("No context retrieved")
|
||||
return ["No context found"]
|
||||
|
||||
response = await get_llm_response(
|
||||
query, context, system_prompt_template, user_prompt_template
|
||||
)
|
||||
return [response]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Entity completion failed: {str(e)}")
|
||||
return ["Entity completion failed"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# For testing purposes, will be removed by the end of the sprint
|
||||
import asyncio
|
||||
import logging
|
||||
from cognee.tasks.entity_completion.entity_extractors.dummy_entity_extractor import (
|
||||
DummyEntityExtractor,
|
||||
)
|
||||
from cognee.tasks.entity_completion.context_providers.dummy_context_provider import (
|
||||
DummyContextProvider,
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
async def run_entity_completion():
|
||||
# Uses config defaults
|
||||
result = await entity_completion(
|
||||
"Tell me about Einstein",
|
||||
DummyEntityExtractor(),
|
||||
DummyContextProvider(),
|
||||
)
|
||||
print(f"Query Response: {result[0]}")
|
||||
|
||||
asyncio.run(run_entity_completion())
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
from typing import List
|
||||
|
||||
from cognee.modules.engine.models import Entity, EntityType
|
||||
from cognee.infrastructure.entities.BaseEntityExtractor import (
|
||||
BaseEntityExtractor,
|
||||
)
|
||||
|
||||
|
||||
class DummyEntityExtractor(BaseEntityExtractor):
|
||||
"""Simple entity extractor that returns a constant entity."""
|
||||
|
||||
async def extract_entities(self, text: str) -> List[Entity]:
|
||||
entity_type = EntityType(name="Person", description="A human individual")
|
||||
entity = Entity(name="Albert Einstein", is_a=entity_type, description="A famous physicist")
|
||||
return [entity]
|
||||
Loading…
Add table
Reference in a new issue