Search Unity

Question Self driving car (need help with training)

Discussion in 'ML-Agents' started by leozhang1, Aug 12, 2022.

  1. leozhang1

    leozhang1

    Joined:
    Jan 27, 2018
    Posts:
    16
    I am having trouble tuning the parameters to the proper values when training my self driving cars. I'm not sure if I'm rewarding and punishing in the right places...Each car agent is supposed to go thru a series of check points in a circular race track (similar to this video:
    ). However the agents seem to be having trouble making sharp turns. After training, the overall "trained" neural network can barely pass the first turn.

    visualization of my track:
    upload_2022-8-12_8-50-31.png

    CarDriverAgents script:
    Code (CSharp):
    1. using System.Collections;
    2. using System.Collections.Generic;
    3. using UnityEngine;
    4. using Unity.MLAgents;
    5. using Unity.MLAgents.Actuators;
    6. using Unity.MLAgents.Sensors;
    7. public class CarDriverAgent : Agent
    8. {
    9.     [SerializeField] private TrackCheckpoints trackCheckpoints;
    10.     [SerializeField] private Transform spawnPosition;
    11.     private CarDriver carDriver;
    12.     private Quaternion recallRotation;
    13.     private Vector3 recallPosition;
    14.     [System.Serializable]
    15.     public struct RewardsInfo
    16.     {
    17.         public float correctCheckpoint, wrongCheckpoint;
    18.         public float hitLastCheckpoint, hitAWall, slidingAlongWall;
    19.         public float movingForward, movingBackwards, noMovement;
    20.         public float notFacingCheckpoint;
    21.     }
    22.     [SerializeField] private RewardsInfo rwd;
    23.     public override void Initialize()
    24.     {
    25.         carDriver = GetComponent<CarDriver>();
    26.         recallRotation = new Quaternion(this.transform.rotation.x, this.transform.rotation.y, this.transform.rotation.z, this.transform.rotation.w);
    27.         recallPosition = spawnPosition.position + new Vector3(Random.Range(-5f,+5f), 0, Random.Range(-5f,+5f));
    28.         if (trackCheckpoints is null)
    29.         {
    30.             trackCheckpoints = GameObject.Find("CheckPoints").GetComponent<TrackCheckpoints>();
    31.         }
    32.         trackCheckpoints.OnCarCorrectCheckpoint += TrackCheckpoints_OnCarCorrectCheckpoint;
    33.         trackCheckpoints.OnCarWrongCheckpoint += TrackCheckpoints_OnCarWrongCheckpoint;
    34.         trackCheckpoints.OnAgentCompleteTrack += resetAgent;
    35.         trackCheckpoints.OnAgentCompleteTrack += rewardAgent;
    36.     }
    37.     private void TrackCheckpoints_OnCarCorrectCheckpoint(Transform carTransform)
    38.     {
    39.         if (carTransform == transform)
    40.         {
    41.             // print("correct checkpoint");
    42.             AddReward(rwd.correctCheckpoint);
    43.         }
    44.     }
    45.     private void TrackCheckpoints_OnCarWrongCheckpoint(Transform carTransform)
    46.     {
    47.         if (carTransform == transform)
    48.         {
    49.             // print("wrong checkpoint");
    50.             AddReward(rwd.wrongCheckpoint);
    51.         }
    52.     }
    53.     private void resetAgent()
    54.     {
    55.         EndEpisode();
    56.     }
    57.     private void rewardAgent()
    58.     {
    59.         AddReward(rwd.hitLastCheckpoint);
    60.     }
    61.     public override void OnEpisodeBegin()
    62.     {
    63.         // print("episode begin");
    64.         transform.position = recallPosition;
    65.         transform.forward = spawnPosition.forward;
    66.         transform.rotation = recallRotation;
    67.         trackCheckpoints.resetCheckPoint(transform);
    68.         carDriver.StopCompletely();
    69.     }
    70.     private float GetDotWithNextCheckpoint()
    71.     {
    72.         Vector3 checkpointForward = trackCheckpoints.GetNextCheckpoint(transform).transform.forward;
    73.         float directionDot = Vector3.Dot(transform.forward, checkpointForward);
    74.         return directionDot;
    75.     }
    76.     public override void CollectObservations(VectorSensor sensor)
    77.     {
    78.         var dot = GetDotWithNextCheckpoint();
    79.         if (dot < 0.9f)
    80.         {
    81.             AddReward(rwd.notFacingCheckpoint);
    82.         }
    83.         sensor.AddObservation(dot);
    84.     }
    85.     public override void OnActionReceived(ActionBuffers actions)
    86.     {
    87.         float forwardAmount = 0f, turnAmount = 0f;
    88.         switch (actions.DiscreteActions[0])
    89.         {
    90.             case 0:
    91.                 forwardAmount = 0f;
    92.                 AddReward(rwd.noMovement);
    93.                 break;
    94.             case 1:
    95.                 forwardAmount = +1f;
    96.                 // encourage moving forward
    97.                 AddReward(rwd.movingForward);
    98.                 break;
    99.             case 2:
    100.                 forwardAmount = -1f;
    101.                 AddReward(rwd.movingBackwards);
    102.                 break;
    103.         }
    104.         switch (actions.DiscreteActions[1])
    105.         {
    106.             case 0:
    107.                 turnAmount = 0f;
    108.                 break;
    109.             case 1:
    110.                 turnAmount = +1f;
    111.                 break;
    112.             case 2:
    113.                 turnAmount = -1f;
    114.                 break;
    115.         }
    116.         carDriver.SetInputs(forwardAmount, turnAmount);
    117.     }
    118.     public override void Heuristic(in ActionBuffers actionsOut)
    119.     {
    120.         int forwardAction = 0;
    121.         if (Input.GetKey(KeyCode.W) || Input.GetKey(KeyCode.UpArrow)) forwardAction = 1;
    122.         if (Input.GetKey(KeyCode.S) ||  Input.GetKey(KeyCode.DownArrow)) forwardAction = 2;
    123.         int turnAction = 0;
    124.         if (Input.GetKey(KeyCode.D) ||  Input.GetKey(KeyCode.RightArrow)) turnAction = 1;
    125.         if (Input.GetKey(KeyCode.A) ||  Input.GetKey(KeyCode.LeftArrow)) turnAction = 2;
    126.         ActionSegment<int> discreteActions = actionsOut.DiscreteActions;
    127.         discreteActions[0] = forwardAction;
    128.         discreteActions[1] = turnAction;
    129.         print($"dot with next checkpoint: {GetDotWithNextCheckpoint()}");
    130.     }
    131.     void OnCollisionEnter(Collision collision)
    132.     {
    133.         if (collision.gameObject.CompareTag("wall"))
    134.         {
    135.             // the car has hit a wall
    136.             // punish the ai
    137.             if (trackCheckpoints.GetNextCheckpoint(transform).name != "CheckpointSingle (67)")
    138.             {
    139.                 // the harder you hit the wall, the more the punish
    140.                 AddReward(rwd.hitAWall * collision.relativeVelocity.sqrMagnitude);
    141.             }
    142.             // EndEpisode();
    143.             // print("ended episode");
    144.         }
    145.     }
    146.     void OnCollisionStay(Collision collision)
    147.     {
    148.         if (collision.gameObject.CompareTag("wall"))
    149.         {
    150.             // the car has hit a wall
    151.             // punish the ai
    152.             // avoid the ai from driving the car along the wall
    153.             AddReward(rwd.slidingAlongWall);
    154.         }
    155.     }
    156. }
    CarAI.yaml:
    Code (CSharp):
    1. behaviors:
    2.   CarDriver:
    3.     trainer_type: ppo
    4.     hyperparameters:
    5.       batch_size: 256
    6.       buffer_size: 10240
    7.       learning_rate: 0.0003
    8.       beta: 0.0005
    9.       epsilon: 0.2
    10.       lambd: 0.99
    11.       num_epoch: 3
    12.       learning_rate_schedule: linear
    13.     network_settings:
    14.       normalize: false
    15.       hidden_units: 128
    16.       num_layers: 2
    17.     reward_signals:
    18.       extrinsic:
    19.         strength: 1
    20.         gamma: 0.99
    21.     max_steps: 7500000
    22.     time_horizon: 64
    23.     summary_freq: 5000000
    24.     threaded: true
    upload_2022-8-12_8-48-21.png

    upload_2022-8-12_8-49-36.png
     

    Attached Files: