Search Unity

  1. Welcome to the Unity Forums! Please take the time to read our Code of Conduct to familiarize yourself with the forum rules and how to post constructively.
  2. Dismiss Notice

Question Car Controller with Wheel Collider using ML Agents and Inference Learning not Learning at all.

Discussion in 'ML-Agents' started by PHOENIX05102000, Jun 17, 2023.

  1. PHOENIX05102000

    PHOENIX05102000

    Joined:
    Oct 14, 2022
    Posts:
    16
    Hi
    I am trying to train a car using ML Agents and Inference Learning to go around a track using a basic checkpoint system. The track consists of wall on both sides to make the AI stay on track. The car is using Wheel Collider for real physics. Even after 100M steps, the car has still not learnt to drive around the track. Any help would be greatly appreciated.

    The following are codes for Car Controller, Car AI, Track Checkpoint System and individual Checkpoints:

    CAR CONTROLLER:

    Code (CSharp):
    1. using System.Collections;
    2. using System.Collections.Generic;
    3. using UnityEngine;
    4.  
    5. public class CarController : MonoBehaviour
    6. {
    7.     private const string HORIZONTAL_MOTION = "Horizontal";
    8.     private const string VERTICAL_MOTION = "Vertical";
    9.    
    10.     public float horizontalInput;
    11.     public float verticalInput;
    12.  
    13.     private bool bIsVehicleBraking;
    14.     private float currentBrakingForce;
    15.     private float currentSteeringAngle;
    16.  
    17.     [SerializeField] private float vehicleMotorForce;
    18.     [SerializeField] private float vehicleBrakingForce;
    19.     [SerializeField] private float vehicleSteeringAngle;
    20.  
    21.  
    22.     [SerializeField] private WheelCollider FrontRightWheelCollider;
    23.     [SerializeField] private WheelCollider FrontLeftWheelCollider;
    24.     [SerializeField] private WheelCollider RearRightWheelCollider;
    25.     [SerializeField] private WheelCollider RearLeftWheelCollider;
    26.  
    27.     [SerializeField] private Transform FrontRightWheelTransform;
    28.     [SerializeField] private Transform FrontLeftWheelTransform;
    29.     [SerializeField] private Transform RearRightWheelTransform;
    30.     [SerializeField] private Transform RearLeftWheelTransform;
    31.  
    32.  
    33.  
    34.     private void FixedUpdate()
    35.     {
    36.         GetMovementInput();
    37.         VehicleMotorHandling();
    38.         VehicleSteeringHandling();
    39.         VehicleWheelAnimationUpdate();
    40.     }
    41.  
    42.     //private void FixedUpdate()
    43.     //{
    44.     //    GetMovementInput();
    45.     //    VehicleMotorHandling();
    46.     //    VehicleSteeringHandling();
    47.     //    VehicleWheelAnimationUpdate();
    48.     //}
    49.  
    50.     public void GetMovementInput()
    51.     {
    52.         horizontalInput = Input.GetAxis(HORIZONTAL_MOTION);
    53.         verticalInput = Input.GetAxis(VERTICAL_MOTION);
    54.         //bIsVehicleBraking = Input.GetKey(KeyCode.Space);
    55.  
    56.     }
    57.  
    58.     private void VehicleMotorHandling()
    59.     {
    60.         FrontLeftWheelCollider.motorTorque = verticalInput * vehicleMotorForce;
    61.         FrontRightWheelCollider.motorTorque = verticalInput * vehicleMotorForce;
    62.         //currentBrakingForce = bIsVehicleBraking ? vehicleBrakingForce : 0f;
    63.         //if (bIsVehicleBraking)
    64.         //{
    65.         //    ApplyBrakingForceToVehicle();
    66.         //}
    67.         if (Input.GetKey(KeyCode.Space))
    68.         {
    69.             currentBrakingForce = vehicleBrakingForce;
    70.         }
    71.         else
    72.         {
    73.             currentBrakingForce = 0f;
    74.         }
    75.         ApplyBrakingForceToVehicle();
    76.     }
    77.  
    78.     private void ApplyBrakingForceToVehicle()
    79.     {
    80.         FrontLeftWheelCollider.brakeTorque = currentBrakingForce;
    81.         FrontRightWheelCollider.brakeTorque = currentBrakingForce;
    82.         RearLeftWheelCollider.brakeTorque = currentBrakingForce;
    83.         RearRightWheelCollider.brakeTorque = currentBrakingForce;
    84.     }
    85.  
    86.     public void StopVehicleCompletely()
    87.     {
    88.         //FrontLeftWheelCollider.brakeTorque = vehicleBrakingForce * 2;
    89.         //FrontRightWheelCollider.brakeTorque = vehicleBrakingForce * 2;
    90.         //RearLeftWheelCollider.brakeTorque = vehicleBrakingForce * 2;
    91.         //RearRightWheelCollider.brakeTorque = vehicleBrakingForce * 2;
    92.         this.gameObject.GetComponent<Rigidbody>().velocity = Vector3.zero;
    93.         this.gameObject.GetComponent<Rigidbody>().angularVelocity = Vector3.zero;
    94.     }
    95.  
    96.     private void VehicleSteeringHandling()
    97.     {
    98.         currentSteeringAngle = vehicleSteeringAngle * horizontalInput;
    99.         FrontLeftWheelCollider.steerAngle = currentSteeringAngle;
    100.         FrontRightWheelCollider.steerAngle = currentSteeringAngle;
    101.     }
    102.  
    103.     private void VehicleWheelAnimationUpdate()
    104.     {
    105.         UpdateWheelOrientation(FrontLeftWheelCollider, FrontLeftWheelTransform);
    106.         UpdateWheelOrientation(FrontRightWheelCollider, FrontRightWheelTransform);
    107.         UpdateWheelOrientation(RearLeftWheelCollider, RearLeftWheelTransform);
    108.         UpdateWheelOrientation(RearRightWheelCollider, RearRightWheelTransform);
    109.     }
    110.  
    111.     private void UpdateWheelOrientation(WheelCollider wheelCollider, Transform wheelTransform)
    112.     {
    113.         Vector3 wheelposition;
    114.         Quaternion wheelrotation;
    115.         wheelCollider.GetWorldPose(out wheelposition, out wheelrotation);
    116.         wheelTransform.position = wheelposition;
    117.         wheelTransform.rotation = wheelrotation;
    118.  
    119.     }
    120. }
    121.  
    CAR AI:

    Code (CSharp):
    1. using System.Collections;
    2. using System.Collections.Generic;
    3. using UnityEngine;
    4. using Unity.MLAgents;
    5. using Unity.MLAgents.Actuators;
    6. using Unity.MLAgents.Sensors;
    7.  
    8. public class CarControllerAI : Agent
    9. {
    10.     [SerializeField] private TrackCheckpointManager trackCheckpointManager;
    11.     [SerializeField] private Transform carSpawnPoint;
    12.  
    13.     [SerializeField] private CarController carController;
    14.  
    15.     private void Awake()
    16.     {
    17.         carController = GetComponent<CarController>();
    18.     }
    19.  
    20.     private void Start()
    21.     {
    22.         trackCheckpointManager.OnPlayerCorrectCheckpoint += CheckpointTracker_OnCorrectCheckpoint;
    23.         trackCheckpointManager.OnPlayerWrongCheckpoint += CheckpointTracker_OnWrongCheckpoint;
    24.     }
    25.  
    26.     private void CheckpointTracker_OnCorrectCheckpoint(object Sender, TrackCheckpointManager.CarCheckPointEventArgs e)
    27.     {
    28.         //Debug.Log(e.carTransform.gameObject.name);
    29.         if (e.carTransform == transform)
    30.         {
    31.             AddReward(1f);
    32.         }
    33.     }
    34.  
    35.     private void CheckpointTracker_OnWrongCheckpoint(object Sender, TrackCheckpointManager.CarCheckPointEventArgs e)
    36.     {
    37.         //Debug.Log(e.carTransform.gameObject.name);
    38.         if (e.carTransform == transform)
    39.         {
    40.             AddReward(-1f);
    41.         }
    42.     }
    43.  
    44.     public override void OnEpisodeBegin()
    45.     {
    46.         //transform.position = carSpawnPoint.position;
    47.         //transform.position = carSpawnPoint.position + new Vector3(20f, 0, Random.Range(-5f, 5f));
    48.         transform.position = carSpawnPoint.position + new Vector3(Random.Range(-5f, 5f), 0, Random.Range(-5f, 5f));
    49.         if (transform.rotation.eulerAngles.y > -90f || transform.rotation.eulerAngles.y < -90f)
    50.             transform.rotation = Quaternion.Euler(0f, -90f, 0f);
    51.         transform.forward = carSpawnPoint.forward;
    52.         trackCheckpointManager.ResetCarChecpoint(transform);
    53.         carController.StopVehicleCompletely();
    54.     }
    55.  
    56.     public override void CollectObservations(VectorSensor sensor)
    57.     {
    58.         Vector3 nextCheckpoint = trackCheckpointManager.GetNextCheckpoint(transform).transform.forward;
    59.         float DirectionDot = Vector3.Dot(transform.forward, nextCheckpoint);
    60.         sensor.AddObservation(DirectionDot);
    61.     }
    62.  
    63.     public override void OnActionReceived(ActionBuffers actions)
    64.     {
    65.         float forwardmovementamount = 0f;
    66.         float steeringangleamount = 0f;
    67.  
    68.         switch (actions.DiscreteActions[0])
    69.         {
    70.             case 0: forwardmovementamount = 0f; break;
    71.             case 1: forwardmovementamount = 1f; break;
    72.             case 2: forwardmovementamount = -1f; break;
    73.         }
    74.  
    75.         switch (actions.DiscreteActions[1])
    76.         {
    77.             case 0: steeringangleamount = 0f; break;
    78.             case 1: steeringangleamount = 1f; break;
    79.             case 2: steeringangleamount = -1f; break;
    80.         }
    81.  
    82.         carController.horizontalInput = forwardmovementamount;
    83.         carController.verticalInput = steeringangleamount;
    84.         carController.GetMovementInput();
    85.     }
    86.  
    87.     public override void Heuristic(in ActionBuffers actionsOut)
    88.     {
    89.         int forwardMovementAction = 0;
    90.         if (Input.GetKey(KeyCode.W)) forwardMovementAction = 1;
    91.         if (Input.GetKey(KeyCode.S)) forwardMovementAction = 2;
    92.  
    93.         int steeringMovementAction = 0;
    94.         if (Input.GetKey(KeyCode.D)) steeringMovementAction = 1;
    95.         if (Input.GetKey(KeyCode.A)) steeringMovementAction = 2;
    96.  
    97.         ActionSegment<int> discreteActions = actionsOut.DiscreteActions;
    98.         discreteActions[0] = forwardMovementAction;
    99.         discreteActions[1] = steeringMovementAction;
    100.     }
    101.  
    102.     private void OnCollisionEnter(Collision collision)
    103.     {
    104.         if(collision.gameObject.tag == "Wall")
    105.         {
    106.             AddReward(-0.5f);
    107.         }
    108.     }
    109.  
    110.     private void OnCollisionStay(Collision collision)
    111.     {
    112.         if (collision.gameObject.tag == "Wall")
    113.         {
    114.             AddReward(-0.1f);
    115.         }
    116.     }
    117. }
    118.  
    TRACK CHECKPOINT SYSTEM:

    Code (CSharp):
    1. using System;
    2. using System.Collections;
    3. using System.Collections.Generic;
    4. using UnityEngine;
    5.  
    6. public class TrackCheckpointManager : MonoBehaviour
    7. {
    8.     //public event EventHandler OnPlayerCorrectCheckpoint;
    9.     //public event EventHandler OnPlayerWrongCheckpoint;
    10.  
    11.     public event EventHandler<CarCheckPointEventArgs> OnPlayerCorrectCheckpoint;
    12.     public event EventHandler<CarCheckPointEventArgs> OnPlayerWrongCheckpoint;
    13.  
    14.     [SerializeField] private List<Transform> RacersTransformList;
    15.     private List<Checkpoint> checkpointList;
    16.     private List<int> nextCheckpointIndexList;
    17.        
    18.     private void Awake()
    19.     {
    20.         Transform CheckpointTransforms = transform.Find("Checkpoints");
    21.  
    22.         checkpointList = new List<Checkpoint>();
    23.  
    24.         foreach (Transform CheckpointTransform in CheckpointTransforms)
    25.         {
    26.             //Debug.Log(CheckpointTransform);
    27.             Checkpoint checkpoint = CheckpointTransform.GetComponent<Checkpoint>();
    28.             checkpoint.SetCurrentCheckpoint(this);
    29.             checkpointList.Add(checkpoint);
    30.  
    31.             nextCheckpointIndexList = new List<int>();
    32.             foreach (Transform racerTransform in RacersTransformList)
    33.             {
    34.                 nextCheckpointIndexList.Add(0);
    35.             }
    36.         }
    37.     }
    38.  
    39.     public void CarPassedCheckpoint(Checkpoint checkpoint, Transform racerTransform)
    40.     {
    41.         int nextCheckpointIndex = nextCheckpointIndexList[RacersTransformList.IndexOf(racerTransform)];
    42.  
    43.         CarCheckPointEventArgs e = new CarCheckPointEventArgs
    44.         {
    45.             carTransform = racerTransform
    46.         };
    47.  
    48.         if (checkpointList.IndexOf(checkpoint) == nextCheckpointIndex)
    49.         {
    50.             //Debug.Log("Correct Checkpoint");
    51.             nextCheckpointIndexList[RacersTransformList.IndexOf(racerTransform)] = (nextCheckpointIndex + 1) % checkpointList.Count;
    52.  
    53.             //CarCheckPointEventArgs e = new CarCheckPointEventArgs
    54.             //{
    55.             //    carTransform = racerTransform
    56.             //};
    57.             OnPlayerCorrectCheckpoint?.Invoke(this, e);
    58.         }
    59.  
    60.         else
    61.         {
    62.             //Debug.Log("Wrong Checkpoint");
    63.             OnPlayerWrongCheckpoint?.Invoke(this, e);
    64.         }
    65.  
    66.     }
    67.  
    68.     public class CarCheckPointEventArgs : EventArgs
    69.     {
    70.         public Transform carTransform { get; set; }
    71.     }
    72.  
    73.     public void ResetCarChecpoint(Transform racerTransform)
    74.     {
    75.         nextCheckpointIndexList[RacersTransformList.IndexOf(racerTransform)] = 0;
    76.     }
    77.  
    78.     public Transform GetNextCheckpoint(Transform racerTransform)
    79.     {
    80.         int nextCheckpointIndex = nextCheckpointIndexList[RacersTransformList.IndexOf(racerTransform)];
    81.  
    82.         Transform CheckpointTransforms = transform.Find("Checkpoints");
    83.  
    84.         Transform NextCheckpointTransform = null;
    85.  
    86.         int count = 0;
    87.  
    88.         foreach (Transform CheckpointTransform in CheckpointTransforms)
    89.         {
    90.             if (count == nextCheckpointIndex)
    91.             {
    92.                 NextCheckpointTransform = CheckpointTransform;
    93.                 //Debug.Log(NextCheckpointTransform.gameObject.name);
    94.                 break;
    95.             }
    96.  
    97.             else
    98.                 count++;
    99.         }
    100.  
    101.         return NextCheckpointTransform;
    102.  
    103.     }
    104. }
    105.  
    CHECKPOINT:

    Code (CSharp):
    1. using System.Collections;
    2. using System.Collections.Generic;
    3. using UnityEngine;
    4.  
    5. public class Checkpoint : MonoBehaviour
    6. {
    7.     private TrackCheckpointManager trackCheckpointManager;
    8.  
    9.     private void OnTriggerEnter(Collider other)
    10.     {
    11.         if (other.TryGetComponent<CarController>(out CarController carController))
    12.         {
    13.             //Debug.Log("Checkpoint Reached");
    14.             trackCheckpointManager.CarPassedCheckpoint(this, other.transform);
    15.         }
    16.     }
    17.  
    18.     public void SetCurrentCheckpoint(TrackCheckpointManager trackCheckpointManager)
    19.     {
    20.         this.trackCheckpointManager = trackCheckpointManager;
    21.     }
    22. }
    23.  
    Training Configuration File (yaml):

    behaviors:
    Car AI Hatchback:
    trainer_type: ppo
    hyperparameters:
    batch_size: 256
    buffer_size: 10240
    learning_rate: 0.0003
    beta: 0.0005
    epsilon: 0.2
    lambd: 0.95
    num_epoch: 3
    shared_critic: false
    learning_rate_schedule: linear
    beta_schedule: linear
    epsilon_schedule: linear
    network_settings:
    normalize: false
    hidden_units: 128
    num_layers: 2
    vis_encode_type: simple
    memory: null
    goal_conditioning_type: hyper
    deterministic: false
    reward_signals:
    extrinsic:
    gamma: 0.99
    strength: 0.2
    network_settings:
    normalize: false
    hidden_units: 128
    num_layers: 2
    vis_encode_type: simple
    memory: null
    goal_conditioning_type: hyper
    deterministic: false
    gail:
    strength: 0.8
    demo_path: D:/Unity Projects/Xtreme Racing/Car AI Demos/Car-Hatchback-2.demo
    behavioral_cloning:
    strength: 0.8
    demo_path: D:/Unity Projects/Xtreme Racing/Car AI Demos/Car-Hatchback-2.demo
    init_path: null
    keep_checkpoints: 5
    checkpoint_interval: 500000
    max_steps: 100000000
    time_horizon: 64
    summary_freq: 50000
    threaded: false
    self_play: null

    DEMO SUMMARY:

    upload_2023-6-17_16-23-1.png

    CAR PARAMETERS:

    upload_2023-6-17_16-23-51.png

    upload_2023-6-17_16-24-10.png

    upload_2023-6-17_16-24-29.png

    upload_2023-6-17_16-24-49.png

    Any help would be greatly appreciated.
     
  2. MelvMay

    MelvMay

    Unity Technologies

    Joined:
    May 24, 2013
    Posts:
    10,468
    Sounds like a post better suited to the ML-Agents sub-forum, not physics. I can move your post for you if you wish?
     
  3. PHOENIX05102000

    PHOENIX05102000

    Joined:
    Oct 14, 2022
    Posts:
    16
  4. Energymover

    Energymover

    Joined:
    Mar 28, 2023
    Posts:
    29
    Ironic, you gave so much information, but first thing that comes to mind is did you use "--time-scale=1" in your training command to start training? That might help since you are using physics.

    As a side note, I got mine working pretty good using the "Realistic Car Controller" in the asset store pretty easily. I started with the free version, then went to light version. In the end I upgraded to the Pro version. All three versions worked excellently with ML Agents, I even have my steering wheel hooked up to it for doing Behavioral cloning demonstrations.
     
  5. PHOENIX05102000

    PHOENIX05102000

    Joined:
    Oct 14, 2022
    Posts:
    16
    No I did not as I thought time scale by default runs on 1 and if i want to increase time scale then only it should be mentioned.
    I will try that.
    Thanks

    One more thing, how many steps did it take for your car to learn driving around the track.
     
    Last edited: Jun 19, 2023
  6. PHOENIX05102000

    PHOENIX05102000

    Joined:
    Oct 14, 2022
    Posts:
    16
    Hey @Energymover, could you get back to me regarding the above post.
     
  7. Energymover

    Energymover

    Joined:
    Mar 28, 2023
    Posts:
    29
    It depends how I train them.

    If I just turn them loose on the track, somewhere around 4+ million and they can get around. If I use behavioral cloning with around 12 demonstrations they will get around in about 40 thousand.

    Things I noticed:
    • If I use BC it works better with only 1 car on the track.
    • Either way I train them, if I start with only two positive rewards and no penalties, it trains faster. I use Angular Velocity towards current check point and when they cross a check point. Later I add the wall collide penalty.
    If we could find a time to meet online I could do a screenshare and give you a run through and demonstration of my setup. I learned oodles of stuff doing it. More than I can type. If you send me a private message with a way to contact you we could setup a time some evening or weekend.

    I think default time-scale is 20
    https://forum.unity.com/threads/training-changes-the-timescale-by-itself.1102507/

    https://github.com/Unity-Technologi...docs/Python-API.md#engineconfigurationchannel
     
    Last edited: Jun 21, 2023
  8. PHOENIX05102000

    PHOENIX05102000

    Joined:
    Oct 14, 2022
    Posts:
    16
    Ok, but I changed the vehicles from using Wheel Collider to normal transform. Now the cars can go around the track in just 200 thousand steps with just one demonstration and with using both GAIL and BC. I am using 5 AI Vehicles to train them.
     
  9. Energymover

    Energymover

    Joined:
    Mar 28, 2023
    Posts:
    29
    Nice. Sounds like you got it going on now, grats.
     
  10. PHOENIX05102000

    PHOENIX05102000

    Joined:
    Oct 14, 2022
    Posts:
    16
    Thanks for the help.