cognee/cognee/infrastructure/llm/openai/adapter.py
Boris 64b8aac86f
feat: code graph swe integration
Co-authored-by: hajdul88 <52442977+hajdul88@users.noreply.github.com>
Co-authored-by: hande-k <handekafkas7@gmail.com>
Co-authored-by: Igor Ilic <igorilic03@gmail.com>
Co-authored-by: Vasilije <8619304+Vasilije1990@users.noreply.github.com>
Co-authored-by: Igor Ilic <30923996+dexters1@users.noreply.github.com>
2024-11-27 09:32:29 +01:00

134 lines
4.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import base64
from pathlib import Path
from typing import Type
import litellm
import instructor
from pydantic import BaseModel
from cognee.infrastructure.llm.llm_interface import LLMInterface
from cognee.infrastructure.llm.prompts import read_query_prompt
class OpenAIAdapter(LLMInterface):
name = "OpenAI"
model: str
api_key: str
api_version: str
"""Adapter for OpenAI's GPT-3, GPT=4 API"""
def __init__(
self,
api_key: str,
endpoint: str,
api_version: str,
model: str,
transcription_model: str,
streaming: bool = False,
):
self.aclient = instructor.from_litellm(litellm.acompletion)
self.client = instructor.from_litellm(litellm.completion)
self.transcription_model = transcription_model
self.model = model
self.api_key = api_key
self.endpoint = endpoint
self.api_version = api_version
self.streaming = streaming
async def acreate_structured_output(self, text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel:
"""Generate a response from a user query."""
return await self.aclient.chat.completions.create(
model = self.model,
messages = [{
"role": "user",
"content": f"""Use the given format to
extract information from the following input: {text_input}. """,
}, {
"role": "system",
"content": system_prompt,
}],
api_key = self.api_key,
api_base = self.endpoint,
api_version = self.api_version,
response_model = response_model,
max_retries = 5,
)
def create_structured_output(self, text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel:
"""Generate a response from a user query."""
return self.client.chat.completions.create(
model = self.model,
messages = [{
"role": "user",
"content": f"""Use the given format to
extract information from the following input: {text_input}. """,
}, {
"role": "system",
"content": system_prompt,
}],
api_key = self.api_key,
api_base = self.endpoint,
api_version = self.api_version,
response_model = response_model,
max_retries = 5,
)
def create_transcript(self, input):
"""Generate a audio transcript from a user query."""
if not os.path.isfile(input):
raise FileNotFoundError(f"The file {input} does not exist.")
# with open(input, 'rb') as audio_file:
# audio_data = audio_file.read()
transcription = litellm.transcription(
model = self.transcription_model,
file = Path(input),
api_key=self.api_key,
api_base=self.endpoint,
api_version=self.api_version,
max_retries = 5,
)
return transcription
def transcribe_image(self, input) -> BaseModel:
with open(input, "rb") as image_file:
encoded_image = base64.b64encode(image_file.read()).decode('utf-8')
return litellm.completion(
model = self.model,
messages = [{
"role": "user",
"content": [
{
"type": "text",
"text": "Whats in this image?",
}, {
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{encoded_image}",
},
},
],
}],
api_key=self.api_key,
api_base=self.endpoint,
api_version=self.api_version,
max_tokens = 300,
max_retries = 5,
)
def show_prompt(self, text_input: str, system_prompt: str) -> str:
"""Format and display the prompt for a user query."""
if not text_input:
text_input = "No user input provided."
if not system_prompt:
raise ValueError("No system prompt path provided.")
system_prompt = read_query_prompt(system_prompt)
formatted_prompt = f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n""" if system_prompt else None
return formatted_prompt