fix: Clean input before passing it to the llm (#238)

* fix: Clean input before passing it to the llm

* chore: Add license

* fix: typo

* chore: Bump graphiti version
This commit is contained in:
Pavlo Paliychuk 2024-12-10 21:27:05 -05:00 committed by GitHub
parent 6814cf7dc0
commit a9091b06ff
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 85 additions and 1 deletions

View file

@ -56,6 +56,29 @@ class LLMClient(ABC):
self.cache_enabled = cache
self.cache_dir = Cache(DEFAULT_CACHE_DIR) # Create a cache directory
def _clean_input(self, input: str) -> str:
"""Clean input string of invalid unicode and control characters.
Args:
input: Raw input string to be cleaned
Returns:
Cleaned string safe for LLM processing
"""
# Clean any invalid Unicode
cleaned = input.encode('utf-8', errors='ignore').decode('utf-8')
# Remove zero-width characters and other invisible unicode
zero_width = '\u200b\u200c\u200d\ufeff\u2060'
for char in zero_width:
cleaned = cleaned.replace(char, '')
# Remove control characters except newlines, returns, and tabs
cleaned = ''.join(char for char in cleaned if ord(char) >= 32 or char in '\n\r\t')
return cleaned
@retry(
stop=stop_after_attempt(4),
wait=wait_random_exponential(multiplier=10, min=5, max=120),
@ -106,6 +129,9 @@ class LLMClient(ABC):
logger.debug(f'Cache hit for {cache_key}')
return cached_response
for message in messages:
message.content = self._clean_input(message.content)
response = await self._generate_response_with_retry(messages, response_model)
if self.cache_enabled:

View file

@ -88,6 +88,7 @@ class OpenAIClient(LLMClient):
) -> dict[str, typing.Any]:
openai_messages: list[ChatCompletionMessageParam] = []
for m in messages:
m.content = self._clean_input(m.content)
if m.role == 'user':
openai_messages.append({'role': 'user', 'content': m.content})
elif m.role == 'system':

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "graphiti-core"
version = "0.5.0pre4"
version = "0.5.0pre5"
description = "A temporal graph building library"
authors = [
"Paul Paliychuk <paul@getzep.com>",

View file

@ -0,0 +1,57 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from graphiti_core.llm_client.client import LLMClient
from graphiti_core.llm_client.config import LLMConfig
class TestLLMClient(LLMClient):
"""Concrete implementation of LLMClient for testing"""
async def _generate_response(self, messages, response_model=None):
return {'content': 'test'}
def test_clean_input():
client = TestLLMClient(LLMConfig())
test_cases = [
# Basic text should remain unchanged
('Hello World', 'Hello World'),
# Control characters should be removed
('Hello\x00World', 'HelloWorld'),
# Newlines, tabs, returns should be preserved
('Hello\nWorld\tTest\r', 'Hello\nWorld\tTest\r'),
# Invalid Unicode should be removed
('Hello\udcdeWorld', 'HelloWorld'),
# Zero-width characters should be removed
('Hello\u200bWorld', 'HelloWorld'),
('Test\ufeffWord', 'TestWord'),
# Multiple issues combined
('Hello\x00\u200b\nWorld\udcde', 'Hello\nWorld'),
# Empty string should remain empty
('', ''),
# Form feed and other control characters from the error case
('{"edges":[{"relation_typ...\f\x04Hn\\?"}]}', '{"edges":[{"relation_typ...Hn\\?"}]}'),
# More specific control character tests
('Hello\x0cWorld', 'HelloWorld'), # form feed \f
('Hello\x04World', 'HelloWorld'), # end of transmission
# Combined JSON-like string with control characters
('{"test": "value\f\x00\x04"}', '{"test": "value"}'),
]
for input_str, expected in test_cases:
assert client._clean_input(input_str) == expected, f'Failed for input: {repr(input_str)}'