chore: Parse edges in tools
This commit is contained in:
parent
888cccbb4f
commit
607affea8f
2 changed files with 28 additions and 24 deletions
|
|
@ -117,26 +117,26 @@ async def main():
|
||||||
await clear_data(client.driver)
|
await clear_data(client.driver)
|
||||||
await client.build_indices_and_constraints()
|
await client.build_indices_and_constraints()
|
||||||
|
|
||||||
episodes: list[RawEpisode] = [
|
players_grouped_by_team = {}
|
||||||
RawEpisode(
|
for player in current_roster_from_file:
|
||||||
name=f'Player {player["player_id"]}',
|
team_name = player['team_name']
|
||||||
content=str(player),
|
if team_name not in players_grouped_by_team:
|
||||||
source_description='NBA current roster',
|
players_grouped_by_team[team_name] = []
|
||||||
source=EpisodeType.json,
|
players_grouped_by_team[team_name].append(player)
|
||||||
reference_time=datetime.now(),
|
|
||||||
)
|
|
||||||
for i, player in enumerate(current_roster_from_file)
|
|
||||||
]
|
|
||||||
|
|
||||||
await client.add_episode_bulk(episodes)
|
for _, _ in players_grouped_by_team.items():
|
||||||
# client.llm_client = AnthropicClient(LLMConfig(api_key=os.environ.get('ANTHROPIC_API_KEY')))
|
episodes: list[RawEpisode] = [
|
||||||
# await client.add_episode(
|
RawEpisode(
|
||||||
# name='Player Transfer',
|
name=f'Player {player["player_id"]}',
|
||||||
# episode_body='DJ Carton got transeffered to Boston Celtics August 2nd',
|
content=str(player),
|
||||||
# source_description='NBA transfer',
|
source_description='NBA current roster',
|
||||||
# reference_time=datetime.now(),
|
source=EpisodeType.json,
|
||||||
# source=EpisodeType.message,
|
reference_time=datetime.now(),
|
||||||
# )
|
)
|
||||||
|
for player in players
|
||||||
|
]
|
||||||
|
|
||||||
|
await client.add_episode_bulk(episodes)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
||||||
|
|
@ -33,23 +33,27 @@ async def invoke_tool(tool_name: str, **kwargs):
|
||||||
return await tool.ainvoke(input=kwargs)
|
return await tool.ainvoke(input=kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def get_fact_string(edge):
|
||||||
|
return f'{edge.fact} {edge.valid_at or edge.created_at}'
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def get_team_roster(team_name: str):
|
async def get_team_roster(team_name: str):
|
||||||
"""Get the current roster for a specific team."""
|
"""Get the current roster for a specific team."""
|
||||||
search_result = await graphiti_client.search(f'{team_name.lower()}', num_results=10)
|
search_result = await graphiti_client.search(f'{team_name.lower()}', num_results=30)
|
||||||
roster = []
|
roster = []
|
||||||
for fact in search_result:
|
for edge in search_result:
|
||||||
roster.append(fact)
|
roster.append(get_fact_string(edge))
|
||||||
return roster
|
return roster
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def search_player_info(player_name: str):
|
async def search_player_info(player_name: str):
|
||||||
"""Search for information about a specific player."""
|
"""Search for information about a specific player."""
|
||||||
search_result = await graphiti_client.search(f'{player_name}')
|
search_result = await graphiti_client.search(f'{player_name}', num_results=30)
|
||||||
player_info = {
|
player_info = {
|
||||||
'name': player_name,
|
'name': player_name,
|
||||||
'facts': search_result,
|
'facts': [get_fact_string(edge) for edge in search_result],
|
||||||
}
|
}
|
||||||
return player_info
|
return player_info
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue