import logging import os from datetime import datetime from typing import Any, Dict, List from dotenv import load_dotenv from langchain.agents import AgentExecutor, create_openai_functions_agent from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.messages import AIMessage, HumanMessage, SystemMessage from langchain_core.tools import tool from langchain_openai import ChatOpenAI from langgraph.graph import MessagesState, StateGraph from langgraph.prebuilt import ToolInvocation, ToolNode from graphiti_core import Graphiti from graphiti_core.nodes import EpisodeType load_dotenv() logging.basicConfig(level=logging.INFO, format='%(message)s') logger = logging.getLogger('nba_agent') for name in logging.root.manager.loggerDict: if name != 'nba_agent': logging.getLogger(name).setLevel(logging.WARNING) # Initialize Graphiti client neo4j_uri = os.environ.get('NEO4J_URI', 'bolt://localhost:7687') neo4j_user = os.environ.get('NEO4J_USER', 'neo4j') neo4j_password = os.environ.get('NEO4J_PASSWORD', 'password') graphiti_client = Graphiti(neo4j_uri, neo4j_user, neo4j_password) async def invoke_tool(tool_name: str, **kwargs): tool = next(t for t in tools if t.name == tool_name) return await tool.ainvoke(input=kwargs) def get_fact_string(edge): return f'{edge.fact} {edge.valid_at or edge.created_at}' @tool async def get_team_roster(team_name: str): """Get the current roster for a specific team.""" search_result = await graphiti_client.search(f'{team_name.lower()}', num_results=30) roster = [] for edge in search_result: roster.append(get_fact_string(edge)) return roster @tool async def search_player_info(player_name: str): """Search for information about a specific player.""" search_result = await graphiti_client.search(f'{player_name}', num_results=30) player_info = { 'name': player_name, 'facts': [get_fact_string(edge) for edge in search_result], } return player_info @tool async def verify_transfer_conditions(player_name: str, from_team: str, to_team: str): """Verify conditions for a player transfer.""" from_roster = await invoke_tool('get_team_roster', team_name=from_team) to_roster = await invoke_tool('get_team_roster', team_name=to_team) player_info = await invoke_tool('search_player_info', player_name=player_name) # Prepare context for LLM context = f""" Player: {player_name} From Team: {from_team} To Team: {to_team} From Team Roster: {from_roster} To Team Roster: {to_roster} Player Info: {player_info} """ # Use LLM to evaluate transfer conditions llm = ChatOpenAI(temperature=0) prompt = f""" Based on the following information, determine if the transfer conditions are met for {player_name} to move from {from_team} to {to_team}. Context: {context} Please consider the following conditions: 1. Is {from_team} a valid NBA team? 2. Is {to_team} a valid NBA team? 3. Is {player_name} currently on the roster of {from_team}? 4. Is there enough information about {player_name}? Important: Players can be transferred multiple times, including back to teams they've played for before. Provide a detailed analysis of each condition and conclude whether all conditions are met or not. Your response should end with one of these two statements: - TRANSFER APPROVED: All conditions are met. - TRANSFER DENIED: [Reason for denial] """ response = await llm.ainvoke(prompt) return response.content @tool async def transfer_player(player_name: str, from_team: str, to_team: str): """Transfer a player from one team to another.""" try: # Verify transfer conditions verification_result = await invoke_tool( 'verify_transfer_conditions', player_name=player_name, from_team=from_team, to_team=to_team, ) # Check if transfer is approved if 'TRANSFER APPROVED' in verification_result: logger.info(f'Transfer initiated: {player_name} from {from_team} to {to_team}') # Add episode await graphiti_client.add_episode( name=f'Transfer {player_name}', episode_body=f'{player_name} transferred from {from_team} to {to_team}', source_description='Player Transfer', reference_time=datetime.now(), source=EpisodeType.message, ) return f'Player {player_name} has been successfully transferred from {from_team} to {to_team}.' else: return f'Transfer denied: {verification_result}' except Exception as e: logger.error(f'Error in transfer_player: {str(e)}') return 'An error occurred while transferring the player. Please try again later or contact support.' # Main agent setup tools = [get_team_roster, search_player_info, verify_transfer_conditions, transfer_player] prompt = ChatPromptTemplate.from_messages( [ ( 'system', """You are an AI assistant for NBA team management. Your role is to help users manage team rosters, transfer players, and provide information about players and teams. Use the available tools to gather information and perform actions. When transferring players, always verify the following before proceeding: 1. The player is currently on the roster of the 'from' team. 2. Both the 'from' and 'to' teams are valid NBA teams. Use the get_team_roster and search_player_info tools to verify this information. Only proceed with the transfer if all conditions are met. IMPORTANT: Only use the information retrieved from the tools. Do not make assumptions or use information that hasn't been explicitly provided by the tools or the user. If you're unsure about any information, use the appropriate tool to verify it.""", ), MessagesPlaceholder(variable_name='chat_history'), ('human', '{input}'), MessagesPlaceholder(variable_name='agent_scratchpad'), ] ) llm = ChatOpenAI(temperature=0) agent = create_openai_functions_agent(llm, tools, prompt) agent_executor = AgentExecutor(agent=agent, tools=tools, handle_parsing_errors=True) # Graph setup workflow = StateGraph(MessagesState) async def agent_node(state): user_input = state['messages'][-1].content if state['messages'] else '' chat_history = state['messages'][:-1] # All messages except the last one result = await agent_executor.ainvoke({'input': user_input, 'chat_history': chat_history}) return { 'messages': state['messages'] + [AIMessage(content=result['output'], name='Manager')], 'agent_scratchpad': [], } workflow.add_node('agent', agent_node) workflow.set_entry_point('agent') workflow.add_edge('agent', '__end__') app = workflow.compile() # Run function async def run_workflow(input_text: str, chat_history: List[Dict[str, Any]] = []): result = await app.ainvoke( { 'messages': [ *[ HumanMessage(content=msg['content']) if msg['type'] == 'human' else AIMessage(content=msg['content']) for msg in chat_history ], HumanMessage(content=input_text), ], }, ) # Log only the latest human input and AI response logger.info(f'Human: {input_text}\n') latest_ai_message = next( (message for message in reversed(result['messages']) if isinstance(message, AIMessage)), None, ) if latest_ai_message: logger.info(f'AI: {latest_ai_message.content}\n') return result['messages'] # Main loop async def main(): chat_history = [] while True: user_input = input("Enter your request (or 'quit' to exit): ") if user_input.lower() == 'quit': break messages = await run_workflow(user_input, chat_history) # Update chat history with only the latest human input and AI response chat_history = [ {'type': 'human', 'content': user_input}, { 'type': 'ai', 'content': messages[-1].content if isinstance(messages[-1], AIMessage) else messages[-2].content, }, ] if __name__ == '__main__': import asyncio asyncio.run(main())