Set max tokens by prompt (#255)
* set max tokens * update generic openai client * mypy updates * fix: dockerfile --------- Co-authored-by: paulpaliychuk <pavlo.paliychuk.ca@gmail.com>
This commit is contained in:
parent
77cb67cdfe
commit
0f50b74735
11 changed files with 1373 additions and 1488 deletions
|
|
@ -23,7 +23,7 @@ RUN poetry build && pip install dist/*.whl
|
||||||
|
|
||||||
# Install server dependencies
|
# Install server dependencies
|
||||||
WORKDIR /app/server
|
WORKDIR /app/server
|
||||||
RUN poetry install --no-interaction --no-ansi --no-dev
|
RUN poetry install --no-interaction --no-ansi --only main --no-root
|
||||||
|
|
||||||
FROM python:3.12-slim
|
FROM python:3.12-slim
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -73,12 +73,12 @@ def lucene_sanitize(query: str) -> str:
|
||||||
return sanitized
|
return sanitized
|
||||||
|
|
||||||
|
|
||||||
def normalize_l2(embedding: list[float]) -> list[float]:
|
def normalize_l2(embedding: list[float]):
|
||||||
embedding_array = np.array(embedding)
|
embedding_array = np.array(embedding)
|
||||||
if embedding_array.ndim == 1:
|
if embedding_array.ndim == 1:
|
||||||
norm = np.linalg.norm(embedding_array)
|
norm = np.linalg.norm(embedding_array)
|
||||||
if norm == 0:
|
if norm == 0:
|
||||||
return embedding_array.tolist()
|
return [0.0] * len(embedding)
|
||||||
return (embedding_array / norm).tolist()
|
return (embedding_array / norm).tolist()
|
||||||
else:
|
else:
|
||||||
norm = np.linalg.norm(embedding_array, 2, axis=1, keepdims=True)
|
norm = np.linalg.norm(embedding_array, 2, axis=1, keepdims=True)
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,10 @@ class AnthropicClient(LLMClient):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _generate_response(
|
async def _generate_response(
|
||||||
self, messages: list[Message], response_model: type[BaseModel] | None = None
|
self,
|
||||||
|
messages: list[Message],
|
||||||
|
response_model: type[BaseModel] | None = None,
|
||||||
|
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||||
) -> dict[str, typing.Any]:
|
) -> dict[str, typing.Any]:
|
||||||
system_message = messages[0]
|
system_message = messages[0]
|
||||||
user_messages = [{'role': m.role, 'content': m.content} for m in messages[1:]] + [
|
user_messages = [{'role': m.role, 'content': m.content} for m in messages[1:]] + [
|
||||||
|
|
@ -59,7 +62,7 @@ class AnthropicClient(LLMClient):
|
||||||
result = await self.client.messages.create(
|
result = await self.client.messages.create(
|
||||||
system='Only include JSON in the response. Do not include any additional text or explanation of the content.\n'
|
system='Only include JSON in the response. Do not include any additional text or explanation of the content.\n'
|
||||||
+ system_message.content,
|
+ system_message.content,
|
||||||
max_tokens=self.max_tokens,
|
max_tokens=max_tokens or self.max_tokens,
|
||||||
temperature=self.temperature,
|
temperature=self.temperature,
|
||||||
messages=user_messages, # type: ignore
|
messages=user_messages, # type: ignore
|
||||||
model=self.model or DEFAULT_MODEL,
|
model=self.model or DEFAULT_MODEL,
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ from pydantic import BaseModel
|
||||||
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_random_exponential
|
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_random_exponential
|
||||||
|
|
||||||
from ..prompts.models import Message
|
from ..prompts.models import Message
|
||||||
from .config import LLMConfig
|
from .config import DEFAULT_MAX_TOKENS, LLMConfig
|
||||||
from .errors import RateLimitError
|
from .errors import RateLimitError
|
||||||
|
|
||||||
DEFAULT_TEMPERATURE = 0
|
DEFAULT_TEMPERATURE = 0
|
||||||
|
|
@ -90,16 +90,22 @@ class LLMClient(ABC):
|
||||||
reraise=True,
|
reraise=True,
|
||||||
)
|
)
|
||||||
async def _generate_response_with_retry(
|
async def _generate_response_with_retry(
|
||||||
self, messages: list[Message], response_model: type[BaseModel] | None = None
|
self,
|
||||||
|
messages: list[Message],
|
||||||
|
response_model: type[BaseModel] | None = None,
|
||||||
|
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||||
) -> dict[str, typing.Any]:
|
) -> dict[str, typing.Any]:
|
||||||
try:
|
try:
|
||||||
return await self._generate_response(messages, response_model)
|
return await self._generate_response(messages, response_model, max_tokens)
|
||||||
except (httpx.HTTPStatusError, RateLimitError) as e:
|
except (httpx.HTTPStatusError, RateLimitError) as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def _generate_response(
|
async def _generate_response(
|
||||||
self, messages: list[Message], response_model: type[BaseModel] | None = None
|
self,
|
||||||
|
messages: list[Message],
|
||||||
|
response_model: type[BaseModel] | None = None,
|
||||||
|
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||||
) -> dict[str, typing.Any]:
|
) -> dict[str, typing.Any]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
@ -110,7 +116,10 @@ class LLMClient(ABC):
|
||||||
return hashlib.md5(key_str.encode()).hexdigest()
|
return hashlib.md5(key_str.encode()).hexdigest()
|
||||||
|
|
||||||
async def generate_response(
|
async def generate_response(
|
||||||
self, messages: list[Message], response_model: type[BaseModel] | None = None
|
self,
|
||||||
|
messages: list[Message],
|
||||||
|
response_model: type[BaseModel] | None = None,
|
||||||
|
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||||
) -> dict[str, typing.Any]:
|
) -> dict[str, typing.Any]:
|
||||||
if response_model is not None:
|
if response_model is not None:
|
||||||
serialized_model = json.dumps(response_model.model_json_schema())
|
serialized_model = json.dumps(response_model.model_json_schema())
|
||||||
|
|
@ -131,7 +140,7 @@ class LLMClient(ABC):
|
||||||
for message in messages:
|
for message in messages:
|
||||||
message.content = self._clean_input(message.content)
|
message.content = self._clean_input(message.content)
|
||||||
|
|
||||||
response = await self._generate_response_with_retry(messages, response_model)
|
response = await self._generate_response_with_retry(messages, response_model, max_tokens)
|
||||||
|
|
||||||
if self.cache_enabled:
|
if self.cache_enabled:
|
||||||
self.cache_dir.set(cache_key, response)
|
self.cache_dir.set(cache_key, response)
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
DEFAULT_MAX_TOKENS = 16384
|
DEFAULT_MAX_TOKENS = 1024
|
||||||
DEFAULT_TEMPERATURE = 0
|
DEFAULT_TEMPERATURE = 0
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,10 @@ class GroqClient(LLMClient):
|
||||||
self.client = AsyncGroq(api_key=config.api_key)
|
self.client = AsyncGroq(api_key=config.api_key)
|
||||||
|
|
||||||
async def _generate_response(
|
async def _generate_response(
|
||||||
self, messages: list[Message], response_model: type[BaseModel] | None = None
|
self,
|
||||||
|
messages: list[Message],
|
||||||
|
response_model: type[BaseModel] | None = None,
|
||||||
|
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||||
) -> dict[str, typing.Any]:
|
) -> dict[str, typing.Any]:
|
||||||
msgs: list[ChatCompletionMessageParam] = []
|
msgs: list[ChatCompletionMessageParam] = []
|
||||||
for m in messages:
|
for m in messages:
|
||||||
|
|
@ -58,7 +61,7 @@ class GroqClient(LLMClient):
|
||||||
model=self.model or DEFAULT_MODEL,
|
model=self.model or DEFAULT_MODEL,
|
||||||
messages=msgs,
|
messages=msgs,
|
||||||
temperature=self.temperature,
|
temperature=self.temperature,
|
||||||
max_tokens=self.max_tokens,
|
max_tokens=max_tokens or self.max_tokens,
|
||||||
response_format={'type': 'json_object'},
|
response_format={'type': 'json_object'},
|
||||||
)
|
)
|
||||||
result = response.choices[0].message.content or ''
|
result = response.choices[0].message.content or ''
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
from ..prompts.models import Message
|
from ..prompts.models import Message
|
||||||
from .client import LLMClient
|
from .client import LLMClient
|
||||||
from .config import LLMConfig
|
from .config import DEFAULT_MAX_TOKENS, LLMConfig
|
||||||
from .errors import RateLimitError, RefusalError
|
from .errors import RateLimitError, RefusalError
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -58,7 +58,11 @@ class OpenAIClient(LLMClient):
|
||||||
MAX_RETRIES: ClassVar[int] = 2
|
MAX_RETRIES: ClassVar[int] = 2
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, config: LLMConfig | None = None, cache: bool = False, client: typing.Any = None
|
self,
|
||||||
|
config: LLMConfig | None = None,
|
||||||
|
cache: bool = False,
|
||||||
|
client: typing.Any = None,
|
||||||
|
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the OpenAIClient with the provided configuration, cache setting, and client.
|
Initialize the OpenAIClient with the provided configuration, cache setting, and client.
|
||||||
|
|
@ -84,7 +88,10 @@ class OpenAIClient(LLMClient):
|
||||||
self.client = client
|
self.client = client
|
||||||
|
|
||||||
async def _generate_response(
|
async def _generate_response(
|
||||||
self, messages: list[Message], response_model: type[BaseModel] | None = None
|
self,
|
||||||
|
messages: list[Message],
|
||||||
|
response_model: type[BaseModel] | None = None,
|
||||||
|
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||||
) -> dict[str, typing.Any]:
|
) -> dict[str, typing.Any]:
|
||||||
openai_messages: list[ChatCompletionMessageParam] = []
|
openai_messages: list[ChatCompletionMessageParam] = []
|
||||||
for m in messages:
|
for m in messages:
|
||||||
|
|
@ -98,7 +105,7 @@ class OpenAIClient(LLMClient):
|
||||||
model=self.model or DEFAULT_MODEL,
|
model=self.model or DEFAULT_MODEL,
|
||||||
messages=openai_messages,
|
messages=openai_messages,
|
||||||
temperature=self.temperature,
|
temperature=self.temperature,
|
||||||
max_tokens=self.max_tokens,
|
max_tokens=max_tokens or self.max_tokens,
|
||||||
response_format=response_model, # type: ignore
|
response_format=response_model, # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -119,14 +126,17 @@ class OpenAIClient(LLMClient):
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def generate_response(
|
async def generate_response(
|
||||||
self, messages: list[Message], response_model: type[BaseModel] | None = None
|
self,
|
||||||
|
messages: list[Message],
|
||||||
|
response_model: type[BaseModel] | None = None,
|
||||||
|
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||||
) -> dict[str, typing.Any]:
|
) -> dict[str, typing.Any]:
|
||||||
retry_count = 0
|
retry_count = 0
|
||||||
last_error = None
|
last_error = None
|
||||||
|
|
||||||
while retry_count <= self.MAX_RETRIES:
|
while retry_count <= self.MAX_RETRIES:
|
||||||
try:
|
try:
|
||||||
response = await self._generate_response(messages, response_model)
|
response = await self._generate_response(messages, response_model, max_tokens)
|
||||||
return response
|
return response
|
||||||
except (RateLimitError, RefusalError):
|
except (RateLimitError, RefusalError):
|
||||||
# These errors should not trigger retries
|
# These errors should not trigger retries
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
from ..prompts.models import Message
|
from ..prompts.models import Message
|
||||||
from .client import LLMClient
|
from .client import LLMClient
|
||||||
from .config import LLMConfig
|
from .config import DEFAULT_MAX_TOKENS, LLMConfig
|
||||||
from .errors import RateLimitError, RefusalError
|
from .errors import RateLimitError, RefusalError
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -85,7 +85,10 @@ class OpenAIGenericClient(LLMClient):
|
||||||
self.client = client
|
self.client = client
|
||||||
|
|
||||||
async def _generate_response(
|
async def _generate_response(
|
||||||
self, messages: list[Message], response_model: type[BaseModel] | None = None
|
self,
|
||||||
|
messages: list[Message],
|
||||||
|
response_model: type[BaseModel] | None = None,
|
||||||
|
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||||
) -> dict[str, typing.Any]:
|
) -> dict[str, typing.Any]:
|
||||||
openai_messages: list[ChatCompletionMessageParam] = []
|
openai_messages: list[ChatCompletionMessageParam] = []
|
||||||
for m in messages:
|
for m in messages:
|
||||||
|
|
@ -111,7 +114,10 @@ class OpenAIGenericClient(LLMClient):
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def generate_response(
|
async def generate_response(
|
||||||
self, messages: list[Message], response_model: type[BaseModel] | None = None
|
self,
|
||||||
|
messages: list[Message],
|
||||||
|
response_model: type[BaseModel] | None = None,
|
||||||
|
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||||
) -> dict[str, typing.Any]:
|
) -> dict[str, typing.Any]:
|
||||||
retry_count = 0
|
retry_count = 0
|
||||||
last_error = None
|
last_error = None
|
||||||
|
|
@ -126,7 +132,9 @@ class OpenAIGenericClient(LLMClient):
|
||||||
|
|
||||||
while retry_count <= self.MAX_RETRIES:
|
while retry_count <= self.MAX_RETRIES:
|
||||||
try:
|
try:
|
||||||
response = await self._generate_response(messages, response_model)
|
response = await self._generate_response(
|
||||||
|
messages, response_model, max_tokens=max_tokens
|
||||||
|
)
|
||||||
return response
|
return response
|
||||||
except (RateLimitError, RefusalError):
|
except (RateLimitError, RefusalError):
|
||||||
# These errors should not trigger retries
|
# These errors should not trigger retries
|
||||||
|
|
|
||||||
|
|
@ -79,6 +79,8 @@ async def extract_edges(
|
||||||
) -> list[EntityEdge]:
|
) -> list[EntityEdge]:
|
||||||
start = time()
|
start = time()
|
||||||
|
|
||||||
|
EXTRACT_EDGES_MAX_TOKENS = 16384
|
||||||
|
|
||||||
node_uuids_by_name_map = {node.name: node.uuid for node in nodes}
|
node_uuids_by_name_map = {node.name: node.uuid for node in nodes}
|
||||||
|
|
||||||
# Prepare context for LLM
|
# Prepare context for LLM
|
||||||
|
|
@ -93,7 +95,9 @@ async def extract_edges(
|
||||||
reflexion_iterations = 0
|
reflexion_iterations = 0
|
||||||
while facts_missed and reflexion_iterations < MAX_REFLEXION_ITERATIONS:
|
while facts_missed and reflexion_iterations < MAX_REFLEXION_ITERATIONS:
|
||||||
llm_response = await llm_client.generate_response(
|
llm_response = await llm_client.generate_response(
|
||||||
prompt_library.extract_edges.edge(context), response_model=ExtractedEdges
|
prompt_library.extract_edges.edge(context),
|
||||||
|
response_model=ExtractedEdges,
|
||||||
|
max_tokens=EXTRACT_EDGES_MAX_TOKENS,
|
||||||
)
|
)
|
||||||
edges_data = llm_response.get('edges', [])
|
edges_data = llm_response.get('edges', [])
|
||||||
|
|
||||||
|
|
|
||||||
2772
poetry.lock
generated
2772
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "graphiti-core"
|
name = "graphiti-core"
|
||||||
version = "0.5.1"
|
version = "0.5.2"
|
||||||
description = "A temporal graph building library"
|
description = "A temporal graph building library"
|
||||||
authors = [
|
authors = [
|
||||||
"Paul Paliychuk <paul@getzep.com>",
|
"Paul Paliychuk <paul@getzep.com>",
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue