chore: Parse edges in tools

This commit is contained in:
paulpaliychuk 2024-08-27 12:16:19 -04:00
parent 888cccbb4f
commit 607affea8f
2 changed files with 28 additions and 24 deletions

View file

@ -117,6 +117,14 @@ async def main():
await clear_data(client.driver) await clear_data(client.driver)
await client.build_indices_and_constraints() await client.build_indices_and_constraints()
players_grouped_by_team = {}
for player in current_roster_from_file:
team_name = player['team_name']
if team_name not in players_grouped_by_team:
players_grouped_by_team[team_name] = []
players_grouped_by_team[team_name].append(player)
for _, _ in players_grouped_by_team.items():
episodes: list[RawEpisode] = [ episodes: list[RawEpisode] = [
RawEpisode( RawEpisode(
name=f'Player {player["player_id"]}', name=f'Player {player["player_id"]}',
@ -125,18 +133,10 @@ async def main():
source=EpisodeType.json, source=EpisodeType.json,
reference_time=datetime.now(), reference_time=datetime.now(),
) )
for i, player in enumerate(current_roster_from_file) for player in players
] ]
await client.add_episode_bulk(episodes) await client.add_episode_bulk(episodes)
# client.llm_client = AnthropicClient(LLMConfig(api_key=os.environ.get('ANTHROPIC_API_KEY')))
# await client.add_episode(
# name='Player Transfer',
# episode_body='DJ Carton got transeffered to Boston Celtics August 2nd',
# source_description='NBA transfer',
# reference_time=datetime.now(),
# source=EpisodeType.message,
# )
if __name__ == '__main__': if __name__ == '__main__':

View file

@ -33,23 +33,27 @@ async def invoke_tool(tool_name: str, **kwargs):
return await tool.ainvoke(input=kwargs) return await tool.ainvoke(input=kwargs)
def get_fact_string(edge):
return f'{edge.fact} {edge.valid_at or edge.created_at}'
@tool @tool
async def get_team_roster(team_name: str): async def get_team_roster(team_name: str):
"""Get the current roster for a specific team.""" """Get the current roster for a specific team."""
search_result = await graphiti_client.search(f'{team_name.lower()}', num_results=10) search_result = await graphiti_client.search(f'{team_name.lower()}', num_results=30)
roster = [] roster = []
for fact in search_result: for edge in search_result:
roster.append(fact) roster.append(get_fact_string(edge))
return roster return roster
@tool @tool
async def search_player_info(player_name: str): async def search_player_info(player_name: str):
"""Search for information about a specific player.""" """Search for information about a specific player."""
search_result = await graphiti_client.search(f'{player_name}') search_result = await graphiti_client.search(f'{player_name}', num_results=30)
player_info = { player_info = {
'name': player_name, 'name': player_name,
'facts': search_result, 'facts': [get_fact_string(edge) for edge in search_result],
} }
return player_info return player_info