cognee/cognee/infrastructure/llm/gemini/adapter.py
Igor Ilic 9305f43d8e
Revert "feat: Change Cognee data models to work with Gemini [COG-1352]" (#596)
Reverts topoteretes/cognee#594

DCO Affirmation

I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **New Features**
- Enhanced AI responses now deliver structured JSON output with clearly
defined sections, improving clarity and consistency.
- Standardized knowledge graph definitions provide a uniform
representation, simplifying integration and interpretation.



<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2025-03-03 17:52:51 +01:00

165 lines
6.1 KiB
Python

from typing import Type, Optional
from pydantic import BaseModel
import logging
import litellm
import asyncio
from litellm import acompletion, JSONSchemaValidationError
from cognee.shared.data_models import MonitoringTool
from cognee.exceptions import InvalidValueError
from cognee.infrastructure.llm.llm_interface import LLMInterface
from cognee.infrastructure.llm.prompts import read_query_prompt
from cognee.base_config import get_base_config
logger = logging.getLogger(__name__)
monitoring = get_base_config().monitoring_tool
if monitoring == MonitoringTool.LANGFUSE:
from langfuse.decorators import observe
class GeminiAdapter(LLMInterface):
MAX_RETRIES = 5
def __init__(
self,
api_key: str,
model: str,
max_tokens: int,
endpoint: Optional[str] = None,
api_version: Optional[str] = None,
streaming: bool = False,
) -> None:
self.api_key = api_key
self.model = model
self.endpoint = endpoint
self.api_version = api_version
self.streaming = streaming
self.max_tokens = max_tokens
@observe(as_type="generation")
async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel:
try:
if response_model is str:
simplified_prompt = system_prompt
response_schema = {"type": "string"}
else:
response_schema = {
"type": "object",
"properties": {
"summary": {"type": "string"},
"description": {"type": "string"},
"nodes": {
"type": "array",
"items": {
"type": "object",
"properties": {
"name": {"type": "string"},
"type": {"type": "string"},
"description": {"type": "string"},
"id": {"type": "string"},
"label": {"type": "string"},
},
"required": ["name", "type", "description", "id", "label"],
},
},
"edges": {
"type": "array",
"items": {
"type": "object",
"properties": {
"source_node_id": {"type": "string"},
"target_node_id": {"type": "string"},
"relationship_name": {"type": "string"},
},
"required": [
"source_node_id",
"target_node_id",
"relationship_name",
],
},
},
},
"required": ["summary", "description", "nodes", "edges"],
}
simplified_prompt = f"""
{system_prompt}
IMPORTANT: Your response must be a valid JSON object with these required fields:
1. summary: A brief summary
2. description: A detailed description
3. nodes: Array of nodes with name, type, description, id, and label
4. edges: Array of edges with source_node_id, target_node_id, and relationship_name
Example structure:
{{
"summary": "Brief summary",
"description": "Detailed description",
"nodes": [
{{
"name": "Example Node",
"type": "Concept",
"description": "Node description",
"id": "example-id",
"label": "Concept"
}}
],
"edges": [
{{
"source_node_id": "source-id",
"target_node_id": "target-id",
"relationship_name": "relates_to"
}}
]
}}"""
messages = [
{"role": "system", "content": simplified_prompt},
{"role": "user", "content": text_input},
]
try:
response = await acompletion(
model=f"{self.model}",
messages=messages,
api_key=self.api_key,
max_tokens=self.max_tokens,
temperature=0.1,
response_format={"type": "json_object", "schema": response_schema},
timeout=100,
num_retries=self.MAX_RETRIES,
)
if response.choices and response.choices[0].message.content:
content = response.choices[0].message.content
if response_model is str:
return content
return response_model.model_validate_json(content)
except litellm.exceptions.BadRequestError as e:
logger.error(f"Bad request error: {str(e)}")
raise ValueError(f"Invalid request: {str(e)}")
raise ValueError("Failed to get valid response after retries")
except JSONSchemaValidationError as e:
logger.error(f"Schema validation failed: {str(e)}")
logger.debug(f"Raw response: {e.raw_response}")
raise ValueError(f"Response failed schema validation: {str(e)}")
def show_prompt(self, text_input: str, system_prompt: str) -> str:
"""Format and display the prompt for a user query."""
if not text_input:
text_input = "No user input provided."
if not system_prompt:
raise InvalidValueError(message="No system prompt path provided.")
system_prompt = read_query_prompt(system_prompt)
formatted_prompt = (
f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n"""
if system_prompt
else None
)
return formatted_prompt