chore: Parse edges in tools
This commit is contained in:
parent
888cccbb4f
commit
607affea8f
2 changed files with 28 additions and 24 deletions
|
|
@ -117,26 +117,26 @@ async def main():
|
|||
await clear_data(client.driver)
|
||||
await client.build_indices_and_constraints()
|
||||
|
||||
episodes: list[RawEpisode] = [
|
||||
RawEpisode(
|
||||
name=f'Player {player["player_id"]}',
|
||||
content=str(player),
|
||||
source_description='NBA current roster',
|
||||
source=EpisodeType.json,
|
||||
reference_time=datetime.now(),
|
||||
)
|
||||
for i, player in enumerate(current_roster_from_file)
|
||||
]
|
||||
players_grouped_by_team = {}
|
||||
for player in current_roster_from_file:
|
||||
team_name = player['team_name']
|
||||
if team_name not in players_grouped_by_team:
|
||||
players_grouped_by_team[team_name] = []
|
||||
players_grouped_by_team[team_name].append(player)
|
||||
|
||||
await client.add_episode_bulk(episodes)
|
||||
# client.llm_client = AnthropicClient(LLMConfig(api_key=os.environ.get('ANTHROPIC_API_KEY')))
|
||||
# await client.add_episode(
|
||||
# name='Player Transfer',
|
||||
# episode_body='DJ Carton got transeffered to Boston Celtics August 2nd',
|
||||
# source_description='NBA transfer',
|
||||
# reference_time=datetime.now(),
|
||||
# source=EpisodeType.message,
|
||||
# )
|
||||
for _, _ in players_grouped_by_team.items():
|
||||
episodes: list[RawEpisode] = [
|
||||
RawEpisode(
|
||||
name=f'Player {player["player_id"]}',
|
||||
content=str(player),
|
||||
source_description='NBA current roster',
|
||||
source=EpisodeType.json,
|
||||
reference_time=datetime.now(),
|
||||
)
|
||||
for player in players
|
||||
]
|
||||
|
||||
await client.add_episode_bulk(episodes)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
|
|
@ -33,23 +33,27 @@ async def invoke_tool(tool_name: str, **kwargs):
|
|||
return await tool.ainvoke(input=kwargs)
|
||||
|
||||
|
||||
def get_fact_string(edge):
|
||||
return f'{edge.fact} {edge.valid_at or edge.created_at}'
|
||||
|
||||
|
||||
@tool
|
||||
async def get_team_roster(team_name: str):
|
||||
"""Get the current roster for a specific team."""
|
||||
search_result = await graphiti_client.search(f'{team_name.lower()}', num_results=10)
|
||||
search_result = await graphiti_client.search(f'{team_name.lower()}', num_results=30)
|
||||
roster = []
|
||||
for fact in search_result:
|
||||
roster.append(fact)
|
||||
for edge in search_result:
|
||||
roster.append(get_fact_string(edge))
|
||||
return roster
|
||||
|
||||
|
||||
@tool
|
||||
async def search_player_info(player_name: str):
|
||||
"""Search for information about a specific player."""
|
||||
search_result = await graphiti_client.search(f'{player_name}')
|
||||
search_result = await graphiti_client.search(f'{player_name}', num_results=30)
|
||||
player_info = {
|
||||
'name': player_name,
|
||||
'facts': search_result,
|
||||
'facts': [get_fact_string(edge) for edge in search_result],
|
||||
}
|
||||
return player_info
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue