diff --git a/examples/nba/ingest_current_roster.py b/examples/nba/ingest_current_roster.py index 10fe9179..c62de3f2 100644 --- a/examples/nba/ingest_current_roster.py +++ b/examples/nba/ingest_current_roster.py @@ -117,26 +117,26 @@ async def main(): await clear_data(client.driver) await client.build_indices_and_constraints() - episodes: list[RawEpisode] = [ - RawEpisode( - name=f'Player {player["player_id"]}', - content=str(player), - source_description='NBA current roster', - source=EpisodeType.json, - reference_time=datetime.now(), - ) - for i, player in enumerate(current_roster_from_file) - ] + players_grouped_by_team = {} + for player in current_roster_from_file: + team_name = player['team_name'] + if team_name not in players_grouped_by_team: + players_grouped_by_team[team_name] = [] + players_grouped_by_team[team_name].append(player) - await client.add_episode_bulk(episodes) - # client.llm_client = AnthropicClient(LLMConfig(api_key=os.environ.get('ANTHROPIC_API_KEY'))) - # await client.add_episode( - # name='Player Transfer', - # episode_body='DJ Carton got transeffered to Boston Celtics August 2nd', - # source_description='NBA transfer', - # reference_time=datetime.now(), - # source=EpisodeType.message, - # ) + for _, _ in players_grouped_by_team.items(): + episodes: list[RawEpisode] = [ + RawEpisode( + name=f'Player {player["player_id"]}', + content=str(player), + source_description='NBA current roster', + source=EpisodeType.json, + reference_time=datetime.now(), + ) + for player in players + ] + + await client.add_episode_bulk(episodes) if __name__ == '__main__': diff --git a/examples/nba/nba_agent.py b/examples/nba/nba_agent.py index 7a10b593..58d0186a 100644 --- a/examples/nba/nba_agent.py +++ b/examples/nba/nba_agent.py @@ -33,23 +33,27 @@ async def invoke_tool(tool_name: str, **kwargs): return await tool.ainvoke(input=kwargs) +def get_fact_string(edge): + return f'{edge.fact} {edge.valid_at or edge.created_at}' + + @tool async def get_team_roster(team_name: str): """Get the current roster for a specific team.""" - search_result = await graphiti_client.search(f'{team_name.lower()}', num_results=10) + search_result = await graphiti_client.search(f'{team_name.lower()}', num_results=30) roster = [] - for fact in search_result: - roster.append(fact) + for edge in search_result: + roster.append(get_fact_string(edge)) return roster @tool async def search_player_info(player_name: str): """Search for information about a specific player.""" - search_result = await graphiti_client.search(f'{player_name}') + search_result = await graphiti_client.search(f'{player_name}', num_results=30) player_info = { 'name': player_name, - 'facts': search_result, + 'facts': [get_fact_string(edge) for edge in search_result], } return player_info