Search Unity

Question Agent Rapidly gets worse after doing well

Discussion in 'ML-Agents' started by AjSpeed, May 4, 2023.

  1. AjSpeed

    AjSpeed

    Joined:
    Mar 18, 2021
    Posts:
    2
    I have an ai for a 2D platformer I have been training for a while now. Every iteration I do I encounter a problem where the agent has get success for an episode then completely reverts to a much worse version. I have added the code, hyperparameters, and tensorboard graph here. If you need any other information just let me know.
    image_2023-05-03_205613032.png image_2023-05-03_205755040.png
    Code (CSharp):
    1. using System.Collections;
    2. using System.Collections.Generic;
    3. using UnityEngine;
    4. using Unity.MLAgents;
    5. using Unity.MLAgents.Sensors;
    6. using Unity.MLAgents.Actuators;
    7. using UnityEngine.Tilemaps;
    8. public class PlayerAgent : Agent
    9. {
    10.     Rigidbody2D rb;
    11.     CircleCollider2D cc;
    12.     SpriteRenderer ps;
    13.  
    14.     GameObject[] coins;
    15.     List<Vector3> coinPos = new List<Vector3>();
    16.     GameObject[] changers;
    17.     List<Vector3> changerPos = new List<Vector3>();
    18.  
    19.     public Transform endLevel;
    20.     public Transform startPos;
    21.  
    22.     bool playerDied;
    23.     bool hasMoved;
    24.  
    25.     float distanceToEnd;
    26.     float t;
    27.     float oldPos;
    28.     float newPos;
    29.     // Start distance to end
    30.     float dis;
    31.  
    32.     float lastReward;
    33.  
    34.     float lastDisToEnd;
    35.  
    36.     bool check1;
    37.     bool check2;
    38.  
    39.  
    40.     public Color playerColor;
    41.  
    42.     public int endReached;
    43.     int jump;
    44.     int movement;
    45.     int dirX;
    46.  
    47.     private void Awake()
    48.     {
    49.         rb = GetComponent<Rigidbody2D>();
    50.         cc = GetComponent<CircleCollider2D>();
    51.         ps = GetComponent<SpriteRenderer>();
    52.         coins = GameObject.FindGameObjectsWithTag("Coin");
    53.  
    54.         foreach (GameObject coin in coins)
    55.         {
    56.             coinPos.Add(coin.transform.localPosition);
    57.         }
    58.  
    59.      
    60.    
    61.         changers = GameObject.FindGameObjectsWithTag("Changer");
    62.  
    63.         foreach (GameObject changer in changers)
    64.         {
    65.             changerPos.Add(changer.transform.localPosition);
    66.         }
    67.     }
    68.  
    69.     public override void OnEpisodeBegin()
    70.     {
    71.         foreach (GameObject coin in coins)
    72.         {
    73.             coin.SetActive(true);
    74.             coin.GetComponent<CoinScirpt>().activelyCollected = false;
    75.         }
    76.         ps.color = Color.white;
    77.         playerColor = ps.color;
    78.         t = 0;
    79.         oldPos = 0;
    80.         playerDied = false;
    81.         check1 = false;
    82.         transform.position = startPos.position + new Vector3(Random.Range(-1.5f, 10f), 0f);
    83.         rb.velocity = Vector3.zero;
    84.  
    85.         dis = Vector3.Distance(startPos.position, endLevel.transform.position);
    86.         lastDisToEnd = dis;
    87.         endReached = PlayerPrefs.GetInt("EndReached");
    88.     }
    89.  
    90.     private void FixedUpdate()
    91.     {
    92.         GetComponent<SpriteRenderer>().color = playerColor;
    93.     }
    94.  
    95.     public override void CollectObservations(VectorSensor sensor)
    96.     {
    97.         sensor.AddObservation(rb.velocity.x);
    98.         sensor.AddObservation(rb.velocity.y);
    99.  
    100.         sensor.AddObservation(transform.position.x);
    101.         sensor.AddObservation(transform.position.y);
    102.  
    103.         sensor.AddObservation(IsGrounded());
    104.  
    105.         sensor.AddObservation(endLevel.position.x);
    106.         sensor.AddObservation(endLevel.position.y);
    107.  
    108.      
    109.         /*for (int i = 0; i < coinPos.Count; i++)
    110.         {
    111.             sensor.AddObservation(coinPos[i].x);
    112.             sensor.AddObservation(coinPos[i].y);
    113.         }*/
    114.      
    115.         sensor.AddObservation(ps.color.r);
    116.         sensor.AddObservation(ps.color.g);
    117.         sensor.AddObservation(ps.color.b);
    118.  
    119.         for (int i = 0; i < changerPos.Count; i++)
    120.         {
    121.             sensor.AddObservation(changerPos[i].x);
    122.             sensor.AddObservation(changerPos[i].y);
    123.         }
    124.  
    125.         sensor.AddObservation(distanceToEnd);
    126.  
    127.     }
    128.  
    129.  
    130.  
    131.     public override void OnActionReceived(ActionBuffers actions)
    132.     {
    133.         Vector2 playerV = Vector2.zero;
    134.         int movement = actions.DiscreteActions[0];
    135.         jump = actions.DiscreteActions[1];
    136.         if(movement == 1) { dirX = -1; }
    137.         if(movement == 2) { dirX = 1; }
    138.         if(movement == 3) { dirX = 0; }
    139.         rb.velocity = new Vector2(dirX * 11, rb.velocity.y);
    140.  
    141.         if (IsGrounded() && jump == 1)
    142.         {
    143.             rb.velocity = new Vector2(rb.velocity.x, 0);
    144.             rb.AddForce(new Vector2(0, 7.5f), ForceMode2D.Impulse);
    145.         }
    146.         // IF player is falling, multiply gravity by 4
    147.         if (rb.velocity.y < 0)
    148.         {
    149.             rb.velocity += Vector2.up * Physics2D.gravity.y * 3 * Time.fixedDeltaTime;
    150.         }
    151.         // if player is going upwards and lets go of jump, start falling earlier
    152.         else if (rb.velocity.y > 0 && jump != 1 )
    153.         {
    154.             rb.velocity += Vector2.up * Physics2D.gravity.y * 4 * Time.fixedDeltaTime;
    155.         }
    156.  
    157.         distanceToEnd = Vector2.Distance(transform.position, endLevel.position);
    158.         if (distanceToEnd <= 1.25f && ps.color == endLevel.GetComponent<SpriteRenderer>().color) { AddReward(3f); PlayerPrefs.SetInt("EndReached", PlayerPrefs.GetInt("EndReached") + 1); EndEpisode(); }
    159.      
    160.         if(distanceToEnd < lastDisToEnd)
    161.         {
    162.             float reward = (dis - distanceToEnd) / dis;
    163.             if ((Mathf.Round(reward * 100) / 100) % 0.01 == 0)
    164.             {
    165.                 //AddReward(0.02f);
    166.             }
    167.             lastDisToEnd = distanceToEnd;
    168.         }
    169.  
    170.      
    171.         //else if(!check1 && distanceToEnd <= 63f && ps.color == Color.red) { AddReward(1f); check1 = true; }
    172.         //else if (!check2 && distanceToEnd <= 77f) { AddReward(1f); check2 = true; }
    173.      
    174.     }
    175.  
    176.     public override void Heuristic(in ActionBuffers actionsOut)
    177.     {
    178.         var discreteActionsOut = actionsOut.DiscreteActions;
    179.  
    180.         if (Input.GetKey(KeyCode.A)) { movement = 1; }
    181.         else if (Input.GetKey(KeyCode.D)) { movement = 2; }
    182.         else { movement = 3; }
    183.  
    184.         discreteActionsOut[0] = movement;
    185.         // Problem is that jump is determined by is grounded, which will return jump as 2 if
    186.         // player is in air, making second part of if statement not happen
    187.         if (IsGrounded() && Input.GetKey(KeyCode.Space)) { jump = 1; }
    188.         else if(rb.velocity.y > 0.01 && Input.GetKey(KeyCode.Space)) { jump = 1; }
    189.         else { jump = 2; }
    190.  
    191.  
    192.      
    193.         discreteActionsOut[1] = jump;
    194.     }
    195.  
    196.  
    197.  
    198.  
    199.     public void PosCheck()
    200.     {
    201.      
    202.         newPos = transform.position.x;
    203.      
    204.         if (oldPos == newPos)
    205.         {
    206.             //AddReward(-1f);
    207.          
    208.         }
    209.         oldPos = newPos;
    210.     }
    211.  
    212.  
    213.  
    214.  
    215.  
    216.  
    217.  
    218.  
    219.  
    220.  
    221.  
    222.  
    223.  
    224.  
    225.  
    226.     public bool IsGrounded() // Is grounded bool determines wether the player is grounded using raycast and layer masks. Resets the wall jumps aswell
    227.     {
    228.         bool isGrounded = false;
    229.         float extraHeightText = .08f;
    230.         RaycastHit2D[] raycastHit = Physics2D.BoxCastAll(cc.bounds.center, cc.bounds.extents, 0f, Vector2.down, extraHeightText);
    231.  
    232.         foreach (RaycastHit2D hit in raycastHit)
    233.         {
    234.             if (hit.collider != null)
    235.             {
    236.                 if (hit.collider.GetComponent<Tilemap>()) //Checks for the obj being a tilemap
    237.                 {
    238.                     Tilemap tm = hit.collider.gameObject.GetComponent<Tilemap>();
    239.                     if (tm.color.r == ps.color.r && tm.color.b == ps.color.b && tm.color.g == ps.color.g && !tm.gameObject.CompareTag("TransTilemaps")) { isGrounded = true; } //Only returns true so that player cant jump on different color
    240.                     else { isGrounded = false; }
    241.                 }
    242.                 else { isGrounded = false; }
    243.             }
    244.             else if (hit.collider == null) { isGrounded = false; }
    245.         }
    246.         if (isGrounded)
    247.         {
    248.             return true;
    249.         }
    250.         else
    251.         {
    252.             return false;
    253.         }
    254.     }
    255.  
    256.     private void OnCollisionEnter2D(Collision2D collision) // Checks if the player is colliding with a tilemap then checks if the player matches the colour of the tilemap
    257.     {
    258.         if (collision.gameObject.GetComponent<Tilemap>())
    259.         {
    260.  
    261.             if ((collision.gameObject.layer == 6) || (collision.gameObject.layer == 7))
    262.             {
    263.                 Tilemap tm = collision.gameObject.GetComponent<Tilemap>();
    264.  
    265.                 if (tm.color != ps.color)
    266.                 {
    267.                     AddReward(-0.15f);
    268.                     EndEpisode();
    269.                 }
    270.              
    271.             }
    272.  
    273.         }
    274.  
    275.     }
    276.  
    277.    
    278.    
    279.  
    280.      
    281.  
    282. }
     
  2. All_American

    All_American

    Joined:
    Oct 14, 2011
    Posts:
    1,528
    Try upping the hidden units to 256 and layers to 3 and set the normalize to true.

    Also I see you’re looking at only 2-3 million steps. It’s going to do that…..it will look at all the possible things you’re giving it the ability to observe to make sure it covers all possible outcomes and learns them all do to know not what to do too. I think….
     
  3. Luke-Houlihan

    Luke-Houlihan

    Joined:
    Jun 26, 2007
    Posts:
    303
    Seconding @All_American on normalizing, large drawdowns in reward are usually caused by exploding gradients (observation values outside -1,1) you need to normalize all observations yourself or set normalize to true in the config. It may also be a too aggressive learning rate, you can try lowering that and training for longer.