Search Unity

Chess game? Do I have to EndEpisode() after I SetReward? No examples of Chess with ML? :(

Discussion in 'ML-Agents' started by markashburner, Jan 23, 2021.

  1. markashburner

    markashburner

    Joined:
    Aug 14, 2015
    Posts:
    212
    Hi I am trying to integrate ML Agents with a chess game I've developed.

    Except I am not sure what I should do after every move that ML Agent makes....After every move should I SetReward and then EndEpisode? Or can I just SetReward or AddReward after every move and only EndEpisode at the end of the game when it is either a Win for the agent or a draw?

    But basically what I am asking is...Do I have to EndEpisode() after I SetReward?
     
  2. celion_unity

    celion_unity

    Joined:
    Jun 12, 2019
    Posts:
    289
    The only difference between AddReward and SetReward is if you call them multiple times in the same step. So in general, it's fine to call SetReward on a step and not call EndEpisode that step.

    For zero-sum games like chess, you should treat each game as an episode, and give +1 reward to the winner and -1 to the loser (and then end the episode).

    You could also try giving small rewards (using AddReward) for capturing pieces, such as .01 for a pawn, .02 for a knight, etc. This wouldn't affect the ELO rating but might teach the agent that capturing pieces is valuable (it also might not help or actually hurt the agent in the long run).
     
  3. markashburner

    markashburner

    Joined:
    Aug 14, 2015
    Posts:
    212

    Yeah that's what I had...I used Add Reward 0.1 for pawn etc...for capturing pieces and then AddReward +1 for winning a game...I never called SetReward...only called SetReward when penalising an agent when it made the same move too many times....it worked ok....but could never work out how to checkmate...

    so then I changed AddReward to 0.01 for pawn etc...Aggghhhh I give up lol...There is barely anyone to help out with machine learning...and I am not getting any good results...was getting some good results when I set it to AddReward only...and never calling SetReward...and it was ok with taking pieces but not great at the end game and working out checkmate...

    Basically all I did was gather all the legal moves it could make and set the agent using one Discrete branch with a branch size of a 120 and then masked all the moves it couldn't make.

    Too be honest...there isn't even one clear example of Unity machine learning with board games...which is ridiculous. I always thought the best way to learn machine learning was with board games...so I am surprised that Unity hasn't provided any examples....and then there is virtually no one to help you out...Machine learning is so exciting...but barely anyone in the Unity community is willing to help you out with it...

    Anyways this is my config file.

    Code (CSharp):
    1. ---
    2. behaviors:
    3.   ChessMLAgent:
    4.     trainer_type: ppo
    5.     hyperparameters:
    6.       batch_size: 32
    7.       buffer_size: 2048
    8.       learning_rate: 3.0e-4
    9.       beta: 5.0e-4
    10.       epsilon: 0.3
    11.       lambd: 0.99
    12.       num_epoch: 3
    13.       learning_rate_schedule: constant
    14.     network_settings:
    15.       normalize: false
    16.       hidden_units: 128
    17.       num_layers: 3
    18.     reward_signals:
    19.       extrinsic:
    20.         gamma: 0.99
    21.         strength: 1.0
    22.     keep_checkpoints: 5
    23.     max_steps: 10000000
    24.     time_horizon: 32
    25.     summary_freq: 1000
    26.     threaded: true
    27.     self_play:
    28.       save_steps: 2000
    29.       team_change: 10000
    30.       swap_steps: 1000
    31.       window: 30
    32.       play_against_latest_model_ratio: 0.5
    33.       initial_elo: 1200.0

    This has been an extremely frustrating endeavour with no clear examples provided by Unity on how to setup a simple board game like chess with machine learning...none whatsoever...
     
  4. markashburner

    markashburner

    Joined:
    Aug 14, 2015
    Posts:
    212
    So do you reckon it will work better if I add no reward when the agent takes a piece? And only SetReward when the Agent wins a game? I haven't tried that yet? But to be honest...my spirits in learning about Unity MLAgents has been shattered...I've been working on this for 3 or 4 days now...and I haven't achieved any significant results...the Agent is subpar...and using normal algorithms instead of ML has been significantly better. Really disappointed with Unity MLAgents.
     
  5. markashburner

    markashburner

    Joined:
    Aug 14, 2015
    Posts:
    212
    Code (CSharp):
    1. using System;
    2. using System.Collections;
    3. using System.Collections.Generic;
    4. using System.Security.AccessControl;
    5. using UnityEngine;
    6. using Unity.MLAgents;
    7. using Unity.MLAgents.Actuators;
    8. using Unity.MLAgents.Policies;
    9. using Unity.MLAgents.Sensors;
    10. using Random = System.Random;
    11. using System.Linq;
    12.  
    13. public class ChessMLAgent : Agent
    14. {
    15.     public ChessAgentManager manager;
    16.     public cgChessBoardScript board;
    17.     public bool isWhite;
    18.     public byte searchDepthWeak = 4;
    19.     public byte searchDepthStrong = 4;
    20.     public byte searchDepthEndGame =4;
    21.     public int branchSize;
    22.     public cgSimpleMove currentMove;
    23.     public List<cgSquareScript> squares = new List<cgSquareScript>();
    24.     public List<cgSimpleMove> legalMoves = new List<cgSimpleMove>();
    25.     public List<int> branchMask = new List<int>();
    26.     public List<int> impossibleMoves = new List<int>();
    27.  
    28.    
    29.     public override void OnEpisodeBegin()
    30.     {
    31.         legalMoves = board._abstractBoard.findStrictLegalMoves(isWhite);
    32.     }
    33.  
    34.  
    35.  
    36.  
    37.     public override void Heuristic(in ActionBuffers actionsOut)
    38.     {
    39.  
    40.         if (manager.mode == MLBoardMode.MLAgentVsMLAgent)
    41.         {
    42.             Debug.Log("Heuristic Action number is " + actionsOut.DiscreteActions[0]);
    43.             currentMove = legalMoves[0];
    44.         }          
    45.        
    46.     }
    47.  
    48.     public override void CollectObservations(VectorSensor sensor)
    49.     {
    50.        
    51.          if (isWhite)
    52.          {
    53.                //moves = board._abstractBoard.moves;
    54.            
    55.                if (legalMoves.Count != 0)
    56.                {
    57.                    for (int i = 0; i < legalMoves.Count; i++)
    58.                    {
    59.                        sensor.AddOneHotObservation(i, legalMoves.Count);
    60.                    }
    61.                }
    62.          }
    63.          else
    64.          {
    65.                //moves = board._abstractBoard.moves;
    66.            
    67.                if (legalMoves.Count != 0)
    68.                {
    69.                    for (int i = 0; i < legalMoves.Count; i++)
    70.                    {
    71.                        sensor.AddOneHotObservation(i, legalMoves.Count);
    72.                    }
    73.                }          
    74.          }          
    75.        
    76.  
    77.      
    78.     }
    79.    
    80.     public override void OnActionReceived(ActionBuffers actions)
    81.     {
    82.         if (manager.mode == MLBoardMode.MLAgentVsMLAgent)
    83.         {
    84.                  if (impossibleMoves.Count != branchSize)
    85.                  {
    86.                    
    87.                      if(actions.DiscreteActions[0] < legalMoves.Count) currentMove = legalMoves[actions.DiscreteActions[0]];
    88.                  }
    89.                  else
    90.                  {
    91.                      Debug.Log("Board " + manager.agentManagerId + " All moves masked under OnActionsReceived!");
    92.                      if (board.chessPieces.Count != 0)
    93.                      {
    94.                          for (int i = 0; i < board.chessPieces.Count; i++)
    95.                          {
    96.                              board.chessPieces[i].GetComponent<cgChessPieceScript>().sqaureMoves.Clear();
    97.                              board.chessPieces[i].GetComponent<cgChessPieceScript>().numberOfMovesMade = 0;            
    98.                      
    99.                          }
    100.                      }
    101.                      bool isChecked = board._abstractBoard.isChecked(isWhite);
    102.  
    103.                     if (isChecked)
    104.                     {
    105.                         if (isWhite)
    106.                         {
    107.                             AddReward(-1f);
    108.                             manager.Agent2.AddReward(1f);
    109.                             board._whiteWins += 1;
    110.                             board.whiteWin.text = "White Games won: " + board._whiteWins;
    111.                            
    112.                             EndEpisode();
    113.                             manager.Agent2.EndEpisode();
    114.                         }
    115.                         else
    116.                         {
    117.                             AddReward(1f);
    118.                             manager.Agent1.AddReward(1f);
    119.                             board._blackWins += 1;
    120.                             board.blackWin.text = "Black Games won: " + board._blackWins;                          
    121.                             EndEpisode();
    122.                             manager.Agent1.EndEpisode();                        
    123.                         }
    124.                     }
    125.  
    126.                     if (!isChecked)
    127.                     {
    128.                         board.ResetBoard();
    129.                         manager.Agent1.EndEpisode();
    130.                         manager.Agent2.EndEpisode();
    131.                     }
    132.                    
    133.                
    134.                  }
    135.            
    136.         }
    137.     }
    138.  
    139.     public override void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
    140.     {
    141.        
    142.           branchMask.Clear();
    143.      
    144.             legalMoves = board._abstractBoard.findStrictLegalMoves(isWhite);
    145.              int alpha = int.MinValue;
    146.              int beta = int.MaxValue;
    147.  
    148.             if (legalMoves.Count != 0)
    149.             {
    150.                  for (int i = 0; i < legalMoves.Count; i++)
    151.                  {
    152.                      cgSimpleMove possibleMove = legalMoves[i];
    153.  
    154.                      byte depth = (cgValueModifiers.AlphaBeta_Strong_Delineation < possibleMove.positionalVal ? searchDepthStrong : searchDepthWeak);
    155.                                          if (legalMoves.Count < 10) depth = searchDepthEndGame;
    156.                    
    157.                      legalMoves[i].val = board.getEngine._alfaBeta(board._abstractBoard, legalMoves[i], depth, alpha,
    158.                          beta, false);
    159.      
    160.                  }
    161.  
    162.                  if (legalMoves.Count > 60)
    163.                  {
    164.                        var newList = legalMoves.OrderByDescending(x => x.val).Take(60);
    165.                        legalMoves = newList.ToList();            
    166.                  }
    167.  
    168.  
    169.                 for (int i = 0; i < legalMoves.Count; i++)
    170.                 {
    171.                     branchMask.Add(i);
    172.                 }
    173.                    
    174.                 int movesAbove = branchSize - branchMask.Count;
    175.                 impossibleMoves.Clear();
    176.                 for (int i = 0; i < movesAbove; i++)
    177.                 {
    178.                     impossibleMoves.Add(branchMask.Count + i);
    179.                 }
    180.                    
    181.                 if (impossibleMoves.Count == branchSize)
    182.                 {
    183.                     Debug.Log("Board " + manager.agentManagerId + " All moves masked!");
    184.                    
    185.                     if (board.chessPieces.Count != 0)
    186.                     {
    187.                         for (int i = 0; i < board.chessPieces.Count; i++)
    188.                         {
    189.                             board.chessPieces[i].GetComponent<cgChessPieceScript>().sqaureMoves.Clear();
    190.                             board.chessPieces[i].GetComponent<cgChessPieceScript>().numberOfMovesMade = 0;            
    191.                                      
    192.                         }
    193.                     }
    194.                     bool isChecked = board._abstractBoard.isChecked(isWhite);
    195.  
    196.                     if (isChecked)
    197.                     {
    198.                         if (isWhite)
    199.                         {
    200.                             AddReward(-1f);
    201.                             manager.Agent2.AddReward(1f);
    202.                             board._whiteWins += 1;
    203.                             board.whiteWin.text = "White Games won: " + board._whiteWins;
    204.                            
    205.                             EndEpisode();
    206.                             manager.Agent2.EndEpisode();
    207.                         }
    208.                         else
    209.                         {
    210.                             AddReward(1f);
    211.                             manager.Agent1.AddReward(1f);
    212.                             board._blackWins += 1;
    213.                             board.blackWin.text = "Black Games won: " + board._blackWins;                          
    214.                             EndEpisode();
    215.                             manager.Agent1.EndEpisode();                        
    216.                         }
    217.                     }
    218.  
    219.                     if (!isChecked)
    220.                     {
    221.                         board.ResetBoard();
    222.                         manager.Agent1.EndEpisode();
    223.                         manager.Agent2.EndEpisode();
    224.                     }
    225.                                            
    226.                    
    227.                 }
    228.                 else
    229.                 {
    230.                     actionMask.WriteMask(0, GetMovementSize());
    231.                    
    232.                 }
    233.  
    234.             }            
    235.        
    236.  
    237.  
    238.     }
    239.    
    240.      private IEnumerable<int> GetMovementSize()
    241.      {
    242.          int movesAbove = branchSize - branchMask.Count;
    243.         impossibleMoves.Clear();
    244.         for (int i = 0; i < movesAbove; i++)
    245.         {
    246.             impossibleMoves.Add(branchMask.Count + i);
    247.         }
    248.        
    249.  
    250.         return impossibleMoves.ToArray();
    251.      }
    252. }
    253.  
    This is my Chess Agent script
     

    Attached Files:

  6. markashburner

    markashburner

    Joined:
    Aug 14, 2015
    Posts:
    212
    Code (CSharp):
    1.  if (_getPieceOn(_abstractBoard.SquareNames[move.to]) != null && !(move is cgCastlingMove))
    2.             {  
    3.              
    4.                     if (_getPieceOn(_abstractBoard.SquareNames[move.to]).type == cgChessPieceScript.Type.BlackPawn)
    5.                     {
    6.                         whiteAgent.AddReward(0.1f);
    7.                         blackAgent.AddReward(-0.1f);
    8.  
    9.                     }
    10.                     if (_getPieceOn(_abstractBoard.SquareNames[move.to]).type == cgChessPieceScript.Type.BlackKnight)
    11.                     {
    12.                         whiteAgent.AddReward(0.3f);
    13.                         blackAgent.AddReward(-0.3f);
    14.  
    15.                     }
    16.                     if (_getPieceOn(_abstractBoard.SquareNames[move.to]).type == cgChessPieceScript.Type.BlackBishop)
    17.                     {
    18.                         whiteAgent.AddReward(0.3f);
    19.                         blackAgent.AddReward(-0.3f);
    20.  
    21.                     }
    22.                     if (_getPieceOn(_abstractBoard.SquareNames[move.to]).type == cgChessPieceScript.Type.BlackRook)
    23.                     {
    24.                         whiteAgent.AddReward(0.5f);
    25.                         blackAgent.AddReward(-0.5f);
    26.  
    27.                     }
    28.                     if (_getPieceOn(_abstractBoard.SquareNames[move.to]).type == cgChessPieceScript.Type.BlackQueen)
    29.                     {
    30.                         whiteAgent.AddReward(0.9f);
    31.                         blackAgent.AddReward(-0.9f);
    32.  
    33.                     }                      
    34.              
    35.              
    36.              
    37.                     if (_getPieceOn(_abstractBoard.SquareNames[move.to]).type == cgChessPieceScript.Type.WhitePawn)
    38.                     {
    39.                         whiteAgent.AddReward(-0.1f);
    40.                         blackAgent.AddReward(0.1f);
    41.  
    42.                     }
    43.                     if (_getPieceOn(_abstractBoard.SquareNames[move.to]).type == cgChessPieceScript.Type.WhiteKnight)
    44.                     {
    45.                         whiteAgent.AddReward(-0.3f);
    46.                         blackAgent.AddReward(0.3f);
    47.  
    48.                     }
    49.                     if (_getPieceOn(_abstractBoard.SquareNames[move.to]).type == cgChessPieceScript.Type.WhiteBishop)
    50.                     {
    51.                         whiteAgent.AddReward(-0.3f);
    52.                         blackAgent.AddReward(0.3f);
    53.  
    54.                     }
    55.                     if (_getPieceOn(_abstractBoard.SquareNames[move.to]).type == cgChessPieceScript.Type.WhiteRook)
    56.                     {
    57.                         whiteAgent.AddReward(-0.5f);
    58.                         blackAgent.AddReward(0.5f);
    59.  
    60.                     }
    61.                     if (_getPieceOn(_abstractBoard.SquareNames[move.to]).type == cgChessPieceScript.Type.WhiteQueen)
    62.                     {
    63.                         whiteAgent.AddReward(-0.9f);
    64.                         blackAgent.AddReward(0.9f);
    65.  
    66.                     }
    67.  
    68.                     if(_getPieceOn(_abstractBoard.SquareNames[move.to]).type == cgChessPieceScript.Type.BlackKing)
    69.                     {
    70.                         Debug.Log("Black King piece was taken!");
    71.                              
    72.                             if (chessPieces.Count != 0)
    73.                             {
    74.                                 for (int i = 0; i < chessPieces.Count; i++)
    75.                                 {
    76.                                     chessPieces[i].GetComponent<cgChessPieceScript>().sqaureMoves.Clear();
    77.                                     chessPieces[i].GetComponent<cgChessPieceScript>().numberOfMovesMade = 0;          
    78.        
    79.                                 }
    80.                             }                    
    81.                              
    82.                             _abstractBoard.revert();
    83.  
    84.                             return;
    85.                     }
    86.  
    87.                     if(_getPieceOn(_abstractBoard.SquareNames[move.to]).type == cgChessPieceScript.Type.WhiteKing)
    88.                     {
    89.                         Debug.Log("White King piece was taken!");
    90.                              
    91.                         if (chessPieces.Count != 0)
    92.                         {
    93.                             for (int i = 0; i < chessPieces.Count; i++)
    94.                             {
    95.                                 chessPieces[i].GetComponent<cgChessPieceScript>().sqaureMoves.Clear();
    96.                                 chessPieces[i].GetComponent<cgChessPieceScript>().numberOfMovesMade = 0;          
    97.        
    98.                             }
    99.                         }                    
    100.                              
    101.                         _abstractBoard.revert();
    102.                         return;
    103.                     }
    104.                     _setDeadPiece(_getPieceOn(_abstractBoard.SquareNames[move.to]));
    105.  
    106.             }
     
  7. markashburner

    markashburner

    Joined:
    Aug 14, 2015
    Posts:
    212
    Code (CSharp):
    1.  
    2.     public void _gameMLOver( bool whiteWins, bool blackWins, ChessMLAgent whiteAgent, ChessMLAgent blackAgent, bool movedTooManyTimes, bool isWhite)
    3.     {
    4.         string gameOverString = "Game Over. ";
    5.        
    6.         if (whiteWins)
    7.         {
    8.             whitePlayer.AddReward(1f);
    9.             blackPlayer.AddReward(-1f);
    10.             _whiteWins += 1;
    11.             whiteWin.text = "White Games won: " + _whiteWins;
    12.          
    13.                                      
    14.             whitePlayer.EndEpisode();
    15.             blackPlayer.EndEpisode();          
    16.                                                  
    17.             ResetBoard();
    18.                          // gameOverString = "White Wins!";
    19.         }
    20.         else if( blackWins)
    21.         {
    22.             blackPlayer.AddReward(1f);
    23.             whitePlayer.AddReward(-1f);
    24.             _blackWins += 1;
    25.             blackWin.text = "Black Games won: " + _blackWins;
    26.            
    27.            
    28.             whitePlayer.EndEpisode();
    29.             blackPlayer.EndEpisode();          
    30.                                                  
    31.             ResetBoard();
    32.                          
    33.         }    
    34.  
    35.         if (!blackWins && !whiteWins)
    36.         {
    37.             _draws += 1;
    38.            
    39.             draws.text = "Draws: " + _draws;
    40.             if (!movedTooManyTimes)
    41.             {
    42.                 gameOverString = "Its a draw!";
    43.                 whiteAgent.AddReward(-0.75f);
    44.                 blackAgent.AddReward(-0.75f);
    45.              
    46.                 whiteAgent.EndEpisode();
    47.                 blackAgent.EndEpisode();
    48.                 ResetBoard();              
    49.             }
    50.             else
    51.             {
    52.                 if (isWhite)
    53.                 {
    54.                     whiteAgent.SetReward(-1f);
    55.                     blackAgent.AddReward(-0.25f);
    56.              
    57.                     whiteAgent.EndEpisode();
    58.                     blackAgent.EndEpisode();  
    59.                     ResetBoard();              
    60.  
    61.                 }
    62.                 else
    63.                 {
    64.                     whiteAgent.AddReward(-0.25f);
    65.                     blackAgent.SetReward(-1f);
    66.                                  
    67.                     whiteAgent.EndEpisode();
    68.                     blackAgent.EndEpisode();
    69.                     ResetBoard();              
    70.  
    71.                 }
    72.             }
    73.  
    74.             if (chessPieces.Count != 0)
    75.             {
    76.                 for (int i = 0; i < chessPieces.Count; i++)
    77.                 {
    78.                     if (chessPieces[i] != null)
    79.                     {
    80.                         chessPieces[i].GetComponent<cgChessPieceScript>().sqaureMoves.Clear();
    81.                         chessPieces[i].GetComponent<cgChessPieceScript>().numberOfMovesMade = 0;  
    82.                     }
    83.          
    84.  
    85.                 }
    86.             }
    87.  
    88.          
    89.         }
    90.  
    91.  
    92.     }  
     
  8. celion_unity

    celion_unity

    Joined:
    Jun 12, 2019
    Posts:
    289
    This isn't quite what I said. With what you said, an agent could get a higher reward than the opponent by capturing pieces, but still lose. The self-play system treats the agent with the higher reward as the winner, so you'd be "encouraging" to grab a bunch of pieces instead of necessarily winning. With smaller rewards for capturing, it learns that capturing pieces is better than not capturing pieces, but winning is still the most important thing.

    I wouldn't recommend that; if it actually repeats the moves enough to trigger a stalemate, then end the match and give both sides 0.

    The algorithms that we use don't know how to "look ahead" for things like this.

    It's not exactly a board game, but we do have an example of a Match-3 game in ML-Agents.

    I'm sorry if anything in our documentation gave this impression. Reinforcement learning (which is most of what ML-Agents is) is good for continuous environments where the "rules" for going from one step to the next aren't always known, which is the case in a lot of games or robotics. RL can also be used board games (for example, AlphaGo, AlphaZero, and MuZero). But AlphaGo and AlphaZero need to be told the "rules" of the game so that they can "simulate" possible future states; MuZero doesn't need the rules, but we haven't tried to implement it and I'm unsure how much more computing power it would take.

    If you're interested in learning more "traditional" Game AI (not machine learning based), something like Monte Carlo Tree Search (MCTS) is a very powerful technique, and (if I understand correctly) it's also part of the basis for AlphaGo and AlphaZero.
     
    markashburner likes this.
  9. markashburner

    markashburner

    Joined:
    Aug 14, 2015
    Posts:
    212
    Thanks heaps for the reply...I will teach it the rules of chess...
     
  10. ThadJunior

    ThadJunior

    Joined:
    Jun 6, 2020
    Posts:
    1
    Hi @markashburner ,

    I’m looking into using MLAgents for a chess game. However, I get the feeling that the 64 positions with 64 moves + different behaviour for the pieces + different board situations + small chance of winning compared to stalemate, threefold repetition, etc. is a very difficult environment to learn the AI to win by reinforcement learning.

    Could you share how your project is going?
    I’m very interested in your progress/achievements.