Search Unity

Question Agent keeps running out of bounds, subsequently not learning [Video included]

Discussion in 'ML-Agents' started by sunnyCallum, Dec 14, 2023.

  1. sunnyCallum

    sunnyCallum

    Joined:
    Nov 6, 2021
    Posts:
    8
    Hi guys, I was hoping someone would be able to help me with my ML Agent training that doesn't seem to be working. A little while ago, this training was working completely fine, I can't remember what I changed but I came back to it after a week and it had broken, I am baffled as to what happened to it without me touching it. The agent seems to be running out of bounds for the sake of it, even though it gets punished.

    Below is a video of what is happening to the agent when I begin its training. I have tried recreating the project altogether on my MacBook in hopes it was my Windows PC that was the issue, but I am still stuck. I have tried changing multiple things within the script, such as the rewards. Aswell as this I have changed the layer ordering of the GameObjects but none of that worked, I am out of ideas as to what to do.

    I should add that I followed the Getting Started guide on the ML-Agents GitHub repo and altered it to suit my needs, so this script is pretty much the same as the Getting Started guide aside from some extra logic.

    If anyone could help me I would greatly appreciate it. I am using this for a final-year University project, but I am slowly losing the motivation to continue.



    Code (CSharp):
    1. using System.Collections;
    2. using System.Collections.Generic;
    3. using UnityEngine;
    4.  
    5. using Unity.MLAgents;
    6. using Unity.MLAgents.Sensors;
    7. using Unity.MLAgents.Actuators;
    8.  
    9. public class Cat_Agent : Agent
    10. {
    11.     // GameObejcts
    12.  
    13.     public Collider2D trainingArea; // The training area
    14.     public Transform Target; // The target
    15.     public Collider2D targetCollider; // The collider of the target
    16.  
    17.     // Physics
    18.     Rigidbody2D rBody; // The rigidbody of the agent
    19.     public float forceMultiplier = 10f; // The force multiplier for the agent
    20.  
    21.     // Parameters
    22.     [SerializeField] public float funNeed; // The fun need of the agent
    23.     private float regenRate = 5f; // The rate at which the fun need regenerates
    24.  
    25.     // Start is called before the first frame update
    26.     void Start()
    27.     {
    28.         rBody = GetComponent<Rigidbody2D>(); // Get the rigidbody of the agent
    29.     }
    30.  
    31.     // Called when the agent is reset
    32.     public override void OnEpisodeBegin()
    33.     {
    34.         // Reset the agent's position
    35.         SetRandomTarget();
    36.    
    37.         //Reset the target's position
    38.         SetRandomAgent();
    39.  
    40.         // Reset the agent's velocity
    41.         rBody.velocity = Vector2.zero;
    42.  
    43.         // Reset the fun need
    44.         funNeed = 25f;
    45.     }
    46.  
    47.     // Called when the agent requests a decision
    48.     public override void CollectObservations(VectorSensor sensor)
    49.     {
    50.         // Target and Agent positions
    51.         sensor.AddObservation(Target.localPosition);
    52.         sensor.AddObservation(this.transform.localPosition);
    53.  
    54.         // Agent velocity
    55.         sensor.AddObservation(rBody.velocity.x); // Upwards velocity
    56.         sensor.AddObservation(rBody.velocity.y); // Sideways velocity
    57.  
    58.         // Fun need
    59.         sensor.AddObservation(funNeed / 100f); // Normalise to range of 0, 1
    60.     }
    61.  
    62.     // Called when the agent requests an action
    63.     public override void OnActionReceived(ActionBuffers actionBuffers)
    64.     {
    65.         Vector2 controlSignal = Vector2.zero; // The control signal for the agent
    66.         float distanceToTarget = Vector2.Distance(this.transform.localPosition, Target.localPosition); // The distance to the target
    67.  
    68.         // Get the action from the action buffer
    69.         controlSignal.x = actionBuffers.ContinuousActions[0]; // Upwards force
    70.         controlSignal.y = actionBuffers.ContinuousActions[1]; // Sideways force
    71.  
    72.         // Check if the agent is outside of the training area
    73.         if (!trainingArea.bounds.Contains(this.transform.localPosition))
    74.         {
    75.             // Punish the agent
    76.             AddReward(-1f);
    77.             EndEpisode();
    78.         }
    79.  
    80.         // Reward the agent based on the fun need
    81.         // This is to incentivise the agent to keep the fun need high in order to gain more rewards
    82.         AddReward(funNeed / 100f);
    83.  
    84.         // Reward the agent based on the distance to the target
    85.         // This is to incentivise the agent to get closer to the target
    86.         AddReward(1f / distanceToTarget);
    87.  
    88.         // Check if the agent is within the target's collider
    89.         //TODO: Update this so the agent "uses" the object instead of sitting inside of it
    90.         if (targetCollider.bounds.Contains(this.transform.localPosition))
    91.         {
    92.             funNeed += regenRate * Time.deltaTime; // Regenerate the fun need
    93.             AddReward(0.1f); // Reward the agent
    94.             funNeed = Mathf.Clamp(funNeed, 0f, 100f); // Clamp the fun need
    95.         }
    96.  
    97.         // When the fun need is satisfied, end the episode
    98.         if (funNeed >= 100f)
    99.         {
    100.             AddReward(2f);
    101.             EndEpisode();
    102.         }
    103.     }
    104.  
    105.     public override void Heuristic(in ActionBuffers actionsOut)
    106.     {
    107.         // Get the action buffer
    108.         var continuousActions = actionsOut.ContinuousActions;
    109.  
    110.         // Set the action buffer
    111.         continuousActions[0] = Input.GetAxisRaw("Vertical");
    112.         continuousActions[1] = Input.GetAxisRaw("Horizontal");
    113.     }
    114.  
    115.     public void SetRandomTarget() // Set a random position for the target
    116.     {
    117.         // Set the target's position to a random position within the training area
    118.         Vector2 position = new Vector2(Random.Range(trainingArea.bounds.min.x, trainingArea.bounds.max.x), Random.Range(trainingArea.bounds.min.y, trainingArea.bounds.max.y));
    119.         Target.position = position;
    120.     }
    121.  
    122.     public void SetRandomAgent() // Set a random position for the agent
    123.     {
    124.         // Set the target's position to a random position within the training area
    125.         Vector2 position = new Vector2(Random.Range(trainingArea.bounds.min.x, trainingArea.bounds.max.x), Random.Range(trainingArea.bounds.min.y, trainingArea.bounds.max.y));
    126.         transform.position = position;
    127.     }
    128.  
    129.  
    130. }
    131.  
    Code (csharp):
    1. behaviors:
    2.   CatAgent:
    3.     trainer_type: ppo
    4.     hyperparameters:
    5.       batch_size: 10
    6.       buffer_size: 100
    7.       learning_rate: 3.0e-4
    8.       beta: 5.0e-4
    9.       epsilon: 0.2
    10.       lambd: 0.99
    11.       num_epoch: 3
    12.       learning_rate_schedule: linear
    13.       beta_schedule: constant
    14.       epsilon_schedule: linear
    15.     network_settings:
    16.       normalize: false
    17.       hidden_units: 128
    18.       num_layers: 2
    19.     reward_signals:
    20.       extrinsic:
    21.         gamma: 0.99
    22.         strength: 1.0
    23.     max_steps: 500000
    24.     time_horizon: 64
    25.     summary_freq: 10000
     
    Last edited: Dec 14, 2023