Set max tokens by prompt (#255)

* set max tokens

* update generic openai client

* mypy updates

* fix: dockerfile

---------

Co-authored-by: paulpaliychuk <pavlo.paliychuk.ca@gmail.com>
This commit is contained in:
Preston Rasmussen 2025-01-24 10:14:49 -05:00 committed by GitHub
parent 77cb67cdfe
commit 0f50b74735
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 1373 additions and 1488 deletions

View file

@ -23,7 +23,7 @@ RUN poetry build && pip install dist/*.whl
# Install server dependencies # Install server dependencies
WORKDIR /app/server WORKDIR /app/server
RUN poetry install --no-interaction --no-ansi --no-dev RUN poetry install --no-interaction --no-ansi --only main --no-root
FROM python:3.12-slim FROM python:3.12-slim

View file

@ -73,12 +73,12 @@ def lucene_sanitize(query: str) -> str:
return sanitized return sanitized
def normalize_l2(embedding: list[float]) -> list[float]: def normalize_l2(embedding: list[float]):
embedding_array = np.array(embedding) embedding_array = np.array(embedding)
if embedding_array.ndim == 1: if embedding_array.ndim == 1:
norm = np.linalg.norm(embedding_array) norm = np.linalg.norm(embedding_array)
if norm == 0: if norm == 0:
return embedding_array.tolist() return [0.0] * len(embedding)
return (embedding_array / norm).tolist() return (embedding_array / norm).tolist()
else: else:
norm = np.linalg.norm(embedding_array, 2, axis=1, keepdims=True) norm = np.linalg.norm(embedding_array, 2, axis=1, keepdims=True)

View file

@ -48,7 +48,10 @@ class AnthropicClient(LLMClient):
) )
async def _generate_response( async def _generate_response(
self, messages: list[Message], response_model: type[BaseModel] | None = None self,
messages: list[Message],
response_model: type[BaseModel] | None = None,
max_tokens: int = DEFAULT_MAX_TOKENS,
) -> dict[str, typing.Any]: ) -> dict[str, typing.Any]:
system_message = messages[0] system_message = messages[0]
user_messages = [{'role': m.role, 'content': m.content} for m in messages[1:]] + [ user_messages = [{'role': m.role, 'content': m.content} for m in messages[1:]] + [
@ -59,7 +62,7 @@ class AnthropicClient(LLMClient):
result = await self.client.messages.create( result = await self.client.messages.create(
system='Only include JSON in the response. Do not include any additional text or explanation of the content.\n' system='Only include JSON in the response. Do not include any additional text or explanation of the content.\n'
+ system_message.content, + system_message.content,
max_tokens=self.max_tokens, max_tokens=max_tokens or self.max_tokens,
temperature=self.temperature, temperature=self.temperature,
messages=user_messages, # type: ignore messages=user_messages, # type: ignore
model=self.model or DEFAULT_MODEL, model=self.model or DEFAULT_MODEL,

View file

@ -26,7 +26,7 @@ from pydantic import BaseModel
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_random_exponential from tenacity import retry, retry_if_exception, stop_after_attempt, wait_random_exponential
from ..prompts.models import Message from ..prompts.models import Message
from .config import LLMConfig from .config import DEFAULT_MAX_TOKENS, LLMConfig
from .errors import RateLimitError from .errors import RateLimitError
DEFAULT_TEMPERATURE = 0 DEFAULT_TEMPERATURE = 0
@ -90,16 +90,22 @@ class LLMClient(ABC):
reraise=True, reraise=True,
) )
async def _generate_response_with_retry( async def _generate_response_with_retry(
self, messages: list[Message], response_model: type[BaseModel] | None = None self,
messages: list[Message],
response_model: type[BaseModel] | None = None,
max_tokens: int = DEFAULT_MAX_TOKENS,
) -> dict[str, typing.Any]: ) -> dict[str, typing.Any]:
try: try:
return await self._generate_response(messages, response_model) return await self._generate_response(messages, response_model, max_tokens)
except (httpx.HTTPStatusError, RateLimitError) as e: except (httpx.HTTPStatusError, RateLimitError) as e:
raise e raise e
@abstractmethod @abstractmethod
async def _generate_response( async def _generate_response(
self, messages: list[Message], response_model: type[BaseModel] | None = None self,
messages: list[Message],
response_model: type[BaseModel] | None = None,
max_tokens: int = DEFAULT_MAX_TOKENS,
) -> dict[str, typing.Any]: ) -> dict[str, typing.Any]:
pass pass
@ -110,7 +116,10 @@ class LLMClient(ABC):
return hashlib.md5(key_str.encode()).hexdigest() return hashlib.md5(key_str.encode()).hexdigest()
async def generate_response( async def generate_response(
self, messages: list[Message], response_model: type[BaseModel] | None = None self,
messages: list[Message],
response_model: type[BaseModel] | None = None,
max_tokens: int = DEFAULT_MAX_TOKENS,
) -> dict[str, typing.Any]: ) -> dict[str, typing.Any]:
if response_model is not None: if response_model is not None:
serialized_model = json.dumps(response_model.model_json_schema()) serialized_model = json.dumps(response_model.model_json_schema())
@ -131,7 +140,7 @@ class LLMClient(ABC):
for message in messages: for message in messages:
message.content = self._clean_input(message.content) message.content = self._clean_input(message.content)
response = await self._generate_response_with_retry(messages, response_model) response = await self._generate_response_with_retry(messages, response_model, max_tokens)
if self.cache_enabled: if self.cache_enabled:
self.cache_dir.set(cache_key, response) self.cache_dir.set(cache_key, response)

View file

@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
DEFAULT_MAX_TOKENS = 16384 DEFAULT_MAX_TOKENS = 1024
DEFAULT_TEMPERATURE = 0 DEFAULT_TEMPERATURE = 0

View file

@ -45,7 +45,10 @@ class GroqClient(LLMClient):
self.client = AsyncGroq(api_key=config.api_key) self.client = AsyncGroq(api_key=config.api_key)
async def _generate_response( async def _generate_response(
self, messages: list[Message], response_model: type[BaseModel] | None = None self,
messages: list[Message],
response_model: type[BaseModel] | None = None,
max_tokens: int = DEFAULT_MAX_TOKENS,
) -> dict[str, typing.Any]: ) -> dict[str, typing.Any]:
msgs: list[ChatCompletionMessageParam] = [] msgs: list[ChatCompletionMessageParam] = []
for m in messages: for m in messages:
@ -58,7 +61,7 @@ class GroqClient(LLMClient):
model=self.model or DEFAULT_MODEL, model=self.model or DEFAULT_MODEL,
messages=msgs, messages=msgs,
temperature=self.temperature, temperature=self.temperature,
max_tokens=self.max_tokens, max_tokens=max_tokens or self.max_tokens,
response_format={'type': 'json_object'}, response_format={'type': 'json_object'},
) )
result = response.choices[0].message.content or '' result = response.choices[0].message.content or ''

View file

@ -25,7 +25,7 @@ from pydantic import BaseModel
from ..prompts.models import Message from ..prompts.models import Message
from .client import LLMClient from .client import LLMClient
from .config import LLMConfig from .config import DEFAULT_MAX_TOKENS, LLMConfig
from .errors import RateLimitError, RefusalError from .errors import RateLimitError, RefusalError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -58,7 +58,11 @@ class OpenAIClient(LLMClient):
MAX_RETRIES: ClassVar[int] = 2 MAX_RETRIES: ClassVar[int] = 2
def __init__( def __init__(
self, config: LLMConfig | None = None, cache: bool = False, client: typing.Any = None self,
config: LLMConfig | None = None,
cache: bool = False,
client: typing.Any = None,
max_tokens: int = DEFAULT_MAX_TOKENS,
): ):
""" """
Initialize the OpenAIClient with the provided configuration, cache setting, and client. Initialize the OpenAIClient with the provided configuration, cache setting, and client.
@ -84,7 +88,10 @@ class OpenAIClient(LLMClient):
self.client = client self.client = client
async def _generate_response( async def _generate_response(
self, messages: list[Message], response_model: type[BaseModel] | None = None self,
messages: list[Message],
response_model: type[BaseModel] | None = None,
max_tokens: int = DEFAULT_MAX_TOKENS,
) -> dict[str, typing.Any]: ) -> dict[str, typing.Any]:
openai_messages: list[ChatCompletionMessageParam] = [] openai_messages: list[ChatCompletionMessageParam] = []
for m in messages: for m in messages:
@ -98,7 +105,7 @@ class OpenAIClient(LLMClient):
model=self.model or DEFAULT_MODEL, model=self.model or DEFAULT_MODEL,
messages=openai_messages, messages=openai_messages,
temperature=self.temperature, temperature=self.temperature,
max_tokens=self.max_tokens, max_tokens=max_tokens or self.max_tokens,
response_format=response_model, # type: ignore response_format=response_model, # type: ignore
) )
@ -119,14 +126,17 @@ class OpenAIClient(LLMClient):
raise raise
async def generate_response( async def generate_response(
self, messages: list[Message], response_model: type[BaseModel] | None = None self,
messages: list[Message],
response_model: type[BaseModel] | None = None,
max_tokens: int = DEFAULT_MAX_TOKENS,
) -> dict[str, typing.Any]: ) -> dict[str, typing.Any]:
retry_count = 0 retry_count = 0
last_error = None last_error = None
while retry_count <= self.MAX_RETRIES: while retry_count <= self.MAX_RETRIES:
try: try:
response = await self._generate_response(messages, response_model) response = await self._generate_response(messages, response_model, max_tokens)
return response return response
except (RateLimitError, RefusalError): except (RateLimitError, RefusalError):
# These errors should not trigger retries # These errors should not trigger retries

View file

@ -26,7 +26,7 @@ from pydantic import BaseModel
from ..prompts.models import Message from ..prompts.models import Message
from .client import LLMClient from .client import LLMClient
from .config import LLMConfig from .config import DEFAULT_MAX_TOKENS, LLMConfig
from .errors import RateLimitError, RefusalError from .errors import RateLimitError, RefusalError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -85,7 +85,10 @@ class OpenAIGenericClient(LLMClient):
self.client = client self.client = client
async def _generate_response( async def _generate_response(
self, messages: list[Message], response_model: type[BaseModel] | None = None self,
messages: list[Message],
response_model: type[BaseModel] | None = None,
max_tokens: int = DEFAULT_MAX_TOKENS,
) -> dict[str, typing.Any]: ) -> dict[str, typing.Any]:
openai_messages: list[ChatCompletionMessageParam] = [] openai_messages: list[ChatCompletionMessageParam] = []
for m in messages: for m in messages:
@ -111,7 +114,10 @@ class OpenAIGenericClient(LLMClient):
raise raise
async def generate_response( async def generate_response(
self, messages: list[Message], response_model: type[BaseModel] | None = None self,
messages: list[Message],
response_model: type[BaseModel] | None = None,
max_tokens: int = DEFAULT_MAX_TOKENS,
) -> dict[str, typing.Any]: ) -> dict[str, typing.Any]:
retry_count = 0 retry_count = 0
last_error = None last_error = None
@ -126,7 +132,9 @@ class OpenAIGenericClient(LLMClient):
while retry_count <= self.MAX_RETRIES: while retry_count <= self.MAX_RETRIES:
try: try:
response = await self._generate_response(messages, response_model) response = await self._generate_response(
messages, response_model, max_tokens=max_tokens
)
return response return response
except (RateLimitError, RefusalError): except (RateLimitError, RefusalError):
# These errors should not trigger retries # These errors should not trigger retries

View file

@ -79,6 +79,8 @@ async def extract_edges(
) -> list[EntityEdge]: ) -> list[EntityEdge]:
start = time() start = time()
EXTRACT_EDGES_MAX_TOKENS = 16384
node_uuids_by_name_map = {node.name: node.uuid for node in nodes} node_uuids_by_name_map = {node.name: node.uuid for node in nodes}
# Prepare context for LLM # Prepare context for LLM
@ -93,7 +95,9 @@ async def extract_edges(
reflexion_iterations = 0 reflexion_iterations = 0
while facts_missed and reflexion_iterations < MAX_REFLEXION_ITERATIONS: while facts_missed and reflexion_iterations < MAX_REFLEXION_ITERATIONS:
llm_response = await llm_client.generate_response( llm_response = await llm_client.generate_response(
prompt_library.extract_edges.edge(context), response_model=ExtractedEdges prompt_library.extract_edges.edge(context),
response_model=ExtractedEdges,
max_tokens=EXTRACT_EDGES_MAX_TOKENS,
) )
edges_data = llm_response.get('edges', []) edges_data = llm_response.get('edges', [])

2772
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "graphiti-core" name = "graphiti-core"
version = "0.5.1" version = "0.5.2"
description = "A temporal graph building library" description = "A temporal graph building library"
authors = [ authors = [
"Paul Paliychuk <paul@getzep.com>", "Paul Paliychuk <paul@getzep.com>",