cognee/cognee/infrastructure/llm/openai/adapter.py
2024-11-28 11:49:28 +01:00

122 lines
3.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import base64
from pathlib import Path
from typing import Type
import litellm
import instructor
from pydantic import BaseModel
from cognee.infrastructure.llm.llm_interface import LLMInterface
class OpenAIAdapter(LLMInterface):
name = "OpenAI"
model: str
api_key: str
api_version: str
"""Adapter for OpenAI's GPT-3, GPT=4 API"""
def __init__(
self,
api_key: str,
endpoint: str,
api_version: str,
model: str,
transcription_model: str,
streaming: bool = False,
):
self.aclient = instructor.from_litellm(litellm.acompletion)
self.client = instructor.from_litellm(litellm.completion)
self.transcription_model = transcription_model
self.model = model
self.api_key = api_key
self.endpoint = endpoint
self.api_version = api_version
self.streaming = streaming
async def acreate_structured_output(self, text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel:
"""Generate a response from a user query."""
return await self.aclient.chat.completions.create(
model = self.model,
messages = [{
"role": "user",
"content": f"""Use the given format to
extract information from the following input: {text_input}. """,
}, {
"role": "system",
"content": system_prompt,
}],
api_key = self.api_key,
api_base = self.endpoint,
api_version = self.api_version,
response_model = response_model,
max_retries = 5,
)
def create_structured_output(self, text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel:
"""Generate a response from a user query."""
return self.client.chat.completions.create(
model = self.model,
messages = [{
"role": "user",
"content": f"""Use the given format to
extract information from the following input: {text_input}. """,
}, {
"role": "system",
"content": system_prompt,
}],
api_key = self.api_key,
api_base = self.endpoint,
api_version = self.api_version,
response_model = response_model,
max_retries = 5,
)
def create_transcript(self, input):
"""Generate a audio transcript from a user query."""
if not os.path.isfile(input):
raise FileNotFoundError(f"The file {input} does not exist.")
# with open(input, 'rb') as audio_file:
# audio_data = audio_file.read()
transcription = litellm.transcription(
model = self.transcription_model,
file = Path(input),
api_key=self.api_key,
api_base=self.endpoint,
api_version=self.api_version,
max_retries = 5,
)
return transcription
def transcribe_image(self, input) -> BaseModel:
with open(input, "rb") as image_file:
encoded_image = base64.b64encode(image_file.read()).decode('utf-8')
return litellm.completion(
model = self.model,
messages = [{
"role": "user",
"content": [
{
"type": "text",
"text": "Whats in this image?",
}, {
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{encoded_image}",
},
},
],
}],
api_key=self.api_key,
api_base=self.endpoint,
api_version=self.api_version,
max_tokens = 300,
max_retries = 5,
)