feat: entity completion skeleton [COG-1318] (#552)
<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> - Modular implementation of entity completion search - Added base classes that define entity extractors and context providers - Created dummy implementations that return test data - Set up adapters that let us switch between different entity extractors and context providers using strings - Added configuration class to control which implementations to use - Entity completion: query → find entities → get context → interact with LLM → return answer ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Enhanced the query completion experience with integrated language model response generation, improved validation, and robust error handling. - Introduced sample modules for context retrieval and entity extraction that simulate key processing steps. - Established foundational abstractions to support flexible context and entity handling strategies. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: hajdul88 <52442977+hajdul88@users.noreply.github.com>
This commit is contained in:
parent
a788875117
commit
55411ff44b
10 changed files with 157 additions and 0 deletions
13
cognee/infrastructure/context/BaseContextProvider.py
Normal file
13
cognee/infrastructure/context/BaseContextProvider.py
Normal file
|
|
@ -0,0 +1,13 @@
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from cognee.infrastructure.engine import DataPoint
|
||||||
|
|
||||||
|
|
||||||
|
class BaseContextProvider(ABC):
|
||||||
|
"""Base class for context retrieval strategies."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_context(self, entities: List[DataPoint], query: str) -> str:
|
||||||
|
"""Get relevant context based on extracted entities and original query."""
|
||||||
|
pass
|
||||||
0
cognee/infrastructure/context/__init__.py
Normal file
0
cognee/infrastructure/context/__init__.py
Normal file
13
cognee/infrastructure/entities/BaseEntityExtractor.py
Normal file
13
cognee/infrastructure/entities/BaseEntityExtractor.py
Normal file
|
|
@ -0,0 +1,13 @@
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from cognee.modules.engine.models import Entity
|
||||||
|
|
||||||
|
|
||||||
|
class BaseEntityExtractor(ABC):
|
||||||
|
"""Base class for entity extraction strategies."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def extract_entities(self, text: str) -> List[Entity]:
|
||||||
|
"""Extract entities from the given text."""
|
||||||
|
pass
|
||||||
0
cognee/infrastructure/entities/__init__.py
Normal file
0
cognee/infrastructure/entities/__init__.py
Normal file
0
cognee/tasks/entity_completion/__init__.py
Normal file
0
cognee/tasks/entity_completion/__init__.py
Normal file
|
|
@ -0,0 +1,13 @@
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from cognee.modules.engine.models import Entity
|
||||||
|
from cognee.infrastructure.context.BaseContextProvider import (
|
||||||
|
BaseContextProvider,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DummyContextProvider(BaseContextProvider):
|
||||||
|
"""Simple context getter that returns a constant context."""
|
||||||
|
|
||||||
|
async def get_context(self, entities: List[Entity], query: str) -> str:
|
||||||
|
return "Albert Einstein was a theoretical physicist."
|
||||||
103
cognee/tasks/entity_completion/entity_completion.py
Normal file
103
cognee/tasks/entity_completion/entity_completion.py
Normal file
|
|
@ -0,0 +1,103 @@
|
||||||
|
from typing import List
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||||
|
from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt
|
||||||
|
from cognee.infrastructure.entities.BaseEntityExtractor import (
|
||||||
|
BaseEntityExtractor,
|
||||||
|
)
|
||||||
|
from cognee.infrastructure.context.BaseContextProvider import (
|
||||||
|
BaseContextProvider,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger("entity_completion")
|
||||||
|
|
||||||
|
# Default prompt template paths
|
||||||
|
DEFAULT_SYSTEM_PROMPT_TEMPLATE = "answer_simple_question.txt"
|
||||||
|
DEFAULT_USER_PROMPT_TEMPLATE = "context_for_question.txt"
|
||||||
|
|
||||||
|
|
||||||
|
async def get_llm_response(
|
||||||
|
query: str,
|
||||||
|
context: str,
|
||||||
|
system_prompt_template: str = None,
|
||||||
|
user_prompt_template: str = None,
|
||||||
|
) -> str:
|
||||||
|
"""Generate LLM response based on query and context."""
|
||||||
|
try:
|
||||||
|
args = {
|
||||||
|
"question": query,
|
||||||
|
"context": context,
|
||||||
|
}
|
||||||
|
user_prompt = render_prompt(user_prompt_template or DEFAULT_USER_PROMPT_TEMPLATE, args)
|
||||||
|
system_prompt = read_query_prompt(system_prompt_template or DEFAULT_SYSTEM_PROMPT_TEMPLATE)
|
||||||
|
|
||||||
|
llm_client = get_llm_client()
|
||||||
|
return await llm_client.acreate_structured_output(
|
||||||
|
text_input=user_prompt,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
response_model=str,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"LLM response generation failed: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
async def entity_completion(
|
||||||
|
query: str,
|
||||||
|
extractor: BaseEntityExtractor,
|
||||||
|
context_provider: BaseContextProvider,
|
||||||
|
system_prompt_template: str = None,
|
||||||
|
user_prompt_template: str = None,
|
||||||
|
) -> List[str]:
|
||||||
|
"""Execute entity-based completion using provided components."""
|
||||||
|
if not query or not isinstance(query, str):
|
||||||
|
logger.error("Invalid query type or empty query")
|
||||||
|
return ["Invalid query input"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(f"Processing query: {query[:100]}")
|
||||||
|
|
||||||
|
entities = await extractor.extract_entities(query)
|
||||||
|
if not entities:
|
||||||
|
logger.info("No entities extracted")
|
||||||
|
return ["No entities found"]
|
||||||
|
|
||||||
|
context = await context_provider.get_context(entities, query)
|
||||||
|
if not context:
|
||||||
|
logger.info("No context retrieved")
|
||||||
|
return ["No context found"]
|
||||||
|
|
||||||
|
response = await get_llm_response(
|
||||||
|
query, context, system_prompt_template, user_prompt_template
|
||||||
|
)
|
||||||
|
return [response]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Entity completion failed: {str(e)}")
|
||||||
|
return ["Entity completion failed"]
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# For testing purposes, will be removed by the end of the sprint
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from cognee.tasks.entity_completion.entity_extractors.dummy_entity_extractor import (
|
||||||
|
DummyEntityExtractor,
|
||||||
|
)
|
||||||
|
from cognee.tasks.entity_completion.context_providers.dummy_context_provider import (
|
||||||
|
DummyContextProvider,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
async def run_entity_completion():
|
||||||
|
# Uses config defaults
|
||||||
|
result = await entity_completion(
|
||||||
|
"Tell me about Einstein",
|
||||||
|
DummyEntityExtractor(),
|
||||||
|
DummyContextProvider(),
|
||||||
|
)
|
||||||
|
print(f"Query Response: {result[0]}")
|
||||||
|
|
||||||
|
asyncio.run(run_entity_completion())
|
||||||
|
|
@ -0,0 +1,15 @@
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from cognee.modules.engine.models import Entity, EntityType
|
||||||
|
from cognee.infrastructure.entities.BaseEntityExtractor import (
|
||||||
|
BaseEntityExtractor,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DummyEntityExtractor(BaseEntityExtractor):
|
||||||
|
"""Simple entity extractor that returns a constant entity."""
|
||||||
|
|
||||||
|
async def extract_entities(self, text: str) -> List[Entity]:
|
||||||
|
entity_type = EntityType(name="Person", description="A human individual")
|
||||||
|
entity = Entity(name="Albert Einstein", is_a=entity_type, description="A famous physicist")
|
||||||
|
return [entity]
|
||||||
Loading…
Add table
Reference in a new issue