parent
8b7ad6f84c
commit
d2654003ff
10 changed files with 91 additions and 87 deletions
|
|
@ -40,8 +40,8 @@ from graphiti_core.nodes import EpisodeType
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=INFO,
|
level=INFO,
|
||||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||||
datefmt="%Y-%m-%d %H:%M:%S",
|
datefmt='%Y-%m-%d %H:%M:%S',
|
||||||
)
|
)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -49,20 +49,20 @@ load_dotenv()
|
||||||
|
|
||||||
# Neo4j connection parameters
|
# Neo4j connection parameters
|
||||||
# Make sure Neo4j Desktop is running with a local DBMS started
|
# Make sure Neo4j Desktop is running with a local DBMS started
|
||||||
neo4j_uri = os.environ.get("NEO4J_URI", "bolt://localhost:7687")
|
neo4j_uri = os.environ.get('NEO4J_URI', 'bolt://localhost:7687')
|
||||||
neo4j_user = os.environ.get("NEO4J_USER", "neo4j")
|
neo4j_user = os.environ.get('NEO4J_USER', 'neo4j')
|
||||||
neo4j_password = os.environ.get("NEO4J_PASSWORD", "password")
|
neo4j_password = os.environ.get('NEO4J_PASSWORD', 'password')
|
||||||
|
|
||||||
# Azure OpenAI connection parameters
|
# Azure OpenAI connection parameters
|
||||||
azure_endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT")
|
azure_endpoint = os.environ.get('AZURE_OPENAI_ENDPOINT')
|
||||||
azure_api_key = os.environ.get("AZURE_OPENAI_API_KEY")
|
azure_api_key = os.environ.get('AZURE_OPENAI_API_KEY')
|
||||||
azure_deployment = os.environ.get("AZURE_OPENAI_DEPLOYMENT", "gpt-4.1")
|
azure_deployment = os.environ.get('AZURE_OPENAI_DEPLOYMENT', 'gpt-4.1')
|
||||||
azure_embedding_deployment = os.environ.get(
|
azure_embedding_deployment = os.environ.get(
|
||||||
"AZURE_OPENAI_EMBEDDING_DEPLOYMENT", "text-embedding-3-small"
|
'AZURE_OPENAI_EMBEDDING_DEPLOYMENT', 'text-embedding-3-small'
|
||||||
)
|
)
|
||||||
|
|
||||||
if not azure_endpoint or not azure_api_key:
|
if not azure_endpoint or not azure_api_key:
|
||||||
raise ValueError("AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_API_KEY must be set")
|
raise ValueError('AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_API_KEY must be set')
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
|
|
@ -76,7 +76,7 @@ async def main():
|
||||||
|
|
||||||
# Initialize Azure OpenAI client
|
# Initialize Azure OpenAI client
|
||||||
azure_client = AsyncOpenAI(
|
azure_client = AsyncOpenAI(
|
||||||
base_url=f"{azure_endpoint}/openai/v1/",
|
base_url=f'{azure_endpoint}/openai/v1/',
|
||||||
api_key=azure_api_key,
|
api_key=azure_api_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -112,40 +112,40 @@ async def main():
|
||||||
# Episodes list containing both text and JSON episodes
|
# Episodes list containing both text and JSON episodes
|
||||||
episodes = [
|
episodes = [
|
||||||
{
|
{
|
||||||
"content": "Kamala Harris is the Attorney General of California. She was previously "
|
'content': 'Kamala Harris is the Attorney General of California. She was previously '
|
||||||
"the district attorney for San Francisco.",
|
'the district attorney for San Francisco.',
|
||||||
"type": EpisodeType.text,
|
'type': EpisodeType.text,
|
||||||
"description": "podcast transcript",
|
'description': 'podcast transcript',
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"content": "As AG, Harris was in office from January 3, 2011 – January 3, 2017",
|
'content': 'As AG, Harris was in office from January 3, 2011 – January 3, 2017',
|
||||||
"type": EpisodeType.text,
|
'type': EpisodeType.text,
|
||||||
"description": "podcast transcript",
|
'description': 'podcast transcript',
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"content": {
|
'content': {
|
||||||
"name": "Gavin Newsom",
|
'name': 'Gavin Newsom',
|
||||||
"position": "Governor",
|
'position': 'Governor',
|
||||||
"state": "California",
|
'state': 'California',
|
||||||
"previous_role": "Lieutenant Governor",
|
'previous_role': 'Lieutenant Governor',
|
||||||
"previous_location": "San Francisco",
|
'previous_location': 'San Francisco',
|
||||||
},
|
},
|
||||||
"type": EpisodeType.json,
|
'type': EpisodeType.json,
|
||||||
"description": "podcast metadata",
|
'description': 'podcast metadata',
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
# Add episodes to the graph
|
# Add episodes to the graph
|
||||||
for i, episode in enumerate(episodes):
|
for i, episode in enumerate(episodes):
|
||||||
await graphiti.add_episode(
|
await graphiti.add_episode(
|
||||||
name=f"California Politics {i}",
|
name=f'California Politics {i}',
|
||||||
episode_body=(
|
episode_body=(
|
||||||
episode["content"]
|
episode['content']
|
||||||
if isinstance(episode["content"], str)
|
if isinstance(episode['content'], str)
|
||||||
else json.dumps(episode["content"])
|
else json.dumps(episode['content'])
|
||||||
),
|
),
|
||||||
source=episode["type"],
|
source=episode['type'],
|
||||||
source_description=episode["description"],
|
source_description=episode['description'],
|
||||||
reference_time=datetime.now(timezone.utc),
|
reference_time=datetime.now(timezone.utc),
|
||||||
)
|
)
|
||||||
print(f'Added episode: California Politics {i} ({episode["type"].value})')
|
print(f'Added episode: California Politics {i} ({episode["type"].value})')
|
||||||
|
|
@ -161,18 +161,18 @@ async def main():
|
||||||
|
|
||||||
# Perform a hybrid search combining semantic similarity and BM25 retrieval
|
# Perform a hybrid search combining semantic similarity and BM25 retrieval
|
||||||
print("\nSearching for: 'Who was the California Attorney General?'")
|
print("\nSearching for: 'Who was the California Attorney General?'")
|
||||||
results = await graphiti.search("Who was the California Attorney General?")
|
results = await graphiti.search('Who was the California Attorney General?')
|
||||||
|
|
||||||
# Print search results
|
# Print search results
|
||||||
print("\nSearch Results:")
|
print('\nSearch Results:')
|
||||||
for result in results:
|
for result in results:
|
||||||
print(f"UUID: {result.uuid}")
|
print(f'UUID: {result.uuid}')
|
||||||
print(f"Fact: {result.fact}")
|
print(f'Fact: {result.fact}')
|
||||||
if hasattr(result, "valid_at") and result.valid_at:
|
if hasattr(result, 'valid_at') and result.valid_at:
|
||||||
print(f"Valid from: {result.valid_at}")
|
print(f'Valid from: {result.valid_at}')
|
||||||
if hasattr(result, "invalid_at") and result.invalid_at:
|
if hasattr(result, 'invalid_at') and result.invalid_at:
|
||||||
print(f"Valid until: {result.invalid_at}")
|
print(f'Valid until: {result.invalid_at}')
|
||||||
print("---")
|
print('---')
|
||||||
|
|
||||||
#################################################
|
#################################################
|
||||||
# CENTER NODE SEARCH
|
# CENTER NODE SEARCH
|
||||||
|
|
@ -187,26 +187,26 @@ async def main():
|
||||||
# Get the source node UUID from the top result
|
# Get the source node UUID from the top result
|
||||||
center_node_uuid = results[0].source_node_uuid
|
center_node_uuid = results[0].source_node_uuid
|
||||||
|
|
||||||
print("\nReranking search results based on graph distance:")
|
print('\nReranking search results based on graph distance:')
|
||||||
print(f"Using center node UUID: {center_node_uuid}")
|
print(f'Using center node UUID: {center_node_uuid}')
|
||||||
|
|
||||||
reranked_results = await graphiti.search(
|
reranked_results = await graphiti.search(
|
||||||
"Who was the California Attorney General?",
|
'Who was the California Attorney General?',
|
||||||
center_node_uuid=center_node_uuid,
|
center_node_uuid=center_node_uuid,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Print reranked search results
|
# Print reranked search results
|
||||||
print("\nReranked Search Results:")
|
print('\nReranked Search Results:')
|
||||||
for result in reranked_results:
|
for result in reranked_results:
|
||||||
print(f"UUID: {result.uuid}")
|
print(f'UUID: {result.uuid}')
|
||||||
print(f"Fact: {result.fact}")
|
print(f'Fact: {result.fact}')
|
||||||
if hasattr(result, "valid_at") and result.valid_at:
|
if hasattr(result, 'valid_at') and result.valid_at:
|
||||||
print(f"Valid from: {result.valid_at}")
|
print(f'Valid from: {result.valid_at}')
|
||||||
if hasattr(result, "invalid_at") and result.invalid_at:
|
if hasattr(result, 'invalid_at') and result.invalid_at:
|
||||||
print(f"Valid until: {result.invalid_at}")
|
print(f'Valid until: {result.invalid_at}')
|
||||||
print("---")
|
print('---')
|
||||||
else:
|
else:
|
||||||
print("No results found in the initial search to use as center node.")
|
print('No results found in the initial search to use as center node.')
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
#################################################
|
#################################################
|
||||||
|
|
@ -218,8 +218,8 @@ async def main():
|
||||||
|
|
||||||
# Close the connection
|
# Close the connection
|
||||||
await graphiti.close()
|
await graphiti.close()
|
||||||
print("\nConnection closed")
|
print('\nConnection closed')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == '__main__':
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ class AzureOpenAIEmbedderClient(EmbedderClient):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
azure_client: AsyncAzureOpenAI | AsyncOpenAI,
|
azure_client: AsyncAzureOpenAI | AsyncOpenAI,
|
||||||
model: str = "text-embedding-3-small",
|
model: str = 'text-embedding-3-small',
|
||||||
):
|
):
|
||||||
self.azure_client = azure_client
|
self.azure_client = azure_client
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
@ -44,22 +44,18 @@ class AzureOpenAIEmbedderClient(EmbedderClient):
|
||||||
# Handle different input types
|
# Handle different input types
|
||||||
if isinstance(input_data, str):
|
if isinstance(input_data, str):
|
||||||
text_input = [input_data]
|
text_input = [input_data]
|
||||||
elif isinstance(input_data, list) and all(
|
elif isinstance(input_data, list) and all(isinstance(item, str) for item in input_data):
|
||||||
isinstance(item, str) for item in input_data
|
|
||||||
):
|
|
||||||
text_input = input_data
|
text_input = input_data
|
||||||
else:
|
else:
|
||||||
# Convert to string list for other types
|
# Convert to string list for other types
|
||||||
text_input = [str(input_data)]
|
text_input = [str(input_data)]
|
||||||
|
|
||||||
response = await self.azure_client.embeddings.create(
|
response = await self.azure_client.embeddings.create(model=self.model, input=text_input)
|
||||||
model=self.model, input=text_input
|
|
||||||
)
|
|
||||||
|
|
||||||
# Return the first embedding as a list of floats
|
# Return the first embedding as a list of floats
|
||||||
return response.data[0].embedding
|
return response.data[0].embedding
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in Azure OpenAI embedding: {e}")
|
logger.error(f'Error in Azure OpenAI embedding: {e}')
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def create_batch(self, input_data_list: list[str]) -> list[list[float]]:
|
async def create_batch(self, input_data_list: list[str]) -> list[list[float]]:
|
||||||
|
|
@ -71,5 +67,5 @@ class AzureOpenAIEmbedderClient(EmbedderClient):
|
||||||
|
|
||||||
return [embedding.embedding for embedding in response.data]
|
return [embedding.embedding for embedding in response.data]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in Azure OpenAI batch embedding: {e}")
|
logger.error(f'Error in Azure OpenAI batch embedding: {e}')
|
||||||
raise
|
raise
|
||||||
|
|
|
||||||
|
|
@ -66,21 +66,21 @@ class AzureOpenAILLMClient(BaseOpenAIClient):
|
||||||
"""Create a structured completion using Azure OpenAI's responses.parse API."""
|
"""Create a structured completion using Azure OpenAI's responses.parse API."""
|
||||||
supports_reasoning = self._supports_reasoning_features(model)
|
supports_reasoning = self._supports_reasoning_features(model)
|
||||||
request_kwargs = {
|
request_kwargs = {
|
||||||
"model": model,
|
'model': model,
|
||||||
"input": messages,
|
'input': messages,
|
||||||
"max_output_tokens": max_tokens,
|
'max_output_tokens': max_tokens,
|
||||||
"text_format": response_model, # type: ignore
|
'text_format': response_model, # type: ignore
|
||||||
}
|
}
|
||||||
|
|
||||||
temperature_value = temperature if not supports_reasoning else None
|
temperature_value = temperature if not supports_reasoning else None
|
||||||
if temperature_value is not None:
|
if temperature_value is not None:
|
||||||
request_kwargs["temperature"] = temperature_value
|
request_kwargs['temperature'] = temperature_value
|
||||||
|
|
||||||
if supports_reasoning and reasoning:
|
if supports_reasoning and reasoning:
|
||||||
request_kwargs["reasoning"] = {"effort": reasoning} # type: ignore
|
request_kwargs['reasoning'] = {'effort': reasoning} # type: ignore
|
||||||
|
|
||||||
if supports_reasoning and verbosity:
|
if supports_reasoning and verbosity:
|
||||||
request_kwargs["text"] = {"verbosity": verbosity} # type: ignore
|
request_kwargs['text'] = {'verbosity': verbosity} # type: ignore
|
||||||
|
|
||||||
return await self.client.responses.parse(**request_kwargs)
|
return await self.client.responses.parse(**request_kwargs)
|
||||||
|
|
||||||
|
|
@ -96,20 +96,20 @@ class AzureOpenAILLMClient(BaseOpenAIClient):
|
||||||
supports_reasoning = self._supports_reasoning_features(model)
|
supports_reasoning = self._supports_reasoning_features(model)
|
||||||
|
|
||||||
request_kwargs = {
|
request_kwargs = {
|
||||||
"model": model,
|
'model': model,
|
||||||
"messages": messages,
|
'messages': messages,
|
||||||
"max_tokens": max_tokens,
|
'max_tokens': max_tokens,
|
||||||
"response_format": {"type": "json_object"},
|
'response_format': {'type': 'json_object'},
|
||||||
}
|
}
|
||||||
|
|
||||||
temperature_value = temperature if not supports_reasoning else None
|
temperature_value = temperature if not supports_reasoning else None
|
||||||
if temperature_value is not None:
|
if temperature_value is not None:
|
||||||
request_kwargs["temperature"] = temperature_value
|
request_kwargs['temperature'] = temperature_value
|
||||||
|
|
||||||
return await self.client.chat.completions.create(**request_kwargs)
|
return await self.client.chat.completions.create(**request_kwargs)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _supports_reasoning_features(model: str) -> bool:
|
def _supports_reasoning_features(model: str) -> bool:
|
||||||
"""Return True when the Azure model supports reasoning/verbosity options."""
|
"""Return True when the Azure model supports reasoning/verbosity options."""
|
||||||
reasoning_prefixes = ("o1", "o3", "gpt-5")
|
reasoning_prefixes = ('o1', 'o3', 'gpt-5')
|
||||||
return model.startswith(reasoning_prefixes)
|
return model.startswith(reasoning_prefixes)
|
||||||
|
|
|
||||||
|
|
@ -166,13 +166,17 @@ class BaseOpenAIClient(LLMClient):
|
||||||
except openai.RateLimitError as e:
|
except openai.RateLimitError as e:
|
||||||
raise RateLimitError from e
|
raise RateLimitError from e
|
||||||
except openai.AuthenticationError as e:
|
except openai.AuthenticationError as e:
|
||||||
logger.error(f'OpenAI Authentication Error: {e}. Please verify your API key is correct.')
|
logger.error(
|
||||||
|
f'OpenAI Authentication Error: {e}. Please verify your API key is correct.'
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Provide more context for connection errors
|
# Provide more context for connection errors
|
||||||
error_msg = str(e)
|
error_msg = str(e)
|
||||||
if 'Connection error' in error_msg or 'connection' in error_msg.lower():
|
if 'Connection error' in error_msg or 'connection' in error_msg.lower():
|
||||||
logger.error(f'Connection error communicating with OpenAI API. Please check your network connection and API key. Error: {e}')
|
logger.error(
|
||||||
|
f'Connection error communicating with OpenAI API. Please check your network connection and API key. Error: {e}'
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.error(f'Error in generating LLM response: {e}')
|
logger.error(f'Error in generating LLM response: {e}')
|
||||||
raise
|
raise
|
||||||
|
|
|
||||||
|
|
@ -74,7 +74,9 @@ class OpenAIClient(BaseOpenAIClient):
|
||||||
):
|
):
|
||||||
"""Create a structured completion using OpenAI's beta parse API."""
|
"""Create a structured completion using OpenAI's beta parse API."""
|
||||||
# Reasoning models (gpt-5 family) don't support temperature
|
# Reasoning models (gpt-5 family) don't support temperature
|
||||||
is_reasoning_model = model.startswith('gpt-5') or model.startswith('o1') or model.startswith('o3')
|
is_reasoning_model = (
|
||||||
|
model.startswith('gpt-5') or model.startswith('o1') or model.startswith('o3')
|
||||||
|
)
|
||||||
|
|
||||||
response = await self.client.responses.parse(
|
response = await self.client.responses.parse(
|
||||||
model=model,
|
model=model,
|
||||||
|
|
@ -100,7 +102,9 @@ class OpenAIClient(BaseOpenAIClient):
|
||||||
):
|
):
|
||||||
"""Create a regular completion with JSON format."""
|
"""Create a regular completion with JSON format."""
|
||||||
# Reasoning models (gpt-5 family) don't support temperature
|
# Reasoning models (gpt-5 family) don't support temperature
|
||||||
is_reasoning_model = model.startswith('gpt-5') or model.startswith('o1') or model.startswith('o3')
|
is_reasoning_model = (
|
||||||
|
model.startswith('gpt-5') or model.startswith('o1') or model.startswith('o3')
|
||||||
|
)
|
||||||
|
|
||||||
return await self.client.chat.completions.create(
|
return await self.client.chat.completions.create(
|
||||||
model=model,
|
model=model,
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ limitations under the License.
|
||||||
import re
|
import re
|
||||||
|
|
||||||
# Maximum length for entity/node summaries
|
# Maximum length for entity/node summaries
|
||||||
MAX_SUMMARY_CHARS = 250
|
MAX_SUMMARY_CHARS = 500
|
||||||
|
|
||||||
|
|
||||||
def truncate_at_sentence(text: str, max_chars: int) -> str:
|
def truncate_at_sentence(text: str, max_chars: int) -> str:
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
[project]
|
[project]
|
||||||
name = "graphiti-core"
|
name = "graphiti-core"
|
||||||
description = "A temporal graph building library"
|
description = "A temporal graph building library"
|
||||||
version = "0.24.0"
|
version = "0.24.1"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
|
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
|
||||||
{ name = "Preston Rasmussen", email = "preston@getzep.com" },
|
{ name = "Preston Rasmussen", email = "preston@getzep.com" },
|
||||||
|
|
|
||||||
|
|
@ -87,7 +87,7 @@ def test_truncate_at_sentence_strips_trailing_whitespace():
|
||||||
|
|
||||||
def test_max_summary_chars_constant():
|
def test_max_summary_chars_constant():
|
||||||
"""Test that MAX_SUMMARY_CHARS is set to expected value."""
|
"""Test that MAX_SUMMARY_CHARS is set to expected value."""
|
||||||
assert MAX_SUMMARY_CHARS == 250
|
assert MAX_SUMMARY_CHARS == 500
|
||||||
|
|
||||||
|
|
||||||
def test_truncate_at_sentence_realistic_summary():
|
def test_truncate_at_sentence_realistic_summary():
|
||||||
|
|
|
||||||
4
uv.lock
generated
4
uv.lock
generated
|
|
@ -1,5 +1,5 @@
|
||||||
version = 1
|
version = 1
|
||||||
revision = 3
|
revision = 2
|
||||||
requires-python = ">=3.10, <4"
|
requires-python = ">=3.10, <4"
|
||||||
resolution-markers = [
|
resolution-markers = [
|
||||||
"python_full_version >= '3.14'",
|
"python_full_version >= '3.14'",
|
||||||
|
|
@ -808,7 +808,7 @@ wheels = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "graphiti-core"
|
name = "graphiti-core"
|
||||||
version = "0.24.0"
|
version = "0.24.1"
|
||||||
source = { editable = "." }
|
source = { editable = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "diskcache" },
|
{ name = "diskcache" },
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue