This commit is contained in:
paulpaliychuk 2024-08-28 15:17:51 -04:00
parent 8a72fe7cec
commit c511a819d3
3 changed files with 196 additions and 171 deletions

View file

@ -124,7 +124,7 @@ async def main():
players_grouped_by_team[team_name] = []
players_grouped_by_team[team_name].append(player)
for _, _ in players_grouped_by_team.items():
for _, players in players_grouped_by_team.items():
episodes: list[RawEpisode] = [
RawEpisode(
name=f'Player {player["player_id"]}',

View file

@ -1,3 +1,4 @@
import asyncio
import logging
import os
from datetime import datetime
@ -21,6 +22,7 @@ logger = logging.getLogger('nba_agent')
for name in logging.root.manager.loggerDict:
if name != 'nba_agent':
logging.getLogger(name).setLevel(logging.WARNING)
# Initialize Graphiti client
neo4j_uri = os.environ.get('NEO4J_URI', 'bolt://localhost:7687')
neo4j_user = os.environ.get('NEO4J_USER', 'neo4j')
@ -28,6 +30,87 @@ neo4j_password = os.environ.get('NEO4J_PASSWORD', 'password')
graphiti_client = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
class TeamAgent:
def __init__(self, team_name: str, tools: List[Any], budget: int = 100000000):
self.team_name = team_name
self.tools = tools
self.budget = budget
self.prompt = ChatPromptTemplate.from_messages(
[
(
'system',
f"""You are an AI assistant managing the {team_name} NBA team.
Your role is to make strategic decisions for your team, react to events, and interact with other teams.
Use the available tools to gather information and perform actions.
When an event occurs, decide how to react. You can:
1. Use tools to gather more information about players or team situations.
2. Propose player transfers (buy or sell) based on events and your team's needs.
3. Set transfer prices based on player performance and your available budget.
4. Negotiate with other teams on transfer prices.
Always consider what's best for your team in the long term. Be strategic and competitive.""",
),
('human', '{input}'),
MessagesPlaceholder(variable_name='agent_scratchpad'),
]
)
self.llm = ChatOpenAI(temperature=0.2)
self.agent = create_openai_functions_agent(self.llm, self.tools, self.prompt)
self.executor = AgentExecutor(agent=self.agent, tools=self.tools, verbose=True)
async def update_budget(self, amount: int):
self.budget += amount
return f"{self.team_name}'s new budget: ${self.budget:,}"
async def process_event(self, event: str):
result = await self.executor.ainvoke(
{
'input': f'Event: {event}\n\nCurrent team: {self.team_name}\nCurrent budget: ${self.budget:,}\n\nReact to this event. If there are transfer proposals, consider them and respond appropriately. Make decisions and take actions without asking for confirmation. Ensure transfer prices are realistic (in millions of dollars).',
'agent_scratchpad': [],
}
)
return result['output']
async def handle_tool_use(self, response):
if 'Action:' in response and 'Action Input:' in response:
action = response.split('Action:')[1].split('Action Input:')[0].strip()
action_input = response.split('Action Input:')[1].strip()
try:
tool = next(t for t in self.tools if t.name.lower() == action.lower())
result = await tool.ainvoke(**eval(action_input))
return f'Tool execution result: {result}'
except Exception as e:
return f'Error executing tool {action}: {e}'
return None
class AgentManager:
def __init__(self):
self.agents = {}
def add_agent(self, team_name: str, budget: int):
if team_name not in self.agents:
self.agents[team_name] = TeamAgent(team_name, tools, budget)
async def process_event(self, event: str):
responses = []
for team_name, agent in self.agents.items():
response = await agent.process_event(event)
responses.append(f'{team_name}: {response}')
return responses
async def add_episode(event_description: str):
"""Add a new episode to the Graphiti client."""
result = await graphiti_client.add_episode(
name='New Event',
episode_body=event_description,
source_description='User Input',
reference_time=datetime.now(),
source=EpisodeType.message,
)
return f"Episode '{event_description}' added successfully."
async def invoke_tool(tool_name: str, **kwargs):
tool = next(t for t in tools if t.name == tool_name)
return await tool.ainvoke(input=kwargs)
@ -40,11 +123,13 @@ def get_fact_string(edge):
@tool
async def get_team_roster(team_name: str):
"""Get the current roster for a specific team."""
search_result = await graphiti_client.search(f'{team_name.lower()}', num_results=30)
roster = []
for edge in search_result:
roster.append(get_fact_string(edge))
return roster
search_result = await graphiti_client.search(f'plays for {team_name}', num_results=30)
roster = [
edge.fact.split(' plays for ')[0]
for edge in search_result
if 'plays for' in edge.fact.lower()
]
return f"{team_name}'s roster: {', '.join(roster)}"
@tool
@ -59,186 +144,126 @@ async def search_player_info(player_name: str):
@tool
async def verify_transfer_conditions(player_name: str, from_team: str, to_team: str):
"""Verify conditions for a player transfer."""
from_roster = await invoke_tool('get_team_roster', team_name=from_team)
to_roster = await invoke_tool('get_team_roster', team_name=to_team)
player_info = await invoke_tool('search_player_info', player_name=player_name)
# Prepare context for LLM
context = f"""
Player: {player_name}
From Team: {from_team}
To Team: {to_team}
From Team Roster:
{from_roster}
To Team Roster:
{to_roster}
Player Info:
{player_info}
"""
# Use LLM to evaluate transfer conditions
llm = ChatOpenAI(temperature=0)
prompt = f"""
Based on the following information, determine if the transfer conditions are met for {player_name} to move from {from_team} to {to_team}.
Context:
{context}
Please consider the following conditions:
1. Is {from_team} a valid NBA team?
2. Is {to_team} a valid NBA team?
3. Is {player_name} currently on the roster of {from_team}?
4. Is there enough information about {player_name}?
Important: Players can be transferred multiple times, including back to teams they've played for before.
Provide a detailed analysis of each condition and conclude whether all conditions are met or not.
Your response should end with one of these two statements:
- TRANSFER APPROVED: All conditions are met.
- TRANSFER DENIED: [Reason for denial]
"""
response = await llm.ainvoke(prompt)
return response.content
async def propose_transfer(player_name: str, from_team: str, to_team: str, proposed_price: int):
"""Propose a player transfer from one team to another with a proposed price."""
return f'Transfer proposal: {to_team} wants to buy {player_name} from {from_team} for ${proposed_price:,}.'
@tool
async def transfer_player(player_name: str, from_team: str, to_team: str):
"""Transfer a player from one team to another."""
try:
# Verify transfer conditions
verification_result = await invoke_tool(
'verify_transfer_conditions',
player_name=player_name,
from_team=from_team,
to_team=to_team,
)
# Check if transfer is approved
if 'TRANSFER APPROVED' in verification_result:
logger.info(f'Transfer initiated: {player_name} from {from_team} to {to_team}')
# Add episode
await graphiti_client.add_episode(
name=f'Transfer {player_name}',
episode_body=f'{player_name} transferred from {from_team} to {to_team}',
source_description='Player Transfer',
reference_time=datetime.now(),
source=EpisodeType.message,
)
return f'Player {player_name} has been successfully transferred from {from_team} to {to_team}.'
else:
return f'Transfer denied: {verification_result}'
except Exception as e:
logger.error(f'Error in transfer_player: {str(e)}')
return 'An error occurred while transferring the player. Please try again later or contact support.'
async def respond_to_transfer(
player_name: str, from_team: str, to_team: str, response: str, counter_offer: int = None
):
"""Respond to a transfer proposal with an accept, reject, or counter-offer."""
response_message = f'{from_team} {response}s the transfer of {player_name} to {to_team}'
if counter_offer:
response_message += f' with a counter-offer of ${counter_offer:,}'
return f'Transfer response: {response_message}.'
# Main agent setup
tools = [get_team_roster, search_player_info, verify_transfer_conditions, transfer_player]
@tool
async def execute_transfer(player_name: str, from_team: str, to_team: str, final_price: int):
"""Execute a player transfer from one team to another with the final agreed price."""
from_agent = agent_manager.agents.get(from_team)
to_agent = agent_manager.agents.get(to_team)
prompt = ChatPromptTemplate.from_messages(
[
(
'system',
"""You are an AI assistant for NBA team management. Your role is to help users manage team rosters, transfer players, and provide information about players and teams. Use the available tools to gather information and perform actions.
if not from_agent or not to_agent:
return 'One or both teams not found.'
When transferring players, always verify the following before proceeding:
1. The player is currently on the roster of the 'from' team.
2. Both the 'from' and 'to' teams are valid NBA teams.
Use the get_team_roster and search_player_info tools to verify this information. Only proceed with the transfer if all conditions are met.
if to_agent.budget < final_price:
return f"{to_team} doesn't have enough budget for this transfer."
IMPORTANT: Only use the information retrieved from the tools. Do not make assumptions or use information that hasn't been explicitly provided by the tools or the user. If you're unsure about any information, use the appropriate tool to verify it.""",
),
MessagesPlaceholder(variable_name='chat_history'),
('human', '{input}'),
MessagesPlaceholder(variable_name='agent_scratchpad'),
]
)
# Update budgets
await from_agent.update_budget(final_price)
await to_agent.update_budget(-final_price)
llm = ChatOpenAI(temperature=0)
agent = create_openai_functions_agent(llm, tools, prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools, handle_parsing_errors=True)
# Graph setup
workflow = StateGraph(MessagesState)
async def agent_node(state):
user_input = state['messages'][-1].content if state['messages'] else ''
chat_history = state['messages'][:-1] # All messages except the last one
result = await agent_executor.ainvoke({'input': user_input, 'chat_history': chat_history})
return {
'messages': state['messages'] + [AIMessage(content=result['output'], name='Manager')],
'agent_scratchpad': [],
}
workflow.add_node('agent', agent_node)
workflow.set_entry_point('agent')
workflow.add_edge('agent', '__end__')
app = workflow.compile()
# Run function
async def run_workflow(input_text: str, chat_history: List[Dict[str, Any]] = []):
result = await app.ainvoke(
{
'messages': [
*[
HumanMessage(content=msg['content'])
if msg['type'] == 'human'
else AIMessage(content=msg['content'])
for msg in chat_history
],
HumanMessage(content=input_text),
],
},
# Add the transfer as an episode
await add_episode(
event_description=f'{player_name} transferred from {from_team} to {to_team} for ${final_price:,}'
)
return f'Transfer executed: {player_name} has been transferred from {from_team} to {to_team} for ${final_price:,}.'
# Log only the latest human input and AI response
logger.info(f'Human: {input_text}\n')
latest_ai_message = next(
(message for message in reversed(result['messages']) if isinstance(message, AIMessage)),
None,
)
if latest_ai_message:
logger.info(f'AI: {latest_ai_message.content}\n')
@tool
async def check_team_budget(team_name: str):
"""Check the current budget of a team."""
agent = agent_manager.agents.get(team_name)
if agent:
return f"{team_name}'s current budget: ${agent.budget:,}"
return f'Team {team_name} not found.'
return result['messages']
# Update the tools list
tools = [
get_team_roster,
search_player_info,
propose_transfer,
respond_to_transfer,
execute_transfer,
check_team_budget,
]
agent_manager = AgentManager()
# Add your teams here
agent_manager.add_agent('Toronto Raptors', budget=100000000)
agent_manager.add_agent('Boston Celtics', budget=100000000)
agent_manager.add_agent('Golden State Warriors', budget=100000000)
# Main loop
async def main():
chat_history = []
print('Welcome to the NBA Team Management Simulation!')
print('Enter events, and watch how the teams react.')
print("Type 'quit' to exit the simulation.\n")
while True:
user_input = input("Enter your request (or 'quit' to exit): ")
user_input = input("Enter an event (or 'quit' to exit): ")
if user_input.lower() == 'quit':
break
messages = await run_workflow(user_input, chat_history)
# Update chat history with only the latest human input and AI response
chat_history = [
{'type': 'human', 'content': user_input},
{
'type': 'ai',
'content': messages[-1].content
if isinstance(messages[-1], AIMessage)
else messages[-2].content,
},
]
try:
# Add the event as an episode
result = await add_episode(event_description=user_input)
print(result)
except Exception as e:
print(f'Error adding episode: {e}')
# Process the event with all team agents for multiple rounds
transfer_proposals = []
for round in range(3): # You can adjust the number of rounds as needed
print(f'\nRound {round + 1}:')
for team_name, agent in agent_manager.agents.items():
print(f'\n{team_name} reaction:')
try:
response = await agent.process_event(user_input)
print(response)
# Handle tool use
tool_result = await agent.handle_tool_use(response)
if tool_result:
print(tool_result)
# If a transfer was proposed or responded to, add it to the list
if (
'Transfer proposal:' in tool_result
or 'Transfer response:' in tool_result
):
transfer_proposals.append(tool_result)
except Exception as e:
print(f'Error processing event for {team_name}: {e}')
# After each round, update the user_input to include transfer proposals
if transfer_proposals:
user_input = (
f'Previous event: {user_input}\nTransfer proposals and responses:\n'
+ '\n'.join(transfer_proposals)
)
else:
break # If no new proposals or responses, end the rounds
print('\n')
if __name__ == '__main__':
import asyncio
asyncio.run(main())

View file

@ -63,12 +63,12 @@ class SearchResults(BaseModel):
async def hybrid_search(
driver: AsyncDriver,
embedder,
query: str,
timestamp: datetime,
config: SearchConfig,
center_node_uuid: str | None = None,
driver: AsyncDriver,
embedder,
query: str,
timestamp: datetime,
config: SearchConfig,
center_node_uuid: str | None = None,
) -> SearchResults:
start = time()