chore: Parse edges in tools

This commit is contained in:
paulpaliychuk 2024-08-27 12:16:19 -04:00
parent 888cccbb4f
commit 607affea8f
2 changed files with 28 additions and 24 deletions

View file

@ -117,26 +117,26 @@ async def main():
await clear_data(client.driver)
await client.build_indices_and_constraints()
episodes: list[RawEpisode] = [
RawEpisode(
name=f'Player {player["player_id"]}',
content=str(player),
source_description='NBA current roster',
source=EpisodeType.json,
reference_time=datetime.now(),
)
for i, player in enumerate(current_roster_from_file)
]
players_grouped_by_team = {}
for player in current_roster_from_file:
team_name = player['team_name']
if team_name not in players_grouped_by_team:
players_grouped_by_team[team_name] = []
players_grouped_by_team[team_name].append(player)
await client.add_episode_bulk(episodes)
# client.llm_client = AnthropicClient(LLMConfig(api_key=os.environ.get('ANTHROPIC_API_KEY')))
# await client.add_episode(
# name='Player Transfer',
# episode_body='DJ Carton got transeffered to Boston Celtics August 2nd',
# source_description='NBA transfer',
# reference_time=datetime.now(),
# source=EpisodeType.message,
# )
for _, _ in players_grouped_by_team.items():
episodes: list[RawEpisode] = [
RawEpisode(
name=f'Player {player["player_id"]}',
content=str(player),
source_description='NBA current roster',
source=EpisodeType.json,
reference_time=datetime.now(),
)
for player in players
]
await client.add_episode_bulk(episodes)
if __name__ == '__main__':

View file

@ -33,23 +33,27 @@ async def invoke_tool(tool_name: str, **kwargs):
return await tool.ainvoke(input=kwargs)
def get_fact_string(edge):
return f'{edge.fact} {edge.valid_at or edge.created_at}'
@tool
async def get_team_roster(team_name: str):
"""Get the current roster for a specific team."""
search_result = await graphiti_client.search(f'{team_name.lower()}', num_results=10)
search_result = await graphiti_client.search(f'{team_name.lower()}', num_results=30)
roster = []
for fact in search_result:
roster.append(fact)
for edge in search_result:
roster.append(get_fact_string(edge))
return roster
@tool
async def search_player_info(player_name: str):
"""Search for information about a specific player."""
search_result = await graphiti_client.search(f'{player_name}')
search_result = await graphiti_client.search(f'{player_name}', num_results=30)
player_info = {
'name': player_name,
'facts': search_result,
'facts': [get_fact_string(edge) for edge in search_result],
}
return player_info