wip
This commit is contained in:
parent
8a72fe7cec
commit
c511a819d3
3 changed files with 196 additions and 171 deletions
|
|
@ -124,7 +124,7 @@ async def main():
|
|||
players_grouped_by_team[team_name] = []
|
||||
players_grouped_by_team[team_name].append(player)
|
||||
|
||||
for _, _ in players_grouped_by_team.items():
|
||||
for _, players in players_grouped_by_team.items():
|
||||
episodes: list[RawEpisode] = [
|
||||
RawEpisode(
|
||||
name=f'Player {player["player_id"]}',
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
|
@ -21,6 +22,7 @@ logger = logging.getLogger('nba_agent')
|
|||
for name in logging.root.manager.loggerDict:
|
||||
if name != 'nba_agent':
|
||||
logging.getLogger(name).setLevel(logging.WARNING)
|
||||
|
||||
# Initialize Graphiti client
|
||||
neo4j_uri = os.environ.get('NEO4J_URI', 'bolt://localhost:7687')
|
||||
neo4j_user = os.environ.get('NEO4J_USER', 'neo4j')
|
||||
|
|
@ -28,6 +30,87 @@ neo4j_password = os.environ.get('NEO4J_PASSWORD', 'password')
|
|||
graphiti_client = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
|
||||
|
||||
|
||||
class TeamAgent:
|
||||
def __init__(self, team_name: str, tools: List[Any], budget: int = 100000000):
|
||||
self.team_name = team_name
|
||||
self.tools = tools
|
||||
self.budget = budget
|
||||
self.prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
'system',
|
||||
f"""You are an AI assistant managing the {team_name} NBA team.
|
||||
Your role is to make strategic decisions for your team, react to events, and interact with other teams.
|
||||
Use the available tools to gather information and perform actions.
|
||||
When an event occurs, decide how to react. You can:
|
||||
1. Use tools to gather more information about players or team situations.
|
||||
2. Propose player transfers (buy or sell) based on events and your team's needs.
|
||||
3. Set transfer prices based on player performance and your available budget.
|
||||
4. Negotiate with other teams on transfer prices.
|
||||
Always consider what's best for your team in the long term. Be strategic and competitive.""",
|
||||
),
|
||||
('human', '{input}'),
|
||||
MessagesPlaceholder(variable_name='agent_scratchpad'),
|
||||
]
|
||||
)
|
||||
self.llm = ChatOpenAI(temperature=0.2)
|
||||
self.agent = create_openai_functions_agent(self.llm, self.tools, self.prompt)
|
||||
self.executor = AgentExecutor(agent=self.agent, tools=self.tools, verbose=True)
|
||||
|
||||
async def update_budget(self, amount: int):
|
||||
self.budget += amount
|
||||
return f"{self.team_name}'s new budget: ${self.budget:,}"
|
||||
|
||||
async def process_event(self, event: str):
|
||||
result = await self.executor.ainvoke(
|
||||
{
|
||||
'input': f'Event: {event}\n\nCurrent team: {self.team_name}\nCurrent budget: ${self.budget:,}\n\nReact to this event. If there are transfer proposals, consider them and respond appropriately. Make decisions and take actions without asking for confirmation. Ensure transfer prices are realistic (in millions of dollars).',
|
||||
'agent_scratchpad': [],
|
||||
}
|
||||
)
|
||||
return result['output']
|
||||
|
||||
async def handle_tool_use(self, response):
|
||||
if 'Action:' in response and 'Action Input:' in response:
|
||||
action = response.split('Action:')[1].split('Action Input:')[0].strip()
|
||||
action_input = response.split('Action Input:')[1].strip()
|
||||
try:
|
||||
tool = next(t for t in self.tools if t.name.lower() == action.lower())
|
||||
result = await tool.ainvoke(**eval(action_input))
|
||||
return f'Tool execution result: {result}'
|
||||
except Exception as e:
|
||||
return f'Error executing tool {action}: {e}'
|
||||
return None
|
||||
|
||||
|
||||
class AgentManager:
|
||||
def __init__(self):
|
||||
self.agents = {}
|
||||
|
||||
def add_agent(self, team_name: str, budget: int):
|
||||
if team_name not in self.agents:
|
||||
self.agents[team_name] = TeamAgent(team_name, tools, budget)
|
||||
|
||||
async def process_event(self, event: str):
|
||||
responses = []
|
||||
for team_name, agent in self.agents.items():
|
||||
response = await agent.process_event(event)
|
||||
responses.append(f'{team_name}: {response}')
|
||||
return responses
|
||||
|
||||
|
||||
async def add_episode(event_description: str):
|
||||
"""Add a new episode to the Graphiti client."""
|
||||
result = await graphiti_client.add_episode(
|
||||
name='New Event',
|
||||
episode_body=event_description,
|
||||
source_description='User Input',
|
||||
reference_time=datetime.now(),
|
||||
source=EpisodeType.message,
|
||||
)
|
||||
return f"Episode '{event_description}' added successfully."
|
||||
|
||||
|
||||
async def invoke_tool(tool_name: str, **kwargs):
|
||||
tool = next(t for t in tools if t.name == tool_name)
|
||||
return await tool.ainvoke(input=kwargs)
|
||||
|
|
@ -40,11 +123,13 @@ def get_fact_string(edge):
|
|||
@tool
|
||||
async def get_team_roster(team_name: str):
|
||||
"""Get the current roster for a specific team."""
|
||||
search_result = await graphiti_client.search(f'{team_name.lower()}', num_results=30)
|
||||
roster = []
|
||||
for edge in search_result:
|
||||
roster.append(get_fact_string(edge))
|
||||
return roster
|
||||
search_result = await graphiti_client.search(f'plays for {team_name}', num_results=30)
|
||||
roster = [
|
||||
edge.fact.split(' plays for ')[0]
|
||||
for edge in search_result
|
||||
if 'plays for' in edge.fact.lower()
|
||||
]
|
||||
return f"{team_name}'s roster: {', '.join(roster)}"
|
||||
|
||||
|
||||
@tool
|
||||
|
|
@ -59,186 +144,126 @@ async def search_player_info(player_name: str):
|
|||
|
||||
|
||||
@tool
|
||||
async def verify_transfer_conditions(player_name: str, from_team: str, to_team: str):
|
||||
"""Verify conditions for a player transfer."""
|
||||
from_roster = await invoke_tool('get_team_roster', team_name=from_team)
|
||||
to_roster = await invoke_tool('get_team_roster', team_name=to_team)
|
||||
player_info = await invoke_tool('search_player_info', player_name=player_name)
|
||||
|
||||
# Prepare context for LLM
|
||||
context = f"""
|
||||
Player: {player_name}
|
||||
From Team: {from_team}
|
||||
To Team: {to_team}
|
||||
|
||||
From Team Roster:
|
||||
{from_roster}
|
||||
|
||||
To Team Roster:
|
||||
{to_roster}
|
||||
|
||||
Player Info:
|
||||
{player_info}
|
||||
"""
|
||||
|
||||
# Use LLM to evaluate transfer conditions
|
||||
llm = ChatOpenAI(temperature=0)
|
||||
prompt = f"""
|
||||
Based on the following information, determine if the transfer conditions are met for {player_name} to move from {from_team} to {to_team}.
|
||||
|
||||
Context:
|
||||
{context}
|
||||
|
||||
Please consider the following conditions:
|
||||
1. Is {from_team} a valid NBA team?
|
||||
2. Is {to_team} a valid NBA team?
|
||||
3. Is {player_name} currently on the roster of {from_team}?
|
||||
4. Is there enough information about {player_name}?
|
||||
|
||||
Important: Players can be transferred multiple times, including back to teams they've played for before.
|
||||
|
||||
Provide a detailed analysis of each condition and conclude whether all conditions are met or not.
|
||||
Your response should end with one of these two statements:
|
||||
- TRANSFER APPROVED: All conditions are met.
|
||||
- TRANSFER DENIED: [Reason for denial]
|
||||
"""
|
||||
|
||||
response = await llm.ainvoke(prompt)
|
||||
|
||||
return response.content
|
||||
async def propose_transfer(player_name: str, from_team: str, to_team: str, proposed_price: int):
|
||||
"""Propose a player transfer from one team to another with a proposed price."""
|
||||
return f'Transfer proposal: {to_team} wants to buy {player_name} from {from_team} for ${proposed_price:,}.'
|
||||
|
||||
|
||||
@tool
|
||||
async def transfer_player(player_name: str, from_team: str, to_team: str):
|
||||
"""Transfer a player from one team to another."""
|
||||
try:
|
||||
# Verify transfer conditions
|
||||
verification_result = await invoke_tool(
|
||||
'verify_transfer_conditions',
|
||||
player_name=player_name,
|
||||
from_team=from_team,
|
||||
to_team=to_team,
|
||||
)
|
||||
|
||||
# Check if transfer is approved
|
||||
if 'TRANSFER APPROVED' in verification_result:
|
||||
logger.info(f'Transfer initiated: {player_name} from {from_team} to {to_team}')
|
||||
|
||||
# Add episode
|
||||
await graphiti_client.add_episode(
|
||||
name=f'Transfer {player_name}',
|
||||
episode_body=f'{player_name} transferred from {from_team} to {to_team}',
|
||||
source_description='Player Transfer',
|
||||
reference_time=datetime.now(),
|
||||
source=EpisodeType.message,
|
||||
)
|
||||
|
||||
return f'Player {player_name} has been successfully transferred from {from_team} to {to_team}.'
|
||||
else:
|
||||
return f'Transfer denied: {verification_result}'
|
||||
except Exception as e:
|
||||
logger.error(f'Error in transfer_player: {str(e)}')
|
||||
return 'An error occurred while transferring the player. Please try again later or contact support.'
|
||||
async def respond_to_transfer(
|
||||
player_name: str, from_team: str, to_team: str, response: str, counter_offer: int = None
|
||||
):
|
||||
"""Respond to a transfer proposal with an accept, reject, or counter-offer."""
|
||||
response_message = f'{from_team} {response}s the transfer of {player_name} to {to_team}'
|
||||
if counter_offer:
|
||||
response_message += f' with a counter-offer of ${counter_offer:,}'
|
||||
return f'Transfer response: {response_message}.'
|
||||
|
||||
|
||||
# Main agent setup
|
||||
tools = [get_team_roster, search_player_info, verify_transfer_conditions, transfer_player]
|
||||
@tool
|
||||
async def execute_transfer(player_name: str, from_team: str, to_team: str, final_price: int):
|
||||
"""Execute a player transfer from one team to another with the final agreed price."""
|
||||
from_agent = agent_manager.agents.get(from_team)
|
||||
to_agent = agent_manager.agents.get(to_team)
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
'system',
|
||||
"""You are an AI assistant for NBA team management. Your role is to help users manage team rosters, transfer players, and provide information about players and teams. Use the available tools to gather information and perform actions.
|
||||
if not from_agent or not to_agent:
|
||||
return 'One or both teams not found.'
|
||||
|
||||
When transferring players, always verify the following before proceeding:
|
||||
1. The player is currently on the roster of the 'from' team.
|
||||
2. Both the 'from' and 'to' teams are valid NBA teams.
|
||||
|
||||
Use the get_team_roster and search_player_info tools to verify this information. Only proceed with the transfer if all conditions are met.
|
||||
if to_agent.budget < final_price:
|
||||
return f"{to_team} doesn't have enough budget for this transfer."
|
||||
|
||||
IMPORTANT: Only use the information retrieved from the tools. Do not make assumptions or use information that hasn't been explicitly provided by the tools or the user. If you're unsure about any information, use the appropriate tool to verify it.""",
|
||||
),
|
||||
MessagesPlaceholder(variable_name='chat_history'),
|
||||
('human', '{input}'),
|
||||
MessagesPlaceholder(variable_name='agent_scratchpad'),
|
||||
]
|
||||
)
|
||||
# Update budgets
|
||||
await from_agent.update_budget(final_price)
|
||||
await to_agent.update_budget(-final_price)
|
||||
|
||||
llm = ChatOpenAI(temperature=0)
|
||||
agent = create_openai_functions_agent(llm, tools, prompt)
|
||||
agent_executor = AgentExecutor(agent=agent, tools=tools, handle_parsing_errors=True)
|
||||
|
||||
# Graph setup
|
||||
workflow = StateGraph(MessagesState)
|
||||
|
||||
|
||||
async def agent_node(state):
|
||||
user_input = state['messages'][-1].content if state['messages'] else ''
|
||||
chat_history = state['messages'][:-1] # All messages except the last one
|
||||
result = await agent_executor.ainvoke({'input': user_input, 'chat_history': chat_history})
|
||||
return {
|
||||
'messages': state['messages'] + [AIMessage(content=result['output'], name='Manager')],
|
||||
'agent_scratchpad': [],
|
||||
}
|
||||
|
||||
|
||||
workflow.add_node('agent', agent_node)
|
||||
workflow.set_entry_point('agent')
|
||||
workflow.add_edge('agent', '__end__')
|
||||
|
||||
app = workflow.compile()
|
||||
|
||||
|
||||
# Run function
|
||||
async def run_workflow(input_text: str, chat_history: List[Dict[str, Any]] = []):
|
||||
result = await app.ainvoke(
|
||||
{
|
||||
'messages': [
|
||||
*[
|
||||
HumanMessage(content=msg['content'])
|
||||
if msg['type'] == 'human'
|
||||
else AIMessage(content=msg['content'])
|
||||
for msg in chat_history
|
||||
],
|
||||
HumanMessage(content=input_text),
|
||||
],
|
||||
},
|
||||
# Add the transfer as an episode
|
||||
await add_episode(
|
||||
event_description=f'{player_name} transferred from {from_team} to {to_team} for ${final_price:,}'
|
||||
)
|
||||
return f'Transfer executed: {player_name} has been transferred from {from_team} to {to_team} for ${final_price:,}.'
|
||||
|
||||
# Log only the latest human input and AI response
|
||||
logger.info(f'Human: {input_text}\n')
|
||||
|
||||
latest_ai_message = next(
|
||||
(message for message in reversed(result['messages']) if isinstance(message, AIMessage)),
|
||||
None,
|
||||
)
|
||||
if latest_ai_message:
|
||||
logger.info(f'AI: {latest_ai_message.content}\n')
|
||||
@tool
|
||||
async def check_team_budget(team_name: str):
|
||||
"""Check the current budget of a team."""
|
||||
agent = agent_manager.agents.get(team_name)
|
||||
if agent:
|
||||
return f"{team_name}'s current budget: ${agent.budget:,}"
|
||||
return f'Team {team_name} not found.'
|
||||
|
||||
return result['messages']
|
||||
|
||||
# Update the tools list
|
||||
tools = [
|
||||
get_team_roster,
|
||||
search_player_info,
|
||||
propose_transfer,
|
||||
respond_to_transfer,
|
||||
execute_transfer,
|
||||
check_team_budget,
|
||||
]
|
||||
|
||||
agent_manager = AgentManager()
|
||||
|
||||
# Add your teams here
|
||||
agent_manager.add_agent('Toronto Raptors', budget=100000000)
|
||||
agent_manager.add_agent('Boston Celtics', budget=100000000)
|
||||
agent_manager.add_agent('Golden State Warriors', budget=100000000)
|
||||
|
||||
|
||||
# Main loop
|
||||
async def main():
|
||||
chat_history = []
|
||||
print('Welcome to the NBA Team Management Simulation!')
|
||||
print('Enter events, and watch how the teams react.')
|
||||
print("Type 'quit' to exit the simulation.\n")
|
||||
|
||||
while True:
|
||||
user_input = input("Enter your request (or 'quit' to exit): ")
|
||||
user_input = input("Enter an event (or 'quit' to exit): ")
|
||||
if user_input.lower() == 'quit':
|
||||
break
|
||||
messages = await run_workflow(user_input, chat_history)
|
||||
# Update chat history with only the latest human input and AI response
|
||||
chat_history = [
|
||||
{'type': 'human', 'content': user_input},
|
||||
{
|
||||
'type': 'ai',
|
||||
'content': messages[-1].content
|
||||
if isinstance(messages[-1], AIMessage)
|
||||
else messages[-2].content,
|
||||
},
|
||||
]
|
||||
|
||||
try:
|
||||
# Add the event as an episode
|
||||
result = await add_episode(event_description=user_input)
|
||||
print(result)
|
||||
except Exception as e:
|
||||
print(f'Error adding episode: {e}')
|
||||
|
||||
# Process the event with all team agents for multiple rounds
|
||||
transfer_proposals = []
|
||||
for round in range(3): # You can adjust the number of rounds as needed
|
||||
print(f'\nRound {round + 1}:')
|
||||
for team_name, agent in agent_manager.agents.items():
|
||||
print(f'\n{team_name} reaction:')
|
||||
try:
|
||||
response = await agent.process_event(user_input)
|
||||
print(response)
|
||||
|
||||
# Handle tool use
|
||||
tool_result = await agent.handle_tool_use(response)
|
||||
if tool_result:
|
||||
print(tool_result)
|
||||
|
||||
# If a transfer was proposed or responded to, add it to the list
|
||||
if (
|
||||
'Transfer proposal:' in tool_result
|
||||
or 'Transfer response:' in tool_result
|
||||
):
|
||||
transfer_proposals.append(tool_result)
|
||||
|
||||
except Exception as e:
|
||||
print(f'Error processing event for {team_name}: {e}')
|
||||
|
||||
# After each round, update the user_input to include transfer proposals
|
||||
if transfer_proposals:
|
||||
user_input = (
|
||||
f'Previous event: {user_input}\nTransfer proposals and responses:\n'
|
||||
+ '\n'.join(transfer_proposals)
|
||||
)
|
||||
else:
|
||||
break # If no new proposals or responses, end the rounds
|
||||
|
||||
print('\n')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
|
|
|
|||
|
|
@ -63,12 +63,12 @@ class SearchResults(BaseModel):
|
|||
|
||||
|
||||
async def hybrid_search(
|
||||
driver: AsyncDriver,
|
||||
embedder,
|
||||
query: str,
|
||||
timestamp: datetime,
|
||||
config: SearchConfig,
|
||||
center_node_uuid: str | None = None,
|
||||
driver: AsyncDriver,
|
||||
embedder,
|
||||
query: str,
|
||||
timestamp: datetime,
|
||||
config: SearchConfig,
|
||||
center_node_uuid: str | None = None,
|
||||
) -> SearchResults:
|
||||
start = time()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue