74 lines
2.4 KiB
Python
74 lines
2.4 KiB
Python
"""Tools for interacting with OpenAI's GPT-3, GPT-4 API"""
|
|
import asyncio
|
|
import os
|
|
from typing import List
|
|
from tenacity import retry, stop_after_attempt
|
|
import openai
|
|
|
|
HOST = os.getenv("OPENAI_API_BASE")
|
|
|
|
if HOST is not None:
|
|
openai.api_base = HOST
|
|
|
|
|
|
@retry(stop = stop_after_attempt(5))
|
|
def completions_with_backoff(**kwargs):
|
|
"""Wrapper around ChatCompletion.create w/ backoff"""
|
|
# Local model
|
|
return openai.chat.completions.create(**kwargs)
|
|
|
|
@retry(stop = stop_after_attempt(5))
|
|
async def acompletions_with_backoff(**kwargs):
|
|
"""Wrapper around ChatCompletion.acreate w/ backoff"""
|
|
return await openai.chat.completions.acreate(**kwargs)
|
|
|
|
|
|
@retry(stop = stop_after_attempt(5))
|
|
async def acreate_embedding_with_backoff(**kwargs):
|
|
"""Wrapper around Embedding.acreate w/ backoff"""
|
|
|
|
client = openai.AsyncOpenAI(
|
|
# This is the default and can be omitted
|
|
api_key=os.environ.get("OPENAI_API_KEY"),
|
|
)
|
|
|
|
return await client.embeddings.create(**kwargs)
|
|
|
|
|
|
async def async_get_embedding_with_backoff(text, model="text-embedding-ada-002"):
|
|
"""To get text embeddings, import/call this function
|
|
It specifies defaults + handles rate-limiting + is async"""
|
|
text = text.replace("\n", " ")
|
|
response = await acreate_embedding_with_backoff(input=[text], model=model)
|
|
embedding = response.data[0].embedding
|
|
return embedding
|
|
|
|
|
|
@retry(stop = stop_after_attempt(5))
|
|
def create_embedding_with_backoff(**kwargs):
|
|
"""Wrapper around Embedding.create w/ backoff"""
|
|
return openai.embeddings.create(**kwargs)
|
|
|
|
|
|
def get_embedding_with_backoff(text:str, model:str="text-embedding-ada-002"):
|
|
"""To get text embeddings, import/call this function
|
|
It specifies defaults + handles rate-limiting
|
|
:param text: str
|
|
:param model: str
|
|
"""
|
|
text = text.replace("\n", " ")
|
|
response = create_embedding_with_backoff(input=[text], model=model)
|
|
embedding = response.data[0].embedding
|
|
return embedding
|
|
|
|
|
|
async def async_get_batch_embeddings_with_backoff(texts: List[str], models: List[str]) :
|
|
"""To get multiple text embeddings in parallel, import/call this function
|
|
It specifies defaults + handles rate-limiting + is async"""
|
|
# Create a generator of coroutines
|
|
coroutines = (async_get_embedding_with_backoff(text, model) for text, model in zip(texts, models))
|
|
|
|
# Run the coroutines in parallel and gather the results
|
|
embeddings = await asyncio.gather(*coroutines)
|
|
|
|
return embeddings
|