parent
8b7ad6f84c
commit
d2654003ff
10 changed files with 91 additions and 87 deletions
|
|
@ -40,8 +40,8 @@ from graphiti_core.nodes import EpisodeType
|
|||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S',
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -49,20 +49,20 @@ load_dotenv()
|
|||
|
||||
# Neo4j connection parameters
|
||||
# Make sure Neo4j Desktop is running with a local DBMS started
|
||||
neo4j_uri = os.environ.get("NEO4J_URI", "bolt://localhost:7687")
|
||||
neo4j_user = os.environ.get("NEO4J_USER", "neo4j")
|
||||
neo4j_password = os.environ.get("NEO4J_PASSWORD", "password")
|
||||
neo4j_uri = os.environ.get('NEO4J_URI', 'bolt://localhost:7687')
|
||||
neo4j_user = os.environ.get('NEO4J_USER', 'neo4j')
|
||||
neo4j_password = os.environ.get('NEO4J_PASSWORD', 'password')
|
||||
|
||||
# Azure OpenAI connection parameters
|
||||
azure_endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT")
|
||||
azure_api_key = os.environ.get("AZURE_OPENAI_API_KEY")
|
||||
azure_deployment = os.environ.get("AZURE_OPENAI_DEPLOYMENT", "gpt-4.1")
|
||||
azure_endpoint = os.environ.get('AZURE_OPENAI_ENDPOINT')
|
||||
azure_api_key = os.environ.get('AZURE_OPENAI_API_KEY')
|
||||
azure_deployment = os.environ.get('AZURE_OPENAI_DEPLOYMENT', 'gpt-4.1')
|
||||
azure_embedding_deployment = os.environ.get(
|
||||
"AZURE_OPENAI_EMBEDDING_DEPLOYMENT", "text-embedding-3-small"
|
||||
'AZURE_OPENAI_EMBEDDING_DEPLOYMENT', 'text-embedding-3-small'
|
||||
)
|
||||
|
||||
if not azure_endpoint or not azure_api_key:
|
||||
raise ValueError("AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_API_KEY must be set")
|
||||
raise ValueError('AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_API_KEY must be set')
|
||||
|
||||
|
||||
async def main():
|
||||
|
|
@ -76,7 +76,7 @@ async def main():
|
|||
|
||||
# Initialize Azure OpenAI client
|
||||
azure_client = AsyncOpenAI(
|
||||
base_url=f"{azure_endpoint}/openai/v1/",
|
||||
base_url=f'{azure_endpoint}/openai/v1/',
|
||||
api_key=azure_api_key,
|
||||
)
|
||||
|
||||
|
|
@ -112,40 +112,40 @@ async def main():
|
|||
# Episodes list containing both text and JSON episodes
|
||||
episodes = [
|
||||
{
|
||||
"content": "Kamala Harris is the Attorney General of California. She was previously "
|
||||
"the district attorney for San Francisco.",
|
||||
"type": EpisodeType.text,
|
||||
"description": "podcast transcript",
|
||||
'content': 'Kamala Harris is the Attorney General of California. She was previously '
|
||||
'the district attorney for San Francisco.',
|
||||
'type': EpisodeType.text,
|
||||
'description': 'podcast transcript',
|
||||
},
|
||||
{
|
||||
"content": "As AG, Harris was in office from January 3, 2011 – January 3, 2017",
|
||||
"type": EpisodeType.text,
|
||||
"description": "podcast transcript",
|
||||
'content': 'As AG, Harris was in office from January 3, 2011 – January 3, 2017',
|
||||
'type': EpisodeType.text,
|
||||
'description': 'podcast transcript',
|
||||
},
|
||||
{
|
||||
"content": {
|
||||
"name": "Gavin Newsom",
|
||||
"position": "Governor",
|
||||
"state": "California",
|
||||
"previous_role": "Lieutenant Governor",
|
||||
"previous_location": "San Francisco",
|
||||
'content': {
|
||||
'name': 'Gavin Newsom',
|
||||
'position': 'Governor',
|
||||
'state': 'California',
|
||||
'previous_role': 'Lieutenant Governor',
|
||||
'previous_location': 'San Francisco',
|
||||
},
|
||||
"type": EpisodeType.json,
|
||||
"description": "podcast metadata",
|
||||
'type': EpisodeType.json,
|
||||
'description': 'podcast metadata',
|
||||
},
|
||||
]
|
||||
|
||||
# Add episodes to the graph
|
||||
for i, episode in enumerate(episodes):
|
||||
await graphiti.add_episode(
|
||||
name=f"California Politics {i}",
|
||||
name=f'California Politics {i}',
|
||||
episode_body=(
|
||||
episode["content"]
|
||||
if isinstance(episode["content"], str)
|
||||
else json.dumps(episode["content"])
|
||||
episode['content']
|
||||
if isinstance(episode['content'], str)
|
||||
else json.dumps(episode['content'])
|
||||
),
|
||||
source=episode["type"],
|
||||
source_description=episode["description"],
|
||||
source=episode['type'],
|
||||
source_description=episode['description'],
|
||||
reference_time=datetime.now(timezone.utc),
|
||||
)
|
||||
print(f'Added episode: California Politics {i} ({episode["type"].value})')
|
||||
|
|
@ -161,18 +161,18 @@ async def main():
|
|||
|
||||
# Perform a hybrid search combining semantic similarity and BM25 retrieval
|
||||
print("\nSearching for: 'Who was the California Attorney General?'")
|
||||
results = await graphiti.search("Who was the California Attorney General?")
|
||||
results = await graphiti.search('Who was the California Attorney General?')
|
||||
|
||||
# Print search results
|
||||
print("\nSearch Results:")
|
||||
print('\nSearch Results:')
|
||||
for result in results:
|
||||
print(f"UUID: {result.uuid}")
|
||||
print(f"Fact: {result.fact}")
|
||||
if hasattr(result, "valid_at") and result.valid_at:
|
||||
print(f"Valid from: {result.valid_at}")
|
||||
if hasattr(result, "invalid_at") and result.invalid_at:
|
||||
print(f"Valid until: {result.invalid_at}")
|
||||
print("---")
|
||||
print(f'UUID: {result.uuid}')
|
||||
print(f'Fact: {result.fact}')
|
||||
if hasattr(result, 'valid_at') and result.valid_at:
|
||||
print(f'Valid from: {result.valid_at}')
|
||||
if hasattr(result, 'invalid_at') and result.invalid_at:
|
||||
print(f'Valid until: {result.invalid_at}')
|
||||
print('---')
|
||||
|
||||
#################################################
|
||||
# CENTER NODE SEARCH
|
||||
|
|
@ -187,26 +187,26 @@ async def main():
|
|||
# Get the source node UUID from the top result
|
||||
center_node_uuid = results[0].source_node_uuid
|
||||
|
||||
print("\nReranking search results based on graph distance:")
|
||||
print(f"Using center node UUID: {center_node_uuid}")
|
||||
print('\nReranking search results based on graph distance:')
|
||||
print(f'Using center node UUID: {center_node_uuid}')
|
||||
|
||||
reranked_results = await graphiti.search(
|
||||
"Who was the California Attorney General?",
|
||||
'Who was the California Attorney General?',
|
||||
center_node_uuid=center_node_uuid,
|
||||
)
|
||||
|
||||
# Print reranked search results
|
||||
print("\nReranked Search Results:")
|
||||
print('\nReranked Search Results:')
|
||||
for result in reranked_results:
|
||||
print(f"UUID: {result.uuid}")
|
||||
print(f"Fact: {result.fact}")
|
||||
if hasattr(result, "valid_at") and result.valid_at:
|
||||
print(f"Valid from: {result.valid_at}")
|
||||
if hasattr(result, "invalid_at") and result.invalid_at:
|
||||
print(f"Valid until: {result.invalid_at}")
|
||||
print("---")
|
||||
print(f'UUID: {result.uuid}')
|
||||
print(f'Fact: {result.fact}')
|
||||
if hasattr(result, 'valid_at') and result.valid_at:
|
||||
print(f'Valid from: {result.valid_at}')
|
||||
if hasattr(result, 'invalid_at') and result.invalid_at:
|
||||
print(f'Valid until: {result.invalid_at}')
|
||||
print('---')
|
||||
else:
|
||||
print("No results found in the initial search to use as center node.")
|
||||
print('No results found in the initial search to use as center node.')
|
||||
|
||||
finally:
|
||||
#################################################
|
||||
|
|
@ -218,8 +218,8 @@ async def main():
|
|||
|
||||
# Close the connection
|
||||
await graphiti.close()
|
||||
print("\nConnection closed")
|
||||
print('\nConnection closed')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(main())
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ class AzureOpenAIEmbedderClient(EmbedderClient):
|
|||
def __init__(
|
||||
self,
|
||||
azure_client: AsyncAzureOpenAI | AsyncOpenAI,
|
||||
model: str = "text-embedding-3-small",
|
||||
model: str = 'text-embedding-3-small',
|
||||
):
|
||||
self.azure_client = azure_client
|
||||
self.model = model
|
||||
|
|
@ -44,22 +44,18 @@ class AzureOpenAIEmbedderClient(EmbedderClient):
|
|||
# Handle different input types
|
||||
if isinstance(input_data, str):
|
||||
text_input = [input_data]
|
||||
elif isinstance(input_data, list) and all(
|
||||
isinstance(item, str) for item in input_data
|
||||
):
|
||||
elif isinstance(input_data, list) and all(isinstance(item, str) for item in input_data):
|
||||
text_input = input_data
|
||||
else:
|
||||
# Convert to string list for other types
|
||||
text_input = [str(input_data)]
|
||||
|
||||
response = await self.azure_client.embeddings.create(
|
||||
model=self.model, input=text_input
|
||||
)
|
||||
response = await self.azure_client.embeddings.create(model=self.model, input=text_input)
|
||||
|
||||
# Return the first embedding as a list of floats
|
||||
return response.data[0].embedding
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Azure OpenAI embedding: {e}")
|
||||
logger.error(f'Error in Azure OpenAI embedding: {e}')
|
||||
raise
|
||||
|
||||
async def create_batch(self, input_data_list: list[str]) -> list[list[float]]:
|
||||
|
|
@ -71,5 +67,5 @@ class AzureOpenAIEmbedderClient(EmbedderClient):
|
|||
|
||||
return [embedding.embedding for embedding in response.data]
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Azure OpenAI batch embedding: {e}")
|
||||
logger.error(f'Error in Azure OpenAI batch embedding: {e}')
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -66,21 +66,21 @@ class AzureOpenAILLMClient(BaseOpenAIClient):
|
|||
"""Create a structured completion using Azure OpenAI's responses.parse API."""
|
||||
supports_reasoning = self._supports_reasoning_features(model)
|
||||
request_kwargs = {
|
||||
"model": model,
|
||||
"input": messages,
|
||||
"max_output_tokens": max_tokens,
|
||||
"text_format": response_model, # type: ignore
|
||||
'model': model,
|
||||
'input': messages,
|
||||
'max_output_tokens': max_tokens,
|
||||
'text_format': response_model, # type: ignore
|
||||
}
|
||||
|
||||
temperature_value = temperature if not supports_reasoning else None
|
||||
if temperature_value is not None:
|
||||
request_kwargs["temperature"] = temperature_value
|
||||
request_kwargs['temperature'] = temperature_value
|
||||
|
||||
if supports_reasoning and reasoning:
|
||||
request_kwargs["reasoning"] = {"effort": reasoning} # type: ignore
|
||||
request_kwargs['reasoning'] = {'effort': reasoning} # type: ignore
|
||||
|
||||
if supports_reasoning and verbosity:
|
||||
request_kwargs["text"] = {"verbosity": verbosity} # type: ignore
|
||||
request_kwargs['text'] = {'verbosity': verbosity} # type: ignore
|
||||
|
||||
return await self.client.responses.parse(**request_kwargs)
|
||||
|
||||
|
|
@ -96,20 +96,20 @@ class AzureOpenAILLMClient(BaseOpenAIClient):
|
|||
supports_reasoning = self._supports_reasoning_features(model)
|
||||
|
||||
request_kwargs = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
"response_format": {"type": "json_object"},
|
||||
'model': model,
|
||||
'messages': messages,
|
||||
'max_tokens': max_tokens,
|
||||
'response_format': {'type': 'json_object'},
|
||||
}
|
||||
|
||||
temperature_value = temperature if not supports_reasoning else None
|
||||
if temperature_value is not None:
|
||||
request_kwargs["temperature"] = temperature_value
|
||||
request_kwargs['temperature'] = temperature_value
|
||||
|
||||
return await self.client.chat.completions.create(**request_kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _supports_reasoning_features(model: str) -> bool:
|
||||
"""Return True when the Azure model supports reasoning/verbosity options."""
|
||||
reasoning_prefixes = ("o1", "o3", "gpt-5")
|
||||
reasoning_prefixes = ('o1', 'o3', 'gpt-5')
|
||||
return model.startswith(reasoning_prefixes)
|
||||
|
|
|
|||
|
|
@ -166,13 +166,17 @@ class BaseOpenAIClient(LLMClient):
|
|||
except openai.RateLimitError as e:
|
||||
raise RateLimitError from e
|
||||
except openai.AuthenticationError as e:
|
||||
logger.error(f'OpenAI Authentication Error: {e}. Please verify your API key is correct.')
|
||||
logger.error(
|
||||
f'OpenAI Authentication Error: {e}. Please verify your API key is correct.'
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
# Provide more context for connection errors
|
||||
error_msg = str(e)
|
||||
if 'Connection error' in error_msg or 'connection' in error_msg.lower():
|
||||
logger.error(f'Connection error communicating with OpenAI API. Please check your network connection and API key. Error: {e}')
|
||||
logger.error(
|
||||
f'Connection error communicating with OpenAI API. Please check your network connection and API key. Error: {e}'
|
||||
)
|
||||
else:
|
||||
logger.error(f'Error in generating LLM response: {e}')
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -74,7 +74,9 @@ class OpenAIClient(BaseOpenAIClient):
|
|||
):
|
||||
"""Create a structured completion using OpenAI's beta parse API."""
|
||||
# Reasoning models (gpt-5 family) don't support temperature
|
||||
is_reasoning_model = model.startswith('gpt-5') or model.startswith('o1') or model.startswith('o3')
|
||||
is_reasoning_model = (
|
||||
model.startswith('gpt-5') or model.startswith('o1') or model.startswith('o3')
|
||||
)
|
||||
|
||||
response = await self.client.responses.parse(
|
||||
model=model,
|
||||
|
|
@ -100,7 +102,9 @@ class OpenAIClient(BaseOpenAIClient):
|
|||
):
|
||||
"""Create a regular completion with JSON format."""
|
||||
# Reasoning models (gpt-5 family) don't support temperature
|
||||
is_reasoning_model = model.startswith('gpt-5') or model.startswith('o1') or model.startswith('o3')
|
||||
is_reasoning_model = (
|
||||
model.startswith('gpt-5') or model.startswith('o1') or model.startswith('o3')
|
||||
)
|
||||
|
||||
return await self.client.chat.completions.create(
|
||||
model=model,
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ limitations under the License.
|
|||
import re
|
||||
|
||||
# Maximum length for entity/node summaries
|
||||
MAX_SUMMARY_CHARS = 250
|
||||
MAX_SUMMARY_CHARS = 500
|
||||
|
||||
|
||||
def truncate_at_sentence(text: str, max_chars: int) -> str:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
[project]
|
||||
name = "graphiti-core"
|
||||
description = "A temporal graph building library"
|
||||
version = "0.24.0"
|
||||
version = "0.24.1"
|
||||
authors = [
|
||||
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
|
||||
{ name = "Preston Rasmussen", email = "preston@getzep.com" },
|
||||
|
|
|
|||
|
|
@ -87,7 +87,7 @@ def test_truncate_at_sentence_strips_trailing_whitespace():
|
|||
|
||||
def test_max_summary_chars_constant():
|
||||
"""Test that MAX_SUMMARY_CHARS is set to expected value."""
|
||||
assert MAX_SUMMARY_CHARS == 250
|
||||
assert MAX_SUMMARY_CHARS == 500
|
||||
|
||||
|
||||
def test_truncate_at_sentence_realistic_summary():
|
||||
|
|
|
|||
4
uv.lock
generated
4
uv.lock
generated
|
|
@ -1,5 +1,5 @@
|
|||
version = 1
|
||||
revision = 3
|
||||
revision = 2
|
||||
requires-python = ">=3.10, <4"
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.14'",
|
||||
|
|
@ -808,7 +808,7 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "graphiti-core"
|
||||
version = "0.24.0"
|
||||
version = "0.24.1"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "diskcache" },
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue