chore: Parse edges in tools

This commit is contained in:
paulpaliychuk 2024-08-27 12:16:19 -04:00
parent 888cccbb4f
commit 607affea8f
2 changed files with 28 additions and 24 deletions

View file

@ -117,26 +117,26 @@ async def main():
await clear_data(client.driver) await clear_data(client.driver)
await client.build_indices_and_constraints() await client.build_indices_and_constraints()
episodes: list[RawEpisode] = [ players_grouped_by_team = {}
RawEpisode( for player in current_roster_from_file:
name=f'Player {player["player_id"]}', team_name = player['team_name']
content=str(player), if team_name not in players_grouped_by_team:
source_description='NBA current roster', players_grouped_by_team[team_name] = []
source=EpisodeType.json, players_grouped_by_team[team_name].append(player)
reference_time=datetime.now(),
)
for i, player in enumerate(current_roster_from_file)
]
await client.add_episode_bulk(episodes) for _, _ in players_grouped_by_team.items():
# client.llm_client = AnthropicClient(LLMConfig(api_key=os.environ.get('ANTHROPIC_API_KEY'))) episodes: list[RawEpisode] = [
# await client.add_episode( RawEpisode(
# name='Player Transfer', name=f'Player {player["player_id"]}',
# episode_body='DJ Carton got transeffered to Boston Celtics August 2nd', content=str(player),
# source_description='NBA transfer', source_description='NBA current roster',
# reference_time=datetime.now(), source=EpisodeType.json,
# source=EpisodeType.message, reference_time=datetime.now(),
# ) )
for player in players
]
await client.add_episode_bulk(episodes)
if __name__ == '__main__': if __name__ == '__main__':

View file

@ -33,23 +33,27 @@ async def invoke_tool(tool_name: str, **kwargs):
return await tool.ainvoke(input=kwargs) return await tool.ainvoke(input=kwargs)
def get_fact_string(edge):
return f'{edge.fact} {edge.valid_at or edge.created_at}'
@tool @tool
async def get_team_roster(team_name: str): async def get_team_roster(team_name: str):
"""Get the current roster for a specific team.""" """Get the current roster for a specific team."""
search_result = await graphiti_client.search(f'{team_name.lower()}', num_results=10) search_result = await graphiti_client.search(f'{team_name.lower()}', num_results=30)
roster = [] roster = []
for fact in search_result: for edge in search_result:
roster.append(fact) roster.append(get_fact_string(edge))
return roster return roster
@tool @tool
async def search_player_info(player_name: str): async def search_player_info(player_name: str):
"""Search for information about a specific player.""" """Search for information about a specific player."""
search_result = await graphiti_client.search(f'{player_name}') search_result = await graphiti_client.search(f'{player_name}', num_results=30)
player_info = { player_info = {
'name': player_name, 'name': player_name,
'facts': search_result, 'facts': [get_fact_string(edge) for edge in search_result],
} }
return player_info return player_info