Search Unity

Resolved Using imitation learning in Karting Microgame

Discussion in 'ML-Agents' started by Wolf00007, May 5, 2021.

  1. Wolf00007

    Wolf00007

    Joined:
    Jan 26, 2019
    Posts:
    24
    Hello,
    I'm new to the ML-Agents and was experimenting with the Karting Microgame tutorials. I would like to add imitation learning to my game and I was wondering how to make my Agent Kart take movement from the keyboard. Below is the Kart Agent script used by the agent karts.

    So I saw in one of the tutorials that I should override a Heuristic function and add it to this script to allow me to use the arrow keys to record actions of my kart. But how do I do that with this script? I am struggling to find the things I should modify in order to make that work. Any tips would be much appreciated!


    Code (CSharp):
    1. using KartGame.KartSystems;
    2. using Unity.MLAgents;
    3. using Unity.MLAgents.Sensors;
    4. using UnityEngine;
    5. using Random = UnityEngine.Random;
    6.  
    7. namespace KartGame.AI
    8. {
    9.     /// <summary>
    10.     /// Sensors hold information such as the position of rotation of the origin of the raycast and its hit threshold
    11.     /// to consider a "crash".
    12.     /// </summary>
    13.     [System.Serializable]
    14.     public struct Sensor
    15.     {
    16.         public Transform Transform;
    17.         public float RayDistance;
    18.         public float HitValidationDistance;
    19.     }
    20.  
    21.     /// <summary>
    22.     /// We only want certain behaviours when the agent runs.
    23.     /// Training would allow certain functions such as OnAgentReset() be called and execute, while Inferencing will
    24.     /// assume that the agent will continuously run and not reset.
    25.     /// </summary>
    26.     public enum AgentMode
    27.     {
    28.         Training,
    29.         Inferencing
    30.     }
    31.  
    32.     /// <summary>
    33.     /// The KartAgent will drive the inputs for the KartController.
    34.     /// </summary>
    35.     public class KartAgent : Agent, IInput
    36.     {
    37. #region Training Modes
    38.         [Tooltip("Are we training the agent or is the agent production ready?")]
    39.         public AgentMode Mode = AgentMode.Training;
    40.         [Tooltip("What is the initial checkpoint the agent will go to? This value is only for inferencing.")]
    41.         public ushort InitCheckpointIndex;
    42.  
    43. #endregion
    44.  
    45. #region Senses
    46.         [Header("Observation Params")]
    47.         [Tooltip("What objects should the raycasts hit and detect?")]
    48.         public LayerMask Mask;
    49.         [Tooltip("Sensors contain ray information to sense out the world, you can have as many sensors as you need.")]
    50.         public Sensor[] Sensors;
    51.         [Header("Checkpoints"), Tooltip("What are the series of checkpoints for the agent to seek and pass through?")]
    52.         public Collider[] Colliders;
    53.         [Tooltip("What layer are the checkpoints on? This should be an exclusive layer for the agent to use.")]
    54.         public LayerMask CheckpointMask;
    55.  
    56.         [Space]
    57.         [Tooltip("Would the agent need a custom transform to be able to raycast and hit the track? " +
    58.             "If not assigned, then the root transform will be used.")]
    59.         public Transform AgentSensorTransform;
    60. #endregion
    61.  
    62. #region Rewards
    63.         [Header("Rewards"), Tooltip("What penatly is given when the agent crashes?")]
    64.         public float HitPenalty = -1f;
    65.         [Tooltip("How much reward is given when the agent successfully passes the checkpoints?")]
    66.         public float PassCheckpointReward;
    67.         [Tooltip("Should typically be a small value, but we reward the agent for moving in the right direction.")]
    68.         public float TowardsCheckpointReward;
    69.         [Tooltip("Typically if the agent moves faster, we want to reward it for finishing the track quickly.")]
    70.         public float SpeedReward;
    71.         [Tooltip("Reward the agent when it keeps accelerating")]
    72.         public float AccelerationReward;
    73.         #endregion
    74.  
    75.         #region ResetParams
    76.         [Header("Inference Reset Params")]
    77.         [Tooltip("What is the unique mask that the agent should detect when it falls out of the track?")]
    78.         public LayerMask OutOfBoundsMask;
    79.         [Tooltip("What are the layers we want to detect for the track and the ground?")]
    80.         public LayerMask TrackMask;
    81.         [Tooltip("How far should the ray be when casted? For larger karts - this value should be larger too.")]
    82.         public float GroundCastDistance;
    83. #endregion
    84.  
    85. #region Debugging
    86.         [Header("Debug Option")] [Tooltip("Should we visualize the rays that the agent draws?")]
    87.         public bool ShowRaycasts;
    88. #endregion
    89.  
    90.         ArcadeKart m_Kart;
    91.         bool m_Acceleration;
    92.         bool m_Brake;
    93.         float m_Steering;
    94.         int m_CheckpointIndex;
    95.  
    96.         bool m_EndEpisode;
    97.         float m_LastAccumulatedReward;
    98.  
    99.         void Awake()
    100.         {
    101.             m_Kart = GetComponent<ArcadeKart>();
    102.             if (AgentSensorTransform == null) AgentSensorTransform = transform;
    103.         }
    104.  
    105.         void Start()
    106.         {
    107.             // If the agent is training, then at the start of the simulation, pick a random checkpoint to train the agent.
    108.             OnEpisodeBegin();
    109.  
    110.             if (Mode == AgentMode.Inferencing) m_CheckpointIndex = InitCheckpointIndex;
    111.         }
    112.  
    113.         void Update()
    114.         {
    115.             if (m_EndEpisode)
    116.             {
    117.                 m_EndEpisode = false;
    118.                 AddReward(m_LastAccumulatedReward);
    119.                 EndEpisode();
    120.                 OnEpisodeBegin();
    121.             }
    122.         }
    123.  
    124.         void LateUpdate()
    125.         {
    126.             switch (Mode)
    127.             {
    128.                 case AgentMode.Inferencing:
    129.                     if (ShowRaycasts)
    130.                         Debug.DrawRay(transform.position, Vector3.down * GroundCastDistance, Color.cyan);
    131.  
    132.                     // We want to place the agent back on the track if the agent happens to launch itself outside of the track.
    133.                     if (Physics.Raycast(transform.position + Vector3.up, Vector3.down, out var hit, GroundCastDistance, TrackMask)
    134.                         && ((1 << hit.collider.gameObject.layer) & OutOfBoundsMask) > 0)
    135.                     {
    136.                         // Reset the agent back to its last known agent checkpoint
    137.                         var checkpoint = Colliders[m_CheckpointIndex].transform;
    138.                         transform.localRotation = checkpoint.rotation;
    139.                         transform.position = checkpoint.position;
    140.                         m_Kart.Rigidbody.velocity = default;
    141.                         m_Steering = 0f;
    142.                         m_Acceleration = m_Brake = false;
    143.                     }
    144.  
    145.                     break;
    146.             }
    147.         }
    148.  
    149.         void OnTriggerEnter(Collider other)
    150.         {
    151.             var maskedValue = 1 << other.gameObject.layer;
    152.             var triggered = maskedValue & CheckpointMask;
    153.  
    154.             FindCheckpointIndex(other, out var index);
    155.  
    156.             // Ensure that the agent touched the checkpoint and the new index is greater than the m_CheckpointIndex.
    157.             if (triggered > 0 && index > m_CheckpointIndex || index == 0 && m_CheckpointIndex == Colliders.Length - 1)
    158.             {
    159.                 AddReward(PassCheckpointReward);
    160.                 m_CheckpointIndex = index;
    161.             }
    162.         }
    163.  
    164.         void FindCheckpointIndex(Collider checkPoint, out int index)
    165.         {
    166.             for (int i = 0; i < Colliders.Length; i++)
    167.             {
    168.                 if (Colliders[i].GetInstanceID() == checkPoint.GetInstanceID())
    169.                 {
    170.                     index = i;
    171.                     return;
    172.                 }
    173.             }
    174.             index = -1;
    175.         }
    176.  
    177.         float Sign(float value)
    178.         {
    179.             if (value > 0)
    180.             {
    181.                 return 1;
    182.             }
    183.             if (value < 0)
    184.             {
    185.                 return -1;
    186.             }
    187.             return 0;
    188.         }
    189.  
    190.         public override void CollectObservations(VectorSensor sensor)
    191.         {
    192.             sensor.AddObservation(m_Kart.LocalSpeed());
    193.  
    194.             // Add an observation for direction of the agent to the next checkpoint.
    195.             var next = (m_CheckpointIndex + 1) % Colliders.Length;
    196.             var nextCollider = Colliders[next];
    197.             if (nextCollider == null)
    198.                 return;
    199.  
    200.             var direction = (nextCollider.transform.position - m_Kart.transform.position).normalized;
    201.             sensor.AddObservation(Vector3.Dot(m_Kart.Rigidbody.velocity.normalized, direction));
    202.  
    203.             if (ShowRaycasts)
    204.                 Debug.DrawLine(AgentSensorTransform.position, nextCollider.transform.position, Color.magenta);
    205.  
    206.             m_LastAccumulatedReward = 0.0f;
    207.             m_EndEpisode = false;
    208.             for (var i = 0; i < Sensors.Length; i++)
    209.             {
    210.                 var current = Sensors[i];
    211.                 var xform = current.Transform;
    212.                 var hit = Physics.Raycast(AgentSensorTransform.position, xform.forward, out var hitInfo,
    213.                     current.RayDistance, Mask, QueryTriggerInteraction.Ignore);
    214.  
    215.                 if (ShowRaycasts)
    216.                 {
    217.                     Debug.DrawRay(AgentSensorTransform.position, xform.forward * current.RayDistance, Color.green);
    218.                     Debug.DrawRay(AgentSensorTransform.position, xform.forward * current.HitValidationDistance,
    219.                         Color.red);
    220.  
    221.                     if (hit && hitInfo.distance < current.HitValidationDistance)
    222.                     {
    223.                         Debug.DrawRay(hitInfo.point, Vector3.up * 3.0f, Color.blue);
    224.                     }
    225.                 }
    226.  
    227.                 if (hit)
    228.                 {
    229.                     if (hitInfo.distance < current.HitValidationDistance)
    230.                     {
    231.                         m_LastAccumulatedReward += HitPenalty;
    232.                         m_EndEpisode = true;
    233.                     }
    234.                 }
    235.  
    236.                 sensor.AddObservation(hit ? hitInfo.distance : current.RayDistance);
    237.             }
    238.  
    239.             sensor.AddObservation(m_Acceleration);
    240.         }
    241.  
    242.         public override void OnActionReceived(float[] vectorAction)
    243.         {
    244.             base.OnActionReceived(vectorAction);
    245.             InterpretDiscreteActions(vectorAction);
    246.  
    247.             // Find the next checkpoint when registering the current checkpoint that the agent has passed.
    248.             var next = (m_CheckpointIndex + 1) % Colliders.Length;
    249.             var nextCollider = Colliders[next];
    250.             var direction = (nextCollider.transform.position - m_Kart.transform.position).normalized;
    251.             var reward = Vector3.Dot(m_Kart.Rigidbody.velocity.normalized, direction);
    252.  
    253.             if (ShowRaycasts) Debug.DrawRay(AgentSensorTransform.position, m_Kart.Rigidbody.velocity, Color.blue);
    254.  
    255.             // Add rewards if the agent is heading in the right direction
    256.             AddReward(reward * TowardsCheckpointReward);
    257.             AddReward((m_Acceleration && !m_Brake ? 1.0f : 0.0f) * AccelerationReward);
    258.             AddReward(m_Kart.LocalSpeed() * SpeedReward);
    259.         }
    260.  
    261.         public override void OnEpisodeBegin()
    262.         {
    263.             switch (Mode)
    264.             {
    265.                 case AgentMode.Training:
    266.                     m_CheckpointIndex = Random.Range(0, Colliders.Length - 1);
    267.                     var collider = Colliders[m_CheckpointIndex];
    268.                     transform.localRotation = collider.transform.rotation;
    269.                     transform.position = collider.transform.position;
    270.                     m_Kart.Rigidbody.velocity = default;
    271.                     m_Acceleration = false;
    272.                     m_Brake = false;
    273.                     m_Steering = 0f;
    274.                     break;
    275.                 default:
    276.                     break;
    277.             }
    278.         }
    279.  
    280.         void InterpretDiscreteActions(float[] actions)
    281.         {
    282.             m_Steering = actions[0] - 1f;
    283.             m_Acceleration = actions[1] >= 1.0f;
    284.             m_Brake = actions[1] < 1.0f;
    285.         }
    286.  
    287.         public InputData GenerateInput()
    288.         {
    289.             return new InputData
    290.             {
    291.                 Accelerate = m_Acceleration,
    292.                 Brake = m_Brake,
    293.                 TurnInput = m_Steering
    294.             };
    295.         }
     
  2. christophergoy

    christophergoy

    Joined:
    Sep 16, 2015
    Posts:
    735
    Hi @Wolf00007,
    You are correct that you need to override the Heuristic method in this script.
    The Heuristic method passes a float[] that you would then write to. That array would then get passed to OnActionReceived and your agent would use that array as input.

    So just override Heuristic, fill the array with the keyboard input values to match the actions that are read in OnActionReceived.

    Let me know if you have any other questions.
     
  3. Wolf00007

    Wolf00007

    Joined:
    Jan 26, 2019
    Posts:
    24
    I have added values like below:

    Code (CSharp):
    1. public override void Heuristic(float[] actionsOut)
    2.         {
    3.             actionsOut[0] = Input.GetAxis("Horizontal");
    4.             actionsOut[1] = Input.GetButton("Accelerate");
    5.             actionsOut[2] = Input.GetButton("Brake");
    6.         }
    But I'm getting this error for both Accelerate and Brake action:
    upload_2021-5-5_21-28-37.png

    Which is expected because the these two actions are of bool type in the code:

    Code (CSharp):
    1. bool m_Acceleration;
    2. bool m_Brake;
    3. float m_Steering;
    So I'm not sure how to add accelerating and braking in this script.
    Thanks for your help btw and if you need more info from my side, let me know
     
  4. christophergoy

    christophergoy

    Joined:
    Sep 16, 2015
    Posts:
    735
    if get button returns a bool you can do something like this:
    Code (CSharp):
    1. public override void Heuristic(float[] actionsOut)
    2. {
    3.     actionsOut[0] = Input.GetAxis("Horizontal");
    4.     actionsOut[1] = Input.GetButton("Accelerate") ? 1 : 0;
    5.     actionsOut[2] = Input.GetButton("Brake") ? 1 : 0;
    6. }
    And if you are setting those variables elsewhere you need to convert back.
    It looks like this is done for you in the method InterpretDiscreteActions
     
    Last edited: May 5, 2021
  5. Wolf00007

    Wolf00007

    Joined:
    Jan 26, 2019
    Posts:
    24
    I also tried like this:

    Code (CSharp):
    1. public override void Heuristic(float[] actionsOut)
    2.         {
    3.             actionsOut[0] = Input.GetAxis("Horizontal");
    4.             actionsOut[1] = Input.GetAxis("Vertical");
    5.         }
    And it kinda worked because I can move the Agent but the Agent still wants to move by itself! I did set the Behavior Type to Heuristic Only and the agent does not have any neural network assigned so I'm not sure what I'm missing here...
     
  6. Wolf00007

    Wolf00007

    Joined:
    Jan 26, 2019
    Posts:
    24
    I have just did that and the error was gone but after hitting Play, I got hundreds of errors like this:
    upload_2021-5-5_21-44-20.png
     
  7. christophergoy

    christophergoy

    Joined:
    Sep 16, 2015
    Posts:
    735
    Ah, so it looks like accelerate was using a float value between -1 and 1. You need to add another discrete action in your behavior parameters so that you have 3 discrete actions instead of 2
     
  8. christophergoy

    christophergoy

    Joined:
    Sep 16, 2015
    Posts:
    735
    This approach works as well. It may be something to do with the default value of the vertical axis. Since it was using floats before it may be getting messed up since you are now using discrete actions (ints)
     
  9. Wolf00007

    Wolf00007

    Joined:
    Jan 26, 2019
    Posts:
    24
    I already have two branches with 3 actions per branch
    upload_2021-5-5_22-29-31.png

    Again, as this is a microgame, this was already prepared like this so I'm not sure why it has two branches.
     

    Attached Files:

  10. christophergoy

    christophergoy

    Joined:
    Sep 16, 2015
    Posts:
    735
    ah, ok. So that means for the neural network output:
    steering looks like:
    -1 = turn left
    0 = go straight
    1 = turn right

    for accelerate/brake
    -1 = brake
    0 = do nothing
    1 = accelerate

    so you need to map your actions accordingly from the keyboard to these values.

    Code (CSharp):
    1. public override void Heuristic(float[] actionsOut)
    2. {
    3.     actionsOut[0] = Input.GetAxis("Horizontal");
    4.     var brakeVal = Input.GetButton("Brake") ? -1 : 0
    5.     actionsOut[1] = Input.GetButton("Accelerate") ? 1 : brakeVal;
    6. }
     
  11. christophergoy

    christophergoy

    Joined:
    Sep 16, 2015
    Posts:
    735
    Sorry for the misunderstanding on my part. let me know if that helps
     
  12. Wolf00007

    Wolf00007

    Joined:
    Jan 26, 2019
    Posts:
    24
    That's okay, thank you for the quick responses :)

    I have tried the mapping like above and the errors are gone but once again, the agent is still trying to move by itself. He seems to be going backwards and steering left only for some reason. I can countersteer his steering but that's it. Do you have any other ideas?
     
  13. christophergoy

    christophergoy

    Joined:
    Sep 16, 2015
    Posts:
    735
    Sorry, I was wrong again.
    Code (CSharp):
    1.  
    2.         void InterpretDiscreteActions(float[] actions)
    3.         {
    4.             m_Steering = actions[0] - 1f;
    5.             m_Acceleration = actions[1] >= 1.0f;
    6.             m_Brake = actions[1] < 1.0f;
    7.         }
    8.  
    Steering looks like it's:
    0 is turn left
    1 is to go straight
    2 is to turn right

    Accelerate looks like:
    0 for brake
    1 for do nothing
    2 for accelerate

    which would change the code I sent you to look like this:
    Code (CSharp):
    1. public override void Heuristic(float[] actionsOut)
    2. {
    3.     var brake = 0f;
    4.     var idle = 1f;
    5.     var accelerate = 2f;
    6.     actionsOut[0] = Input.GetAxis("Horizontal");
    7.     var brakeVal = Input.GetButton("Brake") ? brake : idle;
    8.     actionsOut[1] = Input.GetButton("Accelerate") ? accelerate : brakeVal;
    9. }
     
  14. Wolf00007

    Wolf00007

    Joined:
    Jan 26, 2019
    Posts:
    24
    Now the car goes forward by itself, still steering right (@Edit - left, sorry) o_O

    I noticed that currently I am using both Agent and IInput class (not sure what IInput class is as Visual Studio does not recognize it):
    Code (CSharp):
    1. public class KartAgent : Agent, IInput
    But when I deleted the IInput class, the agent cannot move anymore but I still can through Heuristics. I was able to record some data through Demonstration Recorder and it looks good (I think) but this seems wrong as I have to disable agent movement (and that's what Heuristic should do?).
     
    Last edited: May 6, 2021
  15. Wolf00007

    Wolf00007

    Joined:
    Jan 26, 2019
    Posts:
    24
    Okay I changed the following:
    Code (CSharp):
    1. void InterpretDiscreteActions(float[] actions)
    2.         {
    3.             m_Steering = actions[0];
    4.             m_Brake = actions[1] < 1.0f;
    5.             m_Acceleration = actions[1] > 1.0f;
    6.     }
    without deleting the IInput class and it works as well, but again, it does not work when I then switch to Inference behavior rather than Heuristic. Is it okay to change the code for recording with Demonstration Recorder and then change the code back when training the Agent? It looks like I can't get it to work for both Inference and Heuristic without modifying the code when switching between the two... :/
     
  16. christophergoy

    christophergoy

    Joined:
    Sep 16, 2015
    Posts:
    735
    You shouldn't modify the code to record demos since that modified code is how you need to get things to work. I think if you spend a bit more time tinkering you'll eventually get it working.
     
  17. Wolf00007

    Wolf00007

    Joined:
    Jan 26, 2019
    Posts:
    24
    I think I found the issue. Agents do not have "idle" as an action - they either go forward or brake/go backwards. This is why in Heuristic behavior they go either forward or backwards by themselves (this depends which values we use for "actionsOut".

    This is why I added "+1" to the steering input to balance out the "-1" in the InterpretDiscreteActions function. Also, I changed the rest to be like this:
    Code (CSharp):
    1.  
    2. public override void Heuristic(float[] actionsOut)
    3.        {
    4.            var brake = 0f;
    5.            var accelerate = 1f;
    6.            actionsOut[0] = Input.GetAxis("Horizontal") + 1;
    7.            actionsOut[1] = Input.GetButton("Brake") ? brake : accelerate;
    8.        }
    9.  
    The car still goes forward by itself but I don't see how I could change this without changing the way Agents move (so modifying the rest of the code). But I was able to record the demos like that and all seems to work just fine.

    However, maybe I'm wrong and this is incorrect so if anyone has any other ideas, please let me know :)