Search Unity

  1. Unity 6 Preview is now available. To find out what's new, have a look at our Unity 6 Preview blog post.
    Dismiss Notice
  2. Unity is excited to announce that we will be collaborating with TheXPlace for a summer game jam from June 13 - June 19. Learn more.
    Dismiss Notice
  3. Dismiss Notice

Question Car racing Agent only learns how to turn left

Discussion in 'ML-Agents' started by NDani0209, Apr 19, 2024.

  1. NDani0209

    NDani0209

    Joined:
    Mar 6, 2024
    Posts:
    1
    Hi,

    Im trying to train a car driving agent to go around a simple track. My current goal for the Agent is to make it able to tell the difference between a right and a left turn (90° turns). The problem is that the Agent always seems to learn how to turn left much faster, and then it didnt even try to make the right turns at all.

    The interesesting part is that if I only train the Agent with the track containing only right turns, it learns it no problem, but if I add the other track as well, eventually the Agent will learn to turn left, and forget how to turn right.

    Here is a screenshot of the environment from above:
    Env.png

    Each track has 8 individual car Agent.

    For observations i use RayPerceptionSensor3D-s to detect the walls, and I have 5 vector observations in addition:
    • The magnitude of the velocity vector
    • The dot product of the car's forward vector and the next checkpoint's forward vector
    • The dot product of the car's cross vector (cross product of up & forward) and the next checkpoint's forward vector
    • The previous 2 but for the 2nd next Checkpoint
    So my idea was that using the dot products the Agent should be able to identify what kind of turn is coming. The reason why I use dot products instead of the actual vectors is that the dot product is independent from the vectors orientation in world space (So every right turn will look the same for the Agent). I thought this is much easier for the Agent to learn because if the dot product with the cross vector is negative, it should turn right, if its positive, it should turn left.

    The rest of the project is quite basic, but here is the source code of the Agent:

    Code (CSharp):
    1. using System.Collections;
    2. using System.Collections.Generic;
    3. using Unity.VisualScripting;
    4. using UnityEngine;
    5. using Unity.MLAgents;
    6. using static UnityEngine.GraphicsBuffer;
    7. using Unity.MLAgents.Actuators;
    8. using Unity.MLAgents.Sensors;
    9. using Unity.Barracuda;
    10. using Unity.Mathematics;
    11.  
    12. public class Player : Agent
    13. {
    14.  
    15.     private CarController carController;
    16.     [SerializeField] private Rigidbody carRB;
    17.     [SerializeField] private Transform carTransform;
    18.     [SerializeField] private TrackScript Track;
    19.     [SerializeField] private InputController inputController;
    20.  
    21.     private int nextCheckPoint = 0;
    22.  
    23.     public override void Initialize()
    24.     {
    25.         carController = GetComponent<CarController>();
    26.     }
    27.  
    28.     public override void OnEpisodeBegin()
    29.     {
    30.         transform.localPosition = new Vector3(0f, 0.5f, -5f);
    31.         nextCheckPoint = 0;
    32.         carRB.velocity = new Vector3(0f, 0f, 0f);
    33.         transform.localRotation = Quaternion.Euler(new Vector3(0f, 180f, 0f));
    34.     }
    35.  
    36.     public override void CollectObservations(VectorSensor sensor)
    37.     {
    38.         Vector3 CrossVec = Vector3.Cross(carTransform.forward, carTransform.up);
    39.         Vector3 CheckPointForward = Track.GetCheckPointTransform(nextCheckPoint).forward;
    40.  
    41.         sensor.AddObservation(Vector3.Magnitude(carRB.velocity));
    42.         sensor.AddObservation(Vector3.Dot(CheckPointForward , carTransform.forward));
    43.         sensor.AddObservation(Vector3.Dot(CheckPointForward, CrossVec));
    44.  
    45.  
    46.         CheckPointForward = Track.GetCheckPointTransform(nextCheckPoint + 1).forward;
    47.         sensor.AddObservation(Vector3.Dot(CheckPointForward, carTransform.forward));
    48.         sensor.AddObservation(Vector3.Dot(CheckPointForward, CrossVec));
    49.     }
    50.  
    51.     public override void OnActionReceived(ActionBuffers actions)
    52.     {
    53.         AddReward(-10f / MaxStep);
    54.  
    55.         Vector3 CheckPointForward = Vector3.Normalize(Track.GetCheckPointTransform(nextCheckPoint).forward);
    56.         if (Vector3.Dot(CheckPointForward, Vector3.Normalize(carTransform.forward)) < 0)
    57.         {
    58.             AddReward(-0.1f);
    59.         }
    60.  
    61.         carController.Throttle = actions.ContinuousActions[0];
    62.         carController.Steer = actions.ContinuousActions[1];
    63.     }
    64.  
    65.  
    66.     public override void Heuristic(in ActionBuffers actionsOut)
    67.     {
    68.         ActionSegment<float> continuousActions = actionsOut.ContinuousActions;
    69.         continuousActions[1] = inputController.SteerInput;
    70.         continuousActions[0] = inputController.ThrottleInput;
    71.     }
    72.  
    73.     private void OnCollisionEnter(Collision collision)
    74.     {
    75.         AddReward(-1f);
    76.      
    77.     }
    78.  
    79.     private void OnCollisionStay(Collision collision)
    80.     {
    81.         AddReward(-0.05f);
    82.     }
    83.  
    84.     private void OnTriggerEnter(Collider other)
    85.     {
    86.         if(other.tag=="CheckPoint")
    87.         {
    88.             int newID = Track.CrossCheckPoint(nextCheckPoint, other.transform);
    89.  
    90.             if(newID != nextCheckPoint)
    91.             {
    92.              
    93.                 nextCheckPoint = newID;
    94.                 AddReward(2f);
    95.             }
    96.  
    97.         }
    98.     }
    99. }
    100.  
    And here is the config file:

    behaviors:
    DriveCar:
    trainer_type: ppo
    hyperparameters:
    batch_size: 2048
    buffer_size: 40960
    learning_rate: 0.0003
    beta: 0.003
    epsilon: 0.2
    lambd: 0.95
    num_epoch: 4
    learning_rate_schedule: linear
    network_settings:
    hidden_units: 512
    num_layers: 4
    vis_encode_type: simple
    reward_signals:
    extrinsic:
    gamma: 0.995
    strength: 1.0
    keep_checkpoints: 10
    checkpoint_interval: 1000000
    max_steps: 50000000
    time_horizon: 512
    summary_freq: 50000

    If you have any idea why is this happening, i would be gratefull if you would share it with me.
    I checked and all the checkpoints are facing the correct way, so thats not the problem.