211 lines
7.2 KiB
Python
211 lines
7.2 KiB
Python
import os
|
|
import base64
|
|
import litellm
|
|
import instructor
|
|
from typing import Type
|
|
from pydantic import BaseModel
|
|
from openai import ContentFilterFinishReasonError
|
|
|
|
from cognee.exceptions import InvalidValueError
|
|
from cognee.infrastructure.llm.prompts import read_query_prompt
|
|
from cognee.infrastructure.llm.llm_interface import LLMInterface
|
|
from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
|
|
from cognee.modules.data.processing.document_types.open_data_file import open_data_file
|
|
from cognee.infrastructure.llm.rate_limiter import (
|
|
rate_limit_async,
|
|
rate_limit_sync,
|
|
sleep_and_retry_async,
|
|
sleep_and_retry_sync,
|
|
)
|
|
from cognee.modules.observability.get_observe import get_observe
|
|
|
|
observe = get_observe()
|
|
|
|
|
|
class OpenAIAdapter(LLMInterface):
|
|
name = "OpenAI"
|
|
model: str
|
|
api_key: str
|
|
api_version: str
|
|
|
|
MAX_RETRIES = 5
|
|
|
|
"""Adapter for OpenAI's GPT-3, GPT=4 API"""
|
|
|
|
def __init__(
|
|
self,
|
|
api_key: str,
|
|
endpoint: str,
|
|
api_version: str,
|
|
model: str,
|
|
transcription_model: str,
|
|
max_tokens: int,
|
|
streaming: bool = False,
|
|
fallback_model: str = None,
|
|
fallback_api_key: str = None,
|
|
fallback_endpoint: str = None,
|
|
):
|
|
self.aclient = instructor.from_litellm(litellm.acompletion)
|
|
self.client = instructor.from_litellm(litellm.completion)
|
|
self.transcription_model = transcription_model
|
|
self.model = model
|
|
self.api_key = api_key
|
|
self.endpoint = endpoint
|
|
self.api_version = api_version
|
|
self.max_tokens = max_tokens
|
|
self.streaming = streaming
|
|
|
|
self.fallback_model = fallback_model
|
|
self.fallback_api_key = fallback_api_key
|
|
self.fallback_endpoint = fallback_endpoint
|
|
|
|
@observe(as_type="generation")
|
|
@sleep_and_retry_async()
|
|
@rate_limit_async
|
|
async def acreate_structured_output(
|
|
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
|
) -> BaseModel:
|
|
"""Generate a response from a user query."""
|
|
|
|
try:
|
|
return await self.aclient.chat.completions.create(
|
|
model=self.model,
|
|
messages=[
|
|
{
|
|
"role": "user",
|
|
"content": f"""Use the given format to
|
|
extract information from the following input: {text_input}. """,
|
|
},
|
|
{
|
|
"role": "system",
|
|
"content": system_prompt,
|
|
},
|
|
],
|
|
api_key=self.api_key,
|
|
api_base=self.endpoint,
|
|
api_version=self.api_version,
|
|
response_model=response_model,
|
|
max_retries=self.MAX_RETRIES,
|
|
)
|
|
except ContentFilterFinishReasonError:
|
|
if not (self.fallback_model and self.fallback_api_key):
|
|
raise ContentPolicyFilterError(
|
|
f"The provided input contains content that is not aligned with our content policy: {text_input}"
|
|
)
|
|
|
|
try:
|
|
return await self.aclient.chat.completions.create(
|
|
model=self.fallback_model,
|
|
messages=[
|
|
{
|
|
"role": "user",
|
|
"content": f"""Use the given format to
|
|
extract information from the following input: {text_input}. """,
|
|
},
|
|
{
|
|
"role": "system",
|
|
"content": system_prompt,
|
|
},
|
|
],
|
|
api_key=self.fallback_api_key,
|
|
# api_base=self.fallback_endpoint,
|
|
response_model=response_model,
|
|
max_retries=self.MAX_RETRIES,
|
|
)
|
|
except ContentFilterFinishReasonError:
|
|
raise ContentPolicyFilterError(
|
|
f"The provided input contains content that is not aligned with our content policy: {text_input}"
|
|
)
|
|
|
|
@observe
|
|
@sleep_and_retry_sync()
|
|
@rate_limit_sync
|
|
def create_structured_output(
|
|
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
|
) -> BaseModel:
|
|
"""Generate a response from a user query."""
|
|
|
|
return self.client.chat.completions.create(
|
|
model=self.model,
|
|
messages=[
|
|
{
|
|
"role": "user",
|
|
"content": f"""Use the given format to
|
|
extract information from the following input: {text_input}. """,
|
|
},
|
|
{
|
|
"role": "system",
|
|
"content": system_prompt,
|
|
},
|
|
],
|
|
api_key=self.api_key,
|
|
api_base=self.endpoint,
|
|
api_version=self.api_version,
|
|
response_model=response_model,
|
|
max_retries=self.MAX_RETRIES,
|
|
)
|
|
|
|
@rate_limit_sync
|
|
def create_transcript(self, input):
|
|
"""Generate a audio transcript from a user query."""
|
|
|
|
if not input.startswith("s3://") and not os.path.isfile(input):
|
|
raise FileNotFoundError(f"The file {input} does not exist.")
|
|
|
|
with open_data_file(input, mode="rb") as audio_file:
|
|
transcription = litellm.transcription(
|
|
model=self.transcription_model,
|
|
file=audio_file,
|
|
api_key=self.api_key,
|
|
api_base=self.endpoint,
|
|
api_version=self.api_version,
|
|
max_retries=self.MAX_RETRIES,
|
|
)
|
|
|
|
return transcription
|
|
|
|
@rate_limit_sync
|
|
def transcribe_image(self, input) -> BaseModel:
|
|
with open_data_file(input, mode="rb") as image_file:
|
|
encoded_image = base64.b64encode(image_file.read()).decode("utf-8")
|
|
|
|
return litellm.completion(
|
|
model=self.model,
|
|
messages=[
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": "What's in this image?",
|
|
},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url": f"data:image/jpeg;base64,{encoded_image}",
|
|
},
|
|
},
|
|
],
|
|
}
|
|
],
|
|
api_key=self.api_key,
|
|
api_base=self.endpoint,
|
|
api_version=self.api_version,
|
|
max_tokens=300,
|
|
max_retries=self.MAX_RETRIES,
|
|
)
|
|
|
|
def show_prompt(self, text_input: str, system_prompt: str) -> str:
|
|
"""Format and display the prompt for a user query."""
|
|
if not text_input:
|
|
text_input = "No user input provided."
|
|
if not system_prompt:
|
|
raise InvalidValueError(message="No system prompt path provided.")
|
|
system_prompt = read_query_prompt(system_prompt)
|
|
|
|
formatted_prompt = (
|
|
f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n"""
|
|
if system_prompt
|
|
else None
|
|
)
|
|
return formatted_prompt
|