chore: Add nba agent example

This commit is contained in:
paulpaliychuk 2024-08-26 20:04:35 -04:00
parent 895afc7be1
commit f95026763e
5 changed files with 698 additions and 0 deletions

View file

@ -0,0 +1,257 @@
[
{
"team_name": "Boston Celtics",
"player_id": 1628369,
"player_name": "Jayson Tatum"
},
{
"team_name": "Boston Celtics",
"player_id": 201950,
"player_name": "Jrue Holiday"
},
{
"team_name": "Boston Celtics",
"player_id": 1627759,
"player_name": "Jaylen Brown"
},
{
"team_name": "Boston Celtics",
"player_id": 204001,
"player_name": "Kristaps Porzingis"
},
{
"team_name": "Boston Celtics",
"player_id": 1628401,
"player_name": "Derrick White"
},
{
"team_name": "Boston Celtics",
"player_id": 1630202,
"player_name": "Payton Pritchard"
},
{
"team_name": "Boston Celtics",
"player_id": 1629052,
"player_name": "Oshae Brissett"
},
{
"team_name": "Boston Celtics",
"player_id": 1641809,
"player_name": "Drew Peterson"
},
{
"team_name": "Boston Celtics",
"player_id": 1631120,
"player_name": "JD Davison"
},
{
"team_name": "Boston Celtics",
"player_id": 1630214,
"player_name": "Xavier Tillman"
},
{
"team_name": "Boston Celtics",
"player_id": 1641775,
"player_name": "Jordan Walsh"
},
{
"team_name": "Boston Celtics",
"player_id": 1630573,
"player_name": "Sam Hauser"
},
{
"team_name": "Boston Celtics",
"player_id": 1628436,
"player_name": "Luke Kornet"
},
{
"team_name": "Boston Celtics",
"player_id": 201143,
"player_name": "Al Horford"
},
{
"team_name": "Boston Celtics",
"player_id": 1630531,
"player_name": "Jaden Springer"
},
{
"team_name": "Boston Celtics",
"player_id": 1629004,
"player_name": "Svi Mykhailiuk"
},
{
"team_name": "Boston Celtics",
"player_id": 1629674,
"player_name": "Neemias Queta"
},
{
"team_name": "Golden State Warriors",
"player_id": 1627780,
"player_name": "Gary Payton II"
},
{
"team_name": "Golden State Warriors",
"player_id": 1630228,
"player_name": "Jonathan Kuminga"
},
{
"team_name": "Golden State Warriors",
"player_id": 1641764,
"player_name": "Brandin Podziemski"
},
{
"team_name": "Golden State Warriors",
"player_id": 101108,
"player_name": "Chris Paul"
},
{
"team_name": "Golden State Warriors",
"player_id": 1630541,
"player_name": "Moses Moody"
},
{
"team_name": "Golden State Warriors",
"player_id": 1626172,
"player_name": "Kevon Looney"
},
{
"team_name": "Golden State Warriors",
"player_id": 202691,
"player_name": "Klay Thompson"
},
{
"team_name": "Golden State Warriors",
"player_id": 1630586,
"player_name": "Usman Garuba"
},
{
"team_name": "Golden State Warriors",
"player_id": 1630611,
"player_name": "Gui Santos"
},
{
"team_name": "Golden State Warriors",
"player_id": 1629010,
"player_name": "Jerome Robinson"
},
{
"team_name": "Golden State Warriors",
"player_id": 203967,
"player_name": "Dario Saric"
},
{
"team_name": "Golden State Warriors",
"player_id": 203952,
"player_name": "Andrew Wiggins"
},
{
"team_name": "Golden State Warriors",
"player_id": 203110,
"player_name": "Draymond Green"
},
{
"team_name": "Golden State Warriors",
"player_id": 1631311,
"player_name": "Lester Quinones"
},
{
"team_name": "Golden State Warriors",
"player_id": 201939,
"player_name": "Stephen Curry"
},
{
"team_name": "Golden State Warriors",
"player_id": 1631218,
"player_name": "Trayce Jackson-Davis"
},
{
"team_name": "Golden State Warriors",
"player_id": 1630311,
"player_name": "Pat Spencer"
},
{
"team_name": "Toronto Raptors",
"player_id": 1642013,
"player_name": "Malik Williams"
},
{
"team_name": "Toronto Raptors",
"player_id": 1631241,
"player_name": "Javon Freeman-Liberty"
},
{
"team_name": "Toronto Raptors",
"player_id": 1641711,
"player_name": "Gradey Dick"
},
{
"team_name": "Toronto Raptors",
"player_id": 1629667,
"player_name": "Jalen McDaniels"
},
{
"team_name": "Toronto Raptors",
"player_id": 1630618,
"player_name": "DJ Carton"
},
{
"team_name": "Toronto Raptors",
"player_id": 1630567,
"player_name": "Scottie Barnes"
},
{
"team_name": "Toronto Raptors",
"player_id": 1630193,
"player_name": "Immanuel Quickley"
},
{
"team_name": "Toronto Raptors",
"player_id": 1629628,
"player_name": "RJ Barrett"
},
{
"team_name": "Toronto Raptors",
"player_id": 1628971,
"player_name": "Bruce Brown"
},
{
"team_name": "Toronto Raptors",
"player_id": 1629670,
"player_name": "Jordan Nwora"
},
{
"team_name": "Toronto Raptors",
"player_id": 1631338,
"player_name": "Mouhamadou Gueye"
},
{
"team_name": "Toronto Raptors",
"player_id": 202066,
"player_name": "Garrett Temple"
},
{
"team_name": "Toronto Raptors",
"player_id": 1627751,
"player_name": "Jakob Poeltl"
},
{
"team_name": "Toronto Raptors",
"player_id": 1628449,
"player_name": "Chris Boucher"
},
{
"team_name": "Toronto Raptors",
"player_id": 1630534,
"player_name": "Ochai Agbaji"
},
{
"team_name": "Toronto Raptors",
"player_id": 1629018,
"player_name": "Gary Trent Jr."
},
{
"team_name": "Toronto Raptors",
"player_id": 203482,
"player_name": "Kelly Olynyk"
}
]

View file

@ -0,0 +1,143 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import asyncio
import json
import logging
import os
import sys
from datetime import datetime
from pathlib import Path
from typing import TypedDict
from dotenv import load_dotenv
from nba_api.stats.endpoints import commonteamroster, teamdetails
from nba_api.stats.static import players, teams
from graphiti_core import Graphiti
from graphiti_core.llm_client.anthropic_client import AnthropicClient
from graphiti_core.llm_client.config import LLMConfig
from graphiti_core.nodes import EpisodeType
from graphiti_core.utils.bulk_utils import RawEpisode
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
load_dotenv()
neo4j_uri = os.environ.get('NEO4J_URI', 'bolt://localhost:7687')
neo4j_user = os.environ.get('NEO4J_USER', 'neo4j')
neo4j_password = os.environ.get('NEO4J_PASSWORD', 'password')
class PlayerInfo(TypedDict):
team_name: str
player_id: int
player_name: str
player_number: str
player_position: str
player_school: str
def setup_logging():
# Create a logger
logger = logging.getLogger()
logger.setLevel(logging.INFO) # Set the logging level to INFO
# Create console handler and set level to INFO
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.INFO)
# Create formatter
formatter = logging.Formatter('%(name)s - %(levelname)s - %(message)s')
# Add formatter to console handler
console_handler.setFormatter(formatter)
# Add console handler to logger
logger.addHandler(console_handler)
return logger
def fetch_current_roster():
all_teams = teams.get_teams()
players_json = []
for t in all_teams:
name = t['full_name']
print(name)
if name == 'Golden State Warriors' or name == 'Boston Celtics' or name == 'Toronto Raptors':
roster = commonteamroster.CommonTeamRoster(team_id=t['id']).get_dict()
players_data = roster['resultSets'][0]
headers = players_data['headers']
row_set = players_data['rowSet']
for row in row_set:
player_dict = dict(zip(headers, row))
player_dict['team_name'] = name
print(player_dict)
meaningful_data = {
'team_name': name,
'player_id': player_dict['PLAYER_ID'],
'player_name': player_dict['PLAYER'],
# 'player_number': player_dict['NUM'],
# 'player_position': player_dict['POSITION'],
# 'player_school': player_dict['SCHOOL'],
}
players_json.append(meaningful_data)
script_dir = Path(__file__).parent
filename = script_dir / 'current_nba_roster.json'
print(players_json)
with open(filename, 'w') as f:
# write the players_json to the file and clear the file before doing so
f.truncate(0)
json.dump(players_json, f, indent=2)
async def main():
# fetch_current_roster()
current_roster_from_file: list[PlayerInfo] = []
script_dir = Path(__file__).parent
filename = script_dir / 'current_nba_roster.json'
with open(filename) as f:
current_roster_from_file = json.load(f)
print(current_roster_from_file)
client = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
await clear_data(client.driver)
await client.build_indices_and_constraints()
episodes: list[RawEpisode] = [
RawEpisode(
name=f'Player {player["player_id"]}',
content=str(player),
source_description='NBA current roster',
source=EpisodeType.json,
reference_time=datetime.now(),
)
for i, player in enumerate(current_roster_from_file)
]
await client.add_episode_bulk(episodes)
# client.llm_client = AnthropicClient(LLMConfig(api_key=os.environ.get('ANTHROPIC_API_KEY')))
# await client.add_episode(
# name='Player Transfer',
# episode_body='DJ Carton got transeffered to Boston Celtics August 2nd',
# source_description='NBA transfer',
# reference_time=datetime.now(),
# source=EpisodeType.message,
# )
if __name__ == '__main__':
asyncio.run(main())

242
examples/nba/nba_agent.py Normal file
View file

@ -0,0 +1,242 @@
import logging
import os
from datetime import datetime
from typing import Any, Dict, List
from dotenv import load_dotenv
from langchain.agents import AgentExecutor, create_openai_functions_agent
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI
from langgraph.graph import MessagesState, StateGraph
from langgraph.prebuilt import ToolInvocation, ToolNode
from graphiti_core import Graphiti
from graphiti_core.nodes import EpisodeType
load_dotenv()
logging.basicConfig(level=logging.INFO, format='%(message)s')
logger = logging.getLogger('nba_agent')
for name in logging.root.manager.loggerDict:
if name != 'nba_agent':
logging.getLogger(name).setLevel(logging.WARNING)
# Initialize Graphiti client
neo4j_uri = os.environ.get('NEO4J_URI', 'bolt://localhost:7687')
neo4j_user = os.environ.get('NEO4J_USER', 'neo4j')
neo4j_password = os.environ.get('NEO4J_PASSWORD', 'password')
graphiti_client = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
async def invoke_tool(tool_name: str, **kwargs):
tool = next(t for t in tools if t.name == tool_name)
return await tool.ainvoke(input=kwargs)
@tool
async def get_team_roster(team_name: str):
"""Get the current roster for a specific team."""
search_result = await graphiti_client.search(f'{team_name.lower()}', num_results=1)
print(search_result)
print(team_name.lower())
roster = []
for fact in search_result:
roster.append(fact)
return roster
@tool
async def search_player_info(player_name: str):
"""Search for information about a specific player."""
search_result = await graphiti_client.search(f'{player_name}')
player_info = {
'name': player_name,
'facts': search_result,
}
return player_info
@tool
async def verify_transfer_conditions(player_name: str, from_team: str, to_team: str):
"""Verify conditions for a player transfer."""
from_roster = await invoke_tool('get_team_roster', team_name=from_team)
to_roster = await invoke_tool('get_team_roster', team_name=to_team)
player_info = await invoke_tool('search_player_info', player_name=player_name)
# Prepare context for LLM
context = f"""
Player: {player_name}
From Team: {from_team}
To Team: {to_team}
From Team Roster:
{from_roster}
To Team Roster:
{to_roster}
Player Info:
{player_info}
"""
# Use LLM to evaluate transfer conditions
llm = ChatOpenAI(temperature=0)
prompt = f"""
Based on the following information, determine if the transfer conditions are met for {player_name} to move from {from_team} to {to_team}.
Context:
{context}
Please consider the following conditions:
1. Is {from_team} a valid NBA team?
2. Is {to_team} a valid NBA team?
3. Is {player_name} currently on the roster of {from_team}?
4. Is there enough information about {player_name}?
Important: Players can be transferred multiple times, including back to teams they've played for before.
Provide a detailed analysis of each condition and conclude whether all conditions are met or not.
Your response should end with one of these two statements:
- TRANSFER APPROVED: All conditions are met.
- TRANSFER DENIED: [Reason for denial]
"""
response = await llm.ainvoke(prompt)
return response.content
@tool
async def transfer_player(player_name: str, from_team: str, to_team: str):
"""Transfer a player from one team to another."""
try:
# Verify transfer conditions
verification_result = await invoke_tool(
'verify_transfer_conditions',
player_name=player_name,
from_team=from_team,
to_team=to_team,
)
# Check if transfer is approved
if 'TRANSFER APPROVED' in verification_result:
logger.info(f'Transfer initiated: {player_name} from {from_team} to {to_team}')
# Add episode
await graphiti_client.add_episode(
name=f'Transfer {player_name}',
episode_body=f'{player_name} transferred from {from_team} to {to_team}',
source_description='Player Transfer',
reference_time=datetime.now(),
source=EpisodeType.message,
)
return f'Player {player_name} has been successfully transferred from {from_team} to {to_team}.'
else:
return f'Transfer denied: {verification_result}'
except Exception as e:
logger.error(f'Error in transfer_player: {str(e)}')
return 'An error occurred while transferring the player. Please try again later or contact support.'
# Main agent setup
tools = [get_team_roster, search_player_info, verify_transfer_conditions, transfer_player]
prompt = ChatPromptTemplate.from_messages(
[
(
'system',
"""You are an AI assistant for NBA team management. Your role is to help users manage team rosters, transfer players, and provide information about players and teams. Use the available tools to gather information and perform actions.
When transferring players, always verify the following before proceeding:
1. The player is currently on the roster of the 'from' team.
2. Both the 'from' and 'to' teams are valid NBA teams.
Use the get_team_roster and search_player_info tools to verify this information. Only proceed with the transfer if all conditions are met.
IMPORTANT: Only use the information retrieved from the tools. Do not make assumptions or use information that hasn't been explicitly provided by the tools or the user. If you're unsure about any information, use the appropriate tool to verify it.""",
),
MessagesPlaceholder(variable_name='chat_history'),
('human', '{input}'),
MessagesPlaceholder(variable_name='agent_scratchpad'),
]
)
llm = ChatOpenAI(temperature=0)
agent = create_openai_functions_agent(llm, tools, prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools, handle_parsing_errors=True)
# Graph setup
workflow = StateGraph(MessagesState)
async def agent_node(state):
user_input = state['messages'][-1].content if state['messages'] else ''
chat_history = state['messages'][:-1] # All messages except the last one
result = await agent_executor.ainvoke({'input': user_input, 'chat_history': chat_history})
return {
'messages': state['messages'] + [AIMessage(content=result['output'], name='Manager')],
'agent_scratchpad': [],
}
workflow.add_node('agent', agent_node)
workflow.set_entry_point('agent')
workflow.add_edge('agent', '__end__')
app = workflow.compile()
# Run function
async def run_workflow(input_text: str, chat_history: List[Dict[str, Any]] = []):
result = await app.ainvoke(
{
'messages': [
*[
HumanMessage(content=msg['content'])
if msg['type'] == 'human'
else AIMessage(content=msg['content'])
for msg in chat_history
],
HumanMessage(content=input_text),
],
},
)
# Log only the latest human input and AI response
logger.info(f'Human: {input_text}\n')
latest_ai_message = next(
(message for message in reversed(result['messages']) if isinstance(message, AIMessage)),
None,
)
if latest_ai_message:
logger.info(f'AI: {latest_ai_message.content}\n')
return result['messages']
# Main loop
async def main():
chat_history = []
while True:
user_input = input("Enter your request (or 'quit' to exit): ")
if user_input.lower() == 'quit':
break
messages = await run_workflow(user_input, chat_history)
# Update chat history with only the latest human input and AI response
chat_history = [
{'type': 'human', 'content': user_input},
{
'type': 'ai',
'content': messages[-1].content
if isinstance(messages[-1], AIMessage)
else messages[-2].content,
},
]
if __name__ == '__main__':
import asyncio
asyncio.run(main())

21
examples/nba/poetry.toml Normal file
View file

@ -0,0 +1,21 @@
[tool.poetry]
name = "graphiti-nba-example"
version = "0.1.0"
description = "NBA roster management example using Graphiti and LangGraph"
authors = ["Your Name <your.email@example.com>"]
[tool.poetry.dependencies]
python = "^3.10"
graphiti-core = { path = "../..", develop = true }
python-dotenv = "^1.0.0"
nba-api = "^1.5.0"
langgraph = "^0.2.14"
langchain = "^0.2.14"
langchain-openai = "^0.1.22"
[tool.poetry.dev-dependencies]
pytest = "^7.3.1"
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"

35
examples/nba/runner.py Normal file
View file

@ -0,0 +1,35 @@
if __name__ == '__main__':
# {'id': 2544, 'full_name': 'LeBron James', 'first_name': 'LeBron', 'last_name': 'James', 'is_active': True}
all_teams = teams.get_teams()
all_players = players.get_players()
players_json = []
for t in all_teams:
name = t['full_name']
print(name)
if name == 'Golden State Warriors' or name == 'Boston Celtics' or name == 'Toronto Raptors':
roster = commonteamroster.CommonTeamRoster(team_id=t['id']).get_dict()
players_data = roster['resultSets'][0]
headers = players_data['headers']
row_set = players_data['rowSet']
players_json = []
for row in row_set:
player_dict = dict(zip(headers, row))
player_dict['team_name'] = name
print(player_dict)
meaningful_data = {
'team_name': name,
'player_id': player_dict['PLAYER_ID'],
'player_name': player_dict['PLAYER'],
'player_number': player_dict['NUM'],
'player_position': player_dict['POSITION'],
'player_school': player_dict['SCHOOL'],
}
players_json.append(meaningful_data)
print(len(players_json))
players_json.extend(players_json)
script_dir = Path(__file__).parent
filename = script_dir / 'current_nba_roster.json'
print(players_json)
with open(filename, 'w') as f:
json.dump(players_json, f, indent=2)