cognee/cognitive_architecture/infrastructure/llm/openai/openai_tools.py
2024-03-13 10:19:03 +01:00

74 lines
2.4 KiB
Python

"""Tools for interacting with OpenAI's GPT-3, GPT-4 API"""
import asyncio
import os
from typing import List
from tenacity import retry, stop_after_attempt
import openai
HOST = os.getenv("OPENAI_API_BASE")
if HOST is not None:
openai.api_base = HOST
@retry(stop = stop_after_attempt(5))
def completions_with_backoff(**kwargs):
"""Wrapper around ChatCompletion.create w/ backoff"""
# Local model
return openai.chat.completions.create(**kwargs)
@retry(stop = stop_after_attempt(5))
async def acompletions_with_backoff(**kwargs):
"""Wrapper around ChatCompletion.acreate w/ backoff"""
return await openai.chat.completions.acreate(**kwargs)
@retry(stop = stop_after_attempt(5))
async def acreate_embedding_with_backoff(**kwargs):
"""Wrapper around Embedding.acreate w/ backoff"""
client = openai.AsyncOpenAI(
# This is the default and can be omitted
api_key=os.environ.get("OPENAI_API_KEY"),
)
return await client.embeddings.create(**kwargs)
async def async_get_embedding_with_backoff(text, model="text-embedding-ada-002"):
"""To get text embeddings, import/call this function
It specifies defaults + handles rate-limiting + is async"""
text = text.replace("\n", " ")
response = await acreate_embedding_with_backoff(input=[text], model=model)
embedding = response.data[0].embedding
return embedding
@retry(stop = stop_after_attempt(5))
def create_embedding_with_backoff(**kwargs):
"""Wrapper around Embedding.create w/ backoff"""
return openai.embeddings.create(**kwargs)
def get_embedding_with_backoff(text:str, model:str="text-embedding-ada-002"):
"""To get text embeddings, import/call this function
It specifies defaults + handles rate-limiting
:param text: str
:param model: str
"""
text = text.replace("\n", " ")
response = create_embedding_with_backoff(input=[text], model=model)
embedding = response.data[0].embedding
return embedding
async def async_get_batch_embeddings_with_backoff(texts: List[str], models: List[str]) :
"""To get multiple text embeddings in parallel, import/call this function
It specifies defaults + handles rate-limiting + is async"""
# Create a generator of coroutines
coroutines = (async_get_embedding_with_backoff(text, model) for text, model in zip(texts, models))
# Run the coroutines in parallel and gather the results
embeddings = await asyncio.gather(*coroutines)
return embeddings