feat: Implement first draft of nba agent

This commit is contained in:
paulpaliychuk 2024-08-30 14:03:26 -04:00
parent f1449ac69a
commit 699b815b19
3 changed files with 204 additions and 50 deletions

View file

@ -169,6 +169,186 @@
"player_id": 1630311,
"player_name": "Pat Spencer"
},
{
"team_name": "Los Angeles Lakers",
"player_id": 1641720,
"player_name": "Jalen Hood-Schifino"
},
{
"team_name": "Los Angeles Lakers",
"player_id": 1626156,
"player_name": "D'Angelo Russell"
},
{
"team_name": "Los Angeles Lakers",
"player_id": 1629020,
"player_name": "Jarred Vanderbilt"
},
{
"team_name": "Los Angeles Lakers",
"player_id": 203076,
"player_name": "Anthony Davis"
},
{
"team_name": "Los Angeles Lakers",
"player_id": 1630219,
"player_name": "Skylar Mays"
},
{
"team_name": "Los Angeles Lakers",
"player_id": 1629629,
"player_name": "Cam Reddish"
},
{
"team_name": "Los Angeles Lakers",
"player_id": 1629216,
"player_name": "Gabe Vincent"
},
{
"team_name": "Los Angeles Lakers",
"player_id": 1631108,
"player_name": "Max Christie"
},
{
"team_name": "Los Angeles Lakers",
"player_id": 1629637,
"player_name": "Jaxson Hayes"
},
{
"team_name": "Los Angeles Lakers",
"player_id": 1627752,
"player_name": "Taurean Prince"
},
{
"team_name": "Los Angeles Lakers",
"player_id": 1630658,
"player_name": "Colin Castleton"
},
{
"team_name": "Los Angeles Lakers",
"player_id": 1630559,
"player_name": "Austin Reaves"
},
{
"team_name": "Los Angeles Lakers",
"player_id": 1628385,
"player_name": "Harry Giles III"
},
{
"team_name": "Los Angeles Lakers",
"player_id": 1641721,
"player_name": "Maxwell Lewis"
},
{
"team_name": "Los Angeles Lakers",
"player_id": 2544,
"player_name": "LeBron James"
},
{
"team_name": "Los Angeles Lakers",
"player_id": 203915,
"player_name": "Spencer Dinwiddie"
},
{
"team_name": "Los Angeles Lakers",
"player_id": 1629060,
"player_name": "Rui Hachimura"
},
{
"team_name": "Los Angeles Lakers",
"player_id": 1626174,
"player_name": "Christian Wood"
},
{
"team_name": "Miami Heat",
"player_id": 1626196,
"player_name": "Josh Richardson"
},
{
"team_name": "Miami Heat",
"player_id": 1626179,
"player_name": "Terry Rozier"
},
{
"team_name": "Miami Heat",
"player_id": 1626153,
"player_name": "Delon Wright"
},
{
"team_name": "Miami Heat",
"player_id": 1631107,
"player_name": "Nikola Jovic"
},
{
"team_name": "Miami Heat",
"player_id": 1631288,
"player_name": "Jamal Cain"
},
{
"team_name": "Miami Heat",
"player_id": 201988,
"player_name": "Patty Mills"
},
{
"team_name": "Miami Heat",
"player_id": 1631170,
"player_name": "Jaime Jaquez Jr."
},
{
"team_name": "Miami Heat",
"player_id": 1628389,
"player_name": "Bam Adebayo"
},
{
"team_name": "Miami Heat",
"player_id": 1629639,
"player_name": "Tyler Herro"
},
{
"team_name": "Miami Heat",
"player_id": 1631214,
"player_name": "Alondes Williams"
},
{
"team_name": "Miami Heat",
"player_id": 1628997,
"player_name": "Caleb Martin"
},
{
"team_name": "Miami Heat",
"player_id": 1631306,
"player_name": "Cole Swider"
},
{
"team_name": "Miami Heat",
"player_id": 202710,
"player_name": "Jimmy Butler"
},
{
"team_name": "Miami Heat",
"player_id": 1629312,
"player_name": "Haywood Highsmith"
},
{
"team_name": "Miami Heat",
"player_id": 1631115,
"player_name": "Orlando Robinson"
},
{
"team_name": "Miami Heat",
"player_id": 1628418,
"player_name": "Thomas Bryant"
},
{
"team_name": "Miami Heat",
"player_id": 201567,
"player_name": "Kevin Love"
},
{
"team_name": "Miami Heat",
"player_id": 1629130,
"player_name": "Duncan Robinson"
},
{
"team_name": "Toronto Raptors",
"player_id": 1642013,

View file

@ -82,6 +82,7 @@ def fetch_current_roster():
or name == 'Boston Celtics'
or name == 'Toronto Raptors'
or name == 'Los Angeles Lakers'
or name == 'Miami Heat'
):
roster = commonteamroster.CommonTeamRoster(team_id=t['id']).get_dict()
players_data = roster['resultSets'][0]
@ -111,7 +112,7 @@ def fetch_current_roster():
async def main():
# fetch_current_roster()
fetch_current_roster()
current_roster_from_file: list[PlayerInfo] = []
script_dir = Path(__file__).parent
filename = script_dir / 'current_nba_roster.json'

View file

@ -21,7 +21,13 @@ from graphiti_core.nodes import EpisodeType
logging.getLogger('langchain.callbacks.tracers.langchain').setLevel(logging.WARNING)
logging.getLogger('urllib3.connectionpool').setLevel(logging.ERROR)
DEFAULT_MODEL = 'gpt-4o-mini'
VALID_TEAMS = [
'Toronto Raptors',
'Boston Celtics',
'Golden State Warriors',
'Miami Heat',
'Los Angeles Lakers',
]
load_dotenv()
logging.basicConfig(
level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
@ -182,12 +188,6 @@ async def search_player_info(player_name: str):
return {'name': player_name, 'facts': relevant_facts}
@tool
async def propose_transfer(player_name: str, from_team: str, to_team: str, proposed_price: int):
"""Propose a player transfer from one team to another with a proposed price."""
return f'TRANSFER PROPOSAL: {to_team} wants to buy {player_name} from {from_team} for ${proposed_price:,}.'
@tool
async def execute_transfer(
player_name: str, from_team: str, to_team: str, price: int
@ -236,7 +236,7 @@ tools = [
# Define the team agent function
def create_team_agent(team_name: str, valid_teams: List[str]):
def create_team_agent(team_name: str):
llm = ChatOpenAI(temperature=0.3, model=DEFAULT_MODEL).bind(
response_format={'type': 'json_object'}
)
@ -245,7 +245,7 @@ def create_team_agent(team_name: str, valid_teams: List[str]):
Current event: {event}
Your task is to decide on an action based on the event. Use the available tools to gather information, but focus on making a decision quickly. If you think a player transfer would benefit your team, propose one following the guidelines below.
Ensure that you use the current budget info and the current state of your team to make the best decision.
Ensure that you use the current budget info and the current state of your team (use an appropriate tool to get the current state of your team) to make the best decision.
Current budget: ${budget}
Valid teams for transfers: {valid_teams}
@ -280,7 +280,7 @@ Do not ask for more information or clarification. Make a decision based on what
'team_name': team_name,
'event': state['event'],
'budget': team_data['budget'],
'valid_teams': ', '.join(valid_teams),
'valid_teams': ', '.join(VALID_TEAMS),
}
)
@ -289,8 +289,8 @@ Do not ask for more information or clarification. Make a decision based on what
if 'transfer_proposal' in json_result:
transfer_offer = json_result['transfer_proposal']
if (
transfer_offer['to_team'] not in valid_teams
or transfer_offer['from_team'] not in valid_teams
transfer_offer['to_team'] not in VALID_TEAMS
or transfer_offer['from_team'] not in VALID_TEAMS
):
logger.warning(f'Invalid transfer proposal: {transfer_offer}. Ignoring.')
transfer_offer = None
@ -302,29 +302,6 @@ Do not ask for more information or clarification. Make a decision based on what
return team_agent_function
def parse_transfer_proposal(proposal: str) -> Dict[str, Any]:
# Use regex to extract information
to_team_match = re.search(r'(.*?) wants to buy', proposal)
player_match = re.search(r'buy (.*?) from', proposal)
from_team_match = re.search(r'from (.*?) for', proposal)
price_match = re.search(r'\$([0-9,]+)', proposal)
if not all([to_team_match, player_match, from_team_match, price_match]):
raise ValueError(f'Unable to parse transfer proposal: {proposal}')
to_team = to_team_match.group(1)
player_name = player_match.group(1)
from_team = from_team_match.group(1)
proposed_price = int(price_match.group(1).replace(',', ''))
return {
'to_team': to_team,
'from_team': from_team,
'player_name': player_name,
'proposed_price': proposed_price,
}
async def process_event(state: SimulationState) -> SimulationState:
# await add_episode(state['event'])
return {
@ -407,8 +384,7 @@ simulator_prompt, simulator_llm = create_simulator_agent()
async def simulate_event(state: SimulationState) -> SimulationState:
teams = ['Toronto Raptors', 'Boston Celtics', 'Golden State Warriors']
teams_context = await fetch_all_teams_context.ainvoke({'teams': teams})
teams_context = await fetch_all_teams_context.ainvoke({'teams': VALID_TEAMS})
result = await simulator_llm.ainvoke(
simulator_prompt.format_prompt(teams_context=json.dumps(teams_context, indent=2))
@ -432,9 +408,8 @@ workflow = StateGraph(SimulationState)
# Add nodes
workflow.add_node('simulate_event', simulate_event)
workflow.add_node('process_event', process_event)
valid_teams = ['Toronto Raptors', 'Boston Celtics', 'Golden State Warriors']
for team in valid_teams:
workflow.add_node(f'agent_{team}', create_team_agent(team, valid_teams))
for team in VALID_TEAMS:
workflow.add_node(f'agent_{team}', create_team_agent(team))
workflow.add_node('process_transfers', process_transfers)
# Add edges
@ -442,10 +417,10 @@ workflow.add_edge(START, 'simulate_event')
workflow.add_edge('simulate_event', 'process_event')
# Add edges from process_event to all agent nodes
for team in valid_teams:
for team in VALID_TEAMS:
workflow.add_edge('process_event', f'agent_{team}')
for team in valid_teams:
for team in VALID_TEAMS:
workflow.add_edge(f'agent_{team}', 'process_transfers')
@ -469,14 +444,12 @@ print(app.get_graph().draw_mermaid())
async def run_simulation():
num_iterations = int(input('Enter the number of simulation iterations: '))
teams = {}
for t in VALID_TEAMS:
teams[t] = {'budget': 100_000_000}
initial_state = SimulationState(
messages=[],
teams={
'Toronto Raptors': {'budget': 100000000},
'Boston Celtics': {'budget': 100000000},
'Golden State Warriors': {'budget': 100000000},
},
teams=teams,
event='',
transfer_offers=[],
current_iteration=0,
@ -490,7 +463,7 @@ async def run_simulation():
print(f"{team_name} - Budget: ${team_data['budget']:,}")
print(f'Steps taken: {final_state["current_iteration"]}')
for event in final_state['all_events']:
print('/n')
print('\n')
print(event)
print('\n')