Search Unity

  1. Unity 6 Preview is now available. To find out what's new, have a look at our Unity 6 Preview blog post.
    Dismiss Notice
  2. Unity is excited to announce that we will be collaborating with TheXPlace for a summer game jam from June 13 - June 19. Learn more.
    Dismiss Notice
  3. Dismiss Notice

Question Why is my Pac-Man agent unable to learn the movement?

Discussion in 'ML-Agents' started by Baum_1234, Mar 2, 2024.

  1. Baum_1234

    Baum_1234

    Joined:
    Nov 4, 2022
    Posts:
    1
    Hi, I'm currently working on a Pac-Man agent to play through the game Pac-Man which is based on this GitHub project. I'm using Unity version 2022.3.20f1 (LTS) and Mlagents release 21. This is my current repository.

    I have defined simple rewards for eating single pellets (0.01 for a normal pellet and 0.05 for a big one) and for eating all pellets (1.0). I also have a negative reward (-0.001), which is added in each step, with the background that the agent should take fewer steps. There is also a death reward (-1.0) if PacMan is eaten by ghosts. However, I don't pay attention to the death reward at the beginning, as I'm training without ghosts for the time being.

    The active episode ends when PacMan dies, when all the pellets have been eaten, i.e. the round is over, or when the agent has not managed to complete the round within 4000 steps. The eaten pellets are only reset in the event of death or successful completion of the round, not if the maximum number of steps is exceeded.

    I am currently doing the observations with a camera sensor. I had previously tried this with a RayCast and normal Vector Observations. All so far without success.

    Now let's get into my problem in more detail. I think that the agent is not able to learn the movement properly. I have already started several attempts and let the agent train for a long time. I once had him train 8 million steps without ghosts. (This run was without the penalty each step). I think I trained it for 16 hours. The result was a model that had severe overfitting and simply always took the same path. This model did not understand how the game works and how to move.
    upload_2024-3-2_14-35-37.png
    The other training attempt was over 8 million steps, also about 17 hours. Now with penalty steps. Also I added a ghost in that run. The agent didn't manage to learn how to avoid the ghost in that time. It looks like the decisions on how the agent moves are completely random.
    upload_2024-3-2_14-40-51.png
    I watched a video in which the creator also created a PacMan AI. He has the same problem but for me the agent doesn't get stuck in a corner as often. His solution was to change the movement from a global state to a local state from PacMan's perspective rather than from the top-down perspective. I tried to implement this but couldn't get it to work.

    I also don't think that training for a 2D application like PacMan takes that long (over 10 hours). I train with 4 instances at the same time, all running at 20x speed.
    I have now tried many other rewards and configurations, all without success. Do you perhaps have any idea what this could be? Is it really because of the movement or something else?

    Code (CSharp):
    1. using System;
    2. using Unity.MLAgents;
    3. using Unity.MLAgents.Actuators;
    4. using Unity.MLAgents.Sensors;
    5. using UnityEngine;
    6.  
    7. [RequireComponent(typeof(Movement))]
    8. public class PacManAgent : Agent
    9. {
    10.     [SerializeField] private AnimatedSprite deathSequence;
    11.     private SpriteRenderer spriteRenderer;
    12.     private Movement movement;
    13.  
    14.     private new Collider2D collider;
    15.  
    16.     private GameManager gamemanager;
    17.     private int currentAction;
    18.  
    19.     // Constants
    20.     private const float PelletReward = 0.01f;
    21.     private const float PowerPelletReward = 0.05f;
    22.     private const float NegativeRewardPerStep = -0.001f;
    23.     private const float DeathReward = -1f;
    24.  
    25.     private const float WinReward = 1f;
    26.  
    27.     public override void Initialize()
    28.     {
    29.         spriteRenderer = GetComponent<SpriteRenderer>();
    30.         movement = GetComponent<Movement>();
    31.         collider = GetComponent<Collider2D>();
    32.         gamemanager = FindObjectOfType<GameManager>();
    33.         currentAction = 3;
    34.     }
    35.  
    36.     private void OnTriggerEnter2D(Collider2D collision)
    37.     {
    38.         if (collision.CompareTag("Pellet"))
    39.         {
    40.             // timeSinceLastPellet = 0f;
    41.             AddReward(PelletReward);
    42.         }
    43.  
    44.         if (collision.CompareTag("PowerPellet"))
    45.         {
    46.             // timeSinceLastPellet = 0f;
    47.             AddReward(PowerPelletReward);
    48.         }
    49.     }
    50.  
    51.     // For the gamemanager
    52.     public void GiveWinReward()
    53.     {
    54.         AddReward(WinReward);
    55.     }
    56.  
    57.     // For the gamemanager
    58.     public void GiveDeathReward()
    59.     {
    60.         AddReward(DeathReward);
    61.     }
    62.  
    63.     public override void OnActionReceived(ActionBuffers actions)
    64.     {
    65.         // Convert discrete actions to movement directions
    66.         int movementAction = actions.DiscreteActions[0];
    67.         Vector2 direction = Vector2.zero;
    68.  
    69.         switch (movementAction)
    70.         {
    71.             case 0:
    72.                 direction = Vector2.up;
    73.                 break;
    74.             case 1:
    75.                 direction = Vector2.down;
    76.                 break;
    77.             case 2:
    78.                 direction = Vector2.left;
    79.                 break;
    80.             case 3:
    81.                 direction = Vector2.right;
    82.                 break;
    83.         }
    84.  
    85.         if (!gamemanager.GameIsWon)
    86.         {
    87.             AddReward(NegativeRewardPerStep);
    88.         }
    89.  
    90.         movement.SetDirection(direction);
    91.  
    92.         float angle = Mathf.Atan2(movement.direction.y, movement.direction.x);
    93.         transform.rotation = Quaternion.AngleAxis(angle * Mathf.Rad2Deg, Vector3.forward);
    94.     }
    95.  
    96.     public override void Heuristic(in ActionBuffers actionsOut)
    97.     {
    98.         // Allows manual control for testing purposes
    99.         ActionSegment<int> discreteActions = actionsOut.DiscreteActions;
    100.  
    101.         if (Input.GetKey(KeyCode.W) || Input.GetKey(KeyCode.UpArrow))
    102.         {
    103.             currentAction = 0;
    104.         }
    105.         else if (Input.GetKey(KeyCode.S) || Input.GetKey(KeyCode.DownArrow))
    106.         {
    107.             currentAction = 1;
    108.         }
    109.         else if (Input.GetKey(KeyCode.A) || Input.GetKey(KeyCode.LeftArrow))
    110.         {
    111.             currentAction = 2;
    112.         }
    113.         else if (Input.GetKey(KeyCode.D) || Input.GetKey(KeyCode.RightArrow))
    114.         {
    115.             currentAction = 3;
    116.         }
    117.  
    118.         discreteActions[0] = currentAction;
    119.     }
    120.  
    121.     public override void OnEpisodeBegin()
    122.     {
    123.         enabled = true;
    124.         spriteRenderer.enabled = true;
    125.         collider.enabled = true;
    126.         deathSequence.enabled = false;
    127.         movement.ResetState();
    128.         gamemanager.GameIsWon = false;
    129.         gameObject.SetActive(true);
    130.     }
    131.  
    132.     public void DeathSequence()
    133.     {
    134.         enabled = false;
    135.         spriteRenderer.enabled = false;
    136.         collider.enabled = false;
    137.         movement.enabled = false;
    138.         deathSequence.enabled = true;
    139.         deathSequence.Restart();
    140.     }
    141. }

    Code (CSharp):
    1. using System.Net;
    2. using UnityEngine;
    3. using UnityEngine.UI;
    4. using UnityEngine.UIElements;
    5.  
    6. public class GameManager : MonoBehaviour
    7. {
    8.     public static GameManager Instance { get; private set; }
    9.  
    10.     [SerializeField] private Ghost[] ghosts;
    11.  
    12.     [SerializeField] private PacManAgent pacmanagent;
    13.  
    14.     [SerializeField] private Transform pellets;
    15.  
    16.     [SerializeField] private Text gameOverText;
    17.  
    18.     [SerializeField] private Text scoreText;
    19.  
    20.     [SerializeField] private Text livesText;
    21.  
    22.     public int totalLives;
    23.  
    24.     private int ghostMultiplier = 1;
    25.     private int lives;
    26.     private int score = 0;
    27.     public bool GameIsWon = false;
    28.  
    29.     public int Lives => lives;
    30.     public int Score => score;
    31.  
    32.     private void Awake()
    33.     {
    34.         if (Instance != null)
    35.         {
    36.             DestroyImmediate(gameObject);
    37.         }
    38.         else
    39.         {
    40.             Instance = this;
    41.             DontDestroyOnLoad(gameObject);
    42.         }
    43.     }
    44.  
    45.     private void Start()
    46.     {
    47.         pacmanagent = pacmanagent.GetComponent<PacManAgent>();
    48.         SetLives(totalLives);
    49.         NewGame();
    50.     }
    51.  
    52.     private void Update()
    53.     {
    54.         if (lives <= 0)
    55.         {
    56.             NewGame();
    57.         }
    58.     }
    59.  
    60.     private void NewGame()
    61.     {
    62.         SetScore(0);
    63.         SetLives(totalLives);
    64.         NewRound();
    65.     }
    66.  
    67.     private void NewRound()
    68.     {
    69.         gameOverText.enabled = false;
    70.  
    71.         foreach (Transform pellet in pellets)
    72.         {
    73.             pellet.gameObject.SetActive(true);
    74.         }
    75.  
    76.         ResetState();
    77.     }
    78.  
    79.     private void ResetState()
    80.     {
    81.         for (int i = 0; i < ghosts.Length; i++)
    82.         {
    83.             ghosts[i].ResetState();
    84.         }
    85.         // pacmanagent.startTime = Time.time;
    86.         pacmanagent.OnEpisodeBegin();
    87.     }
    88.  
    89.     private void GameOver()
    90.     {
    91.         gameOverText.enabled = true;
    92.  
    93.         for (int i = 0; i < ghosts.Length; i++)
    94.         {
    95.             ghosts[i].gameObject.SetActive(false);
    96.         }
    97.         pacmanagent.gameObject.SetActive(false);
    98.     }
    99.  
    100.     private void SetLives(int lives)
    101.     {
    102.         this.lives = lives;
    103.         livesText.text = "x" + lives.ToString();
    104.     }
    105.  
    106.     private void SetScore(int score)
    107.     {
    108.         this.score = score;
    109.         scoreText.text = score.ToString().PadLeft(2, '0');
    110.     }
    111.  
    112.     public void PacmanEaten()
    113.     {
    114.         pacmanagent.GiveDeathReward();
    115.         pacmanagent.EndEpisode();
    116.         pacmanagent.DeathSequence();
    117.  
    118.         SetLives(lives - 1);
    119.  
    120.         if (lives > 0)
    121.         {
    122.             ResetState();
    123.         }
    124.         else
    125.         {
    126.             GameOver();
    127.         }
    128.     }
    129.  
    130.     public void GhostEaten(Ghost ghost)
    131.     {
    132.         int points = ghost.points * ghostMultiplier;
    133.         SetScore(score + points);
    134.  
    135.         ghostMultiplier++;
    136.     }
    137.  
    138.     public void PelletEaten(Pellet pellet)
    139.     {
    140.         pellet.gameObject.SetActive(false);
    141.  
    142.         SetScore(score + pellet.points);
    143.  
    144.         if (!HasRemainingPellets())
    145.         {
    146.             pacmanagent.GiveWinReward();
    147.             pacmanagent.EndEpisode();
    148.             pacmanagent.gameObject.SetActive(false);
    149.             GameIsWon = true;
    150.             NewRound();
    151.         }
    152.     }
    153.  
    154.     public void PowerPelletEaten(PowerPellet pellet)
    155.     {
    156.         for (int i = 0; i < ghosts.Length; i++)
    157.         {
    158.             ghosts[i].frightened.Enable(pellet.duration);
    159.         }
    160.  
    161.         PelletEaten(pellet);
    162.         CancelInvoke(nameof(ResetGhostMultiplier));
    163.         Invoke(nameof(ResetGhostMultiplier), pellet.duration);
    164.     }
    165.  
    166.     public bool HasRemainingPellets()
    167.     {
    168.         foreach (Transform pellet in pellets)
    169.         {
    170.             if (pellet.gameObject.activeSelf)
    171.             {
    172.                 return true;
    173.             }
    174.         }
    175.  
    176.         return false;
    177.     }
    178.  
    179.     public Transform GetPellets()
    180.     {
    181.         return this.pellets;
    182.     }
    183.  
    184.     private void ResetGhostMultiplier()
    185.     {
    186.         ghostMultiplier = 1;
    187.     }
    188. }

    Code (CSharp):
    1. using UnityEngine;
    2.  
    3. [RequireComponent(typeof(Rigidbody2D))]
    4. public class Movement : MonoBehaviour
    5. {
    6.     public float speed = 8f;
    7.     public float speedMultiplier = 1f;
    8.     public Vector2 initialDirection;
    9.     public LayerMask obstacleLayer;
    10.  
    11.     public new Rigidbody2D rigidbody { get; private set; }
    12.     public Vector2 direction { get; private set; }
    13.     public Vector2 nextDirection { get; private set; }
    14.     public Vector3 startingPosition { get; private set; }
    15.  
    16.     private void Awake()
    17.     {
    18.         rigidbody = GetComponent<Rigidbody2D>();
    19.         startingPosition = transform.position;
    20.     }
    21.  
    22.     private void Start()
    23.     {
    24.         ResetState();
    25.     }
    26.  
    27.     public void ResetState()
    28.     {
    29.         speedMultiplier = 1f;
    30.         direction = initialDirection;
    31.         nextDirection = Vector2.zero;
    32.         transform.position = startingPosition;
    33.         rigidbody.isKinematic = false;
    34.         enabled = true;
    35.     }
    36.  
    37.     private void Update()
    38.     {
    39.         // Try to move in the next direction while it's queued to make movements
    40.         // more responsive
    41.         if (nextDirection != Vector2.zero)
    42.         {
    43.             SetDirection(nextDirection);
    44.         }
    45.     }
    46.  
    47.     private void FixedUpdate()
    48.     {
    49.         Vector2 position = rigidbody.position;
    50.         Vector2 translation = direction * speed * speedMultiplier * Time.fixedDeltaTime;
    51.  
    52.         rigidbody.MovePosition(position + translation);
    53.     }
    54.  
    55.     public void SetDirection(Vector2 direction, bool forced = false)
    56.     {
    57.         // Only set the direction if the tile in that direction is available
    58.         // otherwise we set it as the next direction so it'll automatically be
    59.         // set when it does become available
    60.         if (forced || !Occupied(direction))
    61.         {
    62.             this.direction = direction;
    63.             nextDirection = Vector2.zero;
    64.         }
    65.         else
    66.         {
    67.             nextDirection = direction;
    68.         }
    69.     }
    70.  
    71.     private bool Occupied(Vector2 direction)
    72.     {
    73.         // If no collider is hit then there is no obstacle in that direction
    74.         RaycastHit2D hit = Physics2D.BoxCast(
    75.             transform.position,
    76.             Vector2.one * 0.75f,
    77.             0f,
    78.             direction,
    79.             1.5f,
    80.             obstacleLayer
    81.         );
    82.         return hit.collider != null;
    83.     }
    84. }
    85.  
     
    Last edited: Mar 2, 2024
    MarkKisker likes this.