Search Unity

  1. Unity support for visionOS is now available. Learn more in our blog post.
    Dismiss Notice

Question Agent not learning, but mean reward is working

Discussion in 'ML-Agents' started by diwar496, Jun 26, 2023.

  1. diwar496

    diwar496

    Joined:
    Apr 3, 2015
    Posts:
    2
    Hi, I'm trying to use ML-Agents to train an agent(blue cube) to do the following:
    -reach the goal (green cube) Reward: +1f
    -avoid enemy (red cube) Reward: -1f (I also tried with -0.025f,-0.5...)

    The problem is that mean reward is going up (as you see in the photos), but then I assign the brain onnx created to the agent and it seems to have not learned (it goes straight to the green cube not avoiding the red or start going directly to the red cube).

    What am I doing wrong?Here's the agent code and the screenshots:

    Code (CSharp):
    1. using UnityEngine;
    2. using Unity.MLAgents;
    3. using Unity.MLAgents.Actuators;
    4. using Unity.MLAgents.Sensors;
    5. using static UnityEngine.GraphicsBuffer;
    6. using UnityEngine.ProBuilder.Shapes;
    7. using Unity.VisualScripting;
    8. using Unity.MLAgents.Integrations.Match3;
    9.  
    10. public class Player : Agent
    11. {
    12.     //Agent
    13.     public float moveSpeed = 5f;
    14.     public float rotationSpeed = 150f;
    15.     private Rigidbody rb;
    16.     private Vector3 startingPos;
    17.     private Quaternion startingRot;
    18.  
    19.     //Enemy and Goal
    20.     public GameObject enemy;
    21.     public GameObject goal;
    22.  
    23.     //Timer
    24.     public float restartTimerLimit=10f;
    25.     private float restartTimer;
    26.  
    27.  
    28.     public override void Initialize()
    29.     {
    30.         rb = GetComponent<Rigidbody>();
    31.         startingPos = transform.position+new Vector3(0,0,Random.Range(-2f,2f));
    32.         startingRot = transform.rotation;
    33.     }
    34.     public override void OnEpisodeBegin()
    35.     {
    36.         transform.position = startingPos + new Vector3(0, 0, Random.Range(-2f, 2f));
    37.         transform.rotation = startingRot;
    38.         rb.velocity = Vector3.zero;
    39.         restartTimer = 0;
    40.     }
    41.     public override void CollectObservations(VectorSensor sensor)
    42.     {
    43.         //Agent position
    44.         sensor.AddObservation(transform.localPosition);
    45.  
    46.         //Goal position
    47.         sensor.AddObservation(goal.transform.localPosition);
    48.     }
    49.     public override void OnActionReceived(ActionBuffers actions)
    50.     {
    51.         var input = actions.DiscreteActions;
    52.  
    53.         if (input[0] == 0)
    54.         {
    55.             //don't move
    56.         }
    57.         else
    58.         {
    59.             MoveCube(input[0]);
    60.         }
    61.         if (input[1] == 0)
    62.         {
    63.             //don't rotate
    64.         }
    65.         else if (input[1] == -1|| input[1] == 1)
    66.         {
    67.             RotateCube(input[1]);
    68.         }
    69.  
    70.  
    71.     }
    72.     public override void Heuristic(in ActionBuffers actionsOut)
    73.     {
    74.         var action = actionsOut.DiscreteActions;
    75.         action[0] = 0; //don't move
    76.         action[1] = 0; //don't rotate
    77.  
    78.         if (Input.GetKey(KeyCode.W))
    79.             action[0]= 1;
    80.  
    81.         if(Input.GetKey(KeyCode.A))
    82.             action[1] = -1;
    83.         else if (Input.GetKey(KeyCode.D))
    84.             action[1] = 1;
    85.  
    86.     }
    87.     private void FixedUpdate()
    88.     {
    89.         restartTimer += Time.deltaTime;
    90.         if (restartTimer > restartTimerLimit) { restartTimer = 0; EndEpisode(); }
    91.     }
    92.     void MoveCube(float moveVertical)
    93.     {
    94.         Vector3 movement = transform.forward * moveVertical * moveSpeed * Time.deltaTime;
    95.         rb.MovePosition(rb.position + movement);
    96.     }
    97.  
    98.     void RotateCube(float rotateHorizontal)
    99.     {
    100.         Quaternion deltaRotation = Quaternion.Euler(Vector3.up * rotateHorizontal * rotationSpeed * Time.deltaTime);
    101.         rb.MoveRotation(rb.rotation * deltaRotation);
    102.     }
    103.     private void OnCollisionEnter(Collision collision)
    104.     {
    105.         if (collision.gameObject.CompareTag("enemy"))
    106.         {
    107.             RemovePoints(1f);
    108.         }
    109.         else if (collision.gameObject.CompareTag("goal"))
    110.         {
    111.             GivePoints(1f);
    112.  
    113.         }
    114.  
    115.     }
    116.     private void GivePoints(float pointsToGive)
    117.     {
    118.         AddReward(pointsToGive);
    119.         EndEpisode();
    120.     }
    121.     private void RemovePoints(float pointsToRemove)
    122.     {
    123.         AddReward(-pointsToRemove);
    124.         EndEpisode();
    125.     }
    126.  
    127. }
    1.png 2.png 3.png
     
  2. Energymover

    Energymover

    Joined:
    Mar 28, 2023
    Posts:
    33
  3. diwar496

    diwar496

    Joined:
    Apr 3, 2015
    Posts:
    2
    Thank you for the reply.
    • everything is in default layer, enemy and goal tagged
    • Tried and nothing changes
    • Tried with stacked raycast for short memory and nothing changed
    I really don't know what else to try..
     
  4. Luke-Houlihan

    Luke-Houlihan

    Joined:
    Jun 26, 2007
    Posts:
    303
    I don't see the logic for the enemy cube, I'm guessing from your results that it is randomly placed. It seems the random placement doesn't intersect the agents direct path to the goal very often and the agent learns to ignore it entirely.

    A quick change would be to sample random spawns until you find one that intersects with the agents line of sight to the goal, which would always place the enemy in a random position between the agent and goal. I would think that would teach the agent to avoid the enemy as well as seeking the goal.
     
  5. Luke-Houlihan

    Luke-Houlihan

    Joined:
    Jun 26, 2007
    Posts:
    303
    I also see you are encoding the goal position directly in the observations and not the enemy's position. You can try a sanity check by encoding the enemy position directly like the goal position.

    If that change makes a difference it means your network isn't large enough to handle the complexity of the raycast observations in addition to the two competing tasks you've set for the agent (move toward the goal, move away from the enemy).
     
  6. smallg2023

    smallg2023

    Joined:
    Sep 2, 2018
    Posts:
    130
    not your issue but discrete actions give 0,1,2,3 etc, not -1,0,1
    so your agent will never be able to rotate left