Search Unity

Question ML-Agents for running through 2D mazes

Discussion in 'ML-Agents' started by vadwave, May 9, 2021.

  1. vadwave

    vadwave

    Joined:
    Feb 6, 2014
    Posts:
    4
    Hello everyone!

    I'm working on my project at university, which involves creating an ML agent to traverse mazes in 2D. But my agents are not trained even after 500,000 steps to walk banal and just sausage on the starting point...

    I don't quite understand what the reason might be... my Agent has 2 sensors, the front one is long and the back one is short. There are points deducted for being on one point and for colliding with a wall, as well as reward points for reaching the end of the maze. I have pre-recorded 6 demo passes. I also use all the methods at once (RL, BC, GAIL, Curiosity). I use the first version of ML-Agents, i.e. 1.0.7.

    With a simpler maze-spirals, he copes quickly, but I still want him to pass a random maze.

    If it's complicated and time-consuming, how can you just get an Agent to go through an identical recorded demo?

    I apologize in advance for my English language and possibly incorrect creation of the post!

    Player:
    trainer: ppo

    # Trainer configs common to PPO/SAC (excluding reward signals)
    batch_size: 256
    buffer_size: 20480
    hidden_units: 128
    learning_rate: 3.0e-4
    learning_rate_schedule: linear
    max_steps: 5.0e6
    normalize: false
    num_layers: 2
    time_horizon: 64
    summary_freq: 20000
    vis_encoder_type: simple
    init_path: null


    # PPO-specific configs
    beta: 5.0e-4
    epsilon: 0.2
    lambd: 0.99
    num_epoch: 3
    threaded: true

    # memory
    use_recurrent: false
    sequence_length: 64
    memory_size: 128

    # behavior cloning
    behavioral_cloning:
    demo_path: Demos/PlayerDemo.demo
    strength: 0.5
    steps: 150000
    batch_size: 512
    num_epoch: 3
    samples_per_update: 0

    # rewards
    reward_signals:
    extrinsic:
    strength: 1.0
    gamma: 0.99

    # GAIL
    gail:
    strength: 0.5
    gamma: 0.99
    demo_path: Demos/PlayerDemo.demo

    # curiosity module
    curiosity:
    strength: 0.02
    gamma: 0.99
    encoding_size: 256
    learning_rate: 3.0e-4

    # self-play
    self_play:
    window: 10
    play_against_latest_model_ratio: 0.5
    save_steps: 50000
    swap_steps: 50000
    team_change: 100000
    Code (CSharp):
    1. using System;
    2. using System.Collections;
    3. using System.Collections.Generic;
    4. using Unity.MLAgents;
    5. using Unity.MLAgents.Sensors;
    6. using UnityEngine;
    7. using Rewards = Constants.Scores.Agent;
    8.  
    9. public class Player : Agent, IDamageable, IDamageDealer, IMovable, IRotable, IEye, IPocket
    10. {
    11.     [SerializeField] float health = 100f;
    12.     [SerializeField] float moveSpeed = 5f;
    13.     [SerializeField] float rotateSpeed = 3f;
    14.     [SerializeField] Rigidbody2D rigBody;
    15.     [SerializeField] Transform body;
    16.  
    17.     [Header("IEye")]
    18.     [SerializeField] float viewDistance = 10f;
    19.     [Range(0, 360)]
    20.     [SerializeField] int viewAngle = 20;
    21.     [SerializeField] bool enableVision = false;
    22.     [Header("DEBUG")]
    23.     [SerializeField] Transform lastPoint;
    24.     [SerializeField] List<Transform> visibleTargets = new List<Transform>();
    25.  
    26.     const float timeDelay = 0.0f;
    27.     private bool isWaiting = false;
    28.     float scores = 0;
    29.     int keys = 0;
    30.  
    31.  
    32.     Coroutine corFind;
    33.     Coroutine corCheckPos;
    34.  
    35.     LevelManager level;
    36.  
    37.     EnvironmentParameters resetParams;
    38.     private float sumValue;
    39.  
    40.     public float Health => health;
    41.     public float Speed => moveSpeed;
    42.     public float SpeedRotate => rotateSpeed;
    43.     public bool OpenEye => enableVision;
    44.     public float Radius => viewDistance;
    45.     public int Angle => viewAngle;
    46.     public int Keys => keys;
    47.     public float Scores => scores;
    48.  
    49.  
    50.     public event Action<Transform, Transform> OnRespawn;
    51.     public event Action OnEndedRespawn;
    52.     public event Action OnEscaped;
    53.     public event Action<float> OnAddedScore;
    54.     #region Agent
    55.  
    56.     public override void Initialize()
    57.     {
    58.         //base.Initialize();
    59.         resetParams = Academy.Instance.EnvironmentParameters;
    60.         SetResetParameters();
    61.     }
    62.     public override void OnEpisodeBegin()
    63.     {
    64.         //base.OnEpisodeBegin();
    65.         Debug.Log("Total Reward: " + sumValue.ToString());
    66.         keys = 0;
    67.         scores = 0;
    68.         sumValue = 0;
    69.         SetResetParameters();
    70.         Respawn();
    71.     }
    72.     public override void Heuristic(float[] actionsOut)
    73.     {
    74.         //base.Heuristic(actionsOut);
    75.         //InputControl(ref actionsOut);
    76.         InputControl(ref actionsOut);
    77.  
    78.     }
    79.     public override void OnActionReceived(float[] vectorAction)
    80.     {
    81.         //base.OnActionReceived(vectorAction);
    82.         //Debug.Log(" X: " + vectorAction[0] + " Y: " + vectorAction[1] + " Z:" + vectorAction[2] + " W:" + vectorAction[3]);
    83.         Move(vectorAction);
    84.         Find(true);
    85.         AddReward(-1f / MaxStep);
    86.         sumValue += (-1f / MaxStep);
    87.     }
    88.     public override void CollectObservations(VectorSensor sensor)
    89.     {
    90.         base.CollectObservations(sensor);
    91.         if (sensor != null)
    92.         {
    93.             if (rigBody)
    94.             {
    95.                 sensor.AddObservation(rigBody.velocity.normalized.x);
    96.                 sensor.AddObservation(rigBody.velocity.normalized.y);
    97.  
    98.             }
    99.             sensor.AddObservation(body.rotation.normalized.z);//body.rotation.eulerAngles.normalized.z
    100.             sensor.AddObservation(body.rotation.normalized.w);
    101.             if (level.exit)
    102.             {
    103.                 Vector2 dirToExit = (level.exit.position - rigBody.transform.position).normalized;
    104.                 //Debug.Log(dirToExit.ToString());
    105.                 sensor.AddObservation(dirToExit.x);
    106.                 sensor.AddObservation(dirToExit.y);
    107.  
    108.             }
    109.         }
    110.      
    111.  
    112.     }
    113.     public void AddAgentReward(float value)
    114.     {
    115.         sumValue += value;
    116.         //Debug.Log("Add Reward: " + value.ToString());
    117.         AddReward(value);
    118.     }
    119.  
    120.     void SetResetParameters()
    121.     {
    122.         if (resetParams != null)
    123.         {
    124.             moveSpeed = resetParams.GetWithDefault("moveSpeed", moveSpeed);
    125.             rotateSpeed = resetParams.GetWithDefault("rotateSpeed", rotateSpeed);
    126.  
    127.             viewAngle = (int)resetParams.GetWithDefault("viewAngle", viewAngle);
    128.             viewDistance = resetParams.GetWithDefault("viewDistance", viewDistance);
    129.             if (level)
    130.             {
    131.                 level.SetParameters(resetParams);
    132.             }
    133.             else
    134.             {
    135.                 this.transform.root.GetComponent<LevelManager>().SetParameters(resetParams);
    136.             }
    137.         }
    138.  
    139.     }
    140.  
    141.     #endregion
    142.  
    143.  
    144.     #region Moving
    145.  
    146.     public void Move(float[] vectorAction)
    147.     {
    148.         if (isWaiting) return;
    149.         Vector3 dir;
    150.         float angleRotate=0;
    151.         Quaternion quaternionAngle;
    152.         GetDirection(vectorAction, out dir, out quaternionAngle);
    153.         Quaternion old = Quaternion.Euler(0, 0, -angleRotate);
    154.  
    155.         Vector2 pos = rigBody.transform.position + dir * Time.deltaTime;
    156.         rigBody.MovePosition(pos);
    157.  
    158.  
    159.         body.rotation = Quaternion.Lerp(body.rotation, quaternionAngle, rotateSpeed * Time.deltaTime);
    160.     }
    161.     void InputControl(ref float[] actionsOut)
    162.     {
    163.         actionsOut[0] = Input.GetAxis("Vertical");
    164.         actionsOut[1] = Input.GetAxis("Horizontal");
    165.  
    166.         Vector3 mousePos = Utils.Instance.GetPosMousePosition();
    167.         Vector2 direction = (mousePos - transform.position).normalized;
    168.         float angle = Mathf.Atan2(direction.x, direction.y) * Mathf.Rad2Deg;
    169.         Quaternion eulerAngle = Quaternion.Euler(0, 0, -angle);
    170.         //float normalAngle = eulerAngle.z;
    171.         //actionsOut[2] = angle;
    172.         actionsOut[2] = eulerAngle.z;
    173.         actionsOut[3] = eulerAngle.w;
    174.     }
    175.     void GetDirection(float[] vectorAction, out Vector3 dir, out Quaternion angleRotate)
    176.     {
    177.         dir = Vector3.zero;
    178.         dir += rigBody.transform.up * vectorAction[0] * moveSpeed;
    179.         dir += rigBody.transform.right * vectorAction[1] * moveSpeed;
    180.         //angleRotate = vectorAction[2];
    181.         angleRotate = new Quaternion(0, 0, vectorAction[2], vectorAction[3]);
    182.     }
    183.  
    184.     public void Move()
    185.     {
    186.  
    187.     }
    188.     public void Rotate(float angle, float speed, float startAngle)
    189.     {
    190.  
    191.     }
    192.  
    193.  
    194.     #endregion
    195.  
    196.     #region Vision
    197.  
    198.     public void Find(bool enable)
    199.     {
    200.         if (enable)
    201.         {
    202.             if (corFind == null)
    203.                 corFind = StartCoroutine(IEFindTargetsInRadius(timeDelay));
    204.         }
    205.         else
    206.         {
    207.             StopCoroutine(corFind);
    208.             corFind = null;
    209.         }
    210.         CheckTargets();
    211.     }
    212.     public void Alert(bool enable)
    213.     {
    214.  
    215.     }
    216.  
    217.     void CheckTargets()
    218.     {
    219.         foreach (Transform target in visibleTargets)
    220.         {
    221.             if (target.GetComponent<SecurityCamera>())
    222.             {
    223.                 AddAgentReward(- Rewards.Check);
    224.                 Debug.Log("Find Camera score: -" + Rewards.Check);
    225.             }
    226.             else if (target.GetComponent<Guard>())
    227.             {
    228.                 AddAgentReward(- Rewards.Check);
    229.                 Debug.Log("Find Guard score: -" + Rewards.Check);
    230.             }
    231.             else if (target.GetComponent<CollectLogic>())
    232.             {
    233.                 AddAgentReward(Rewards.Check);
    234.                 Debug.Log("Find CollectLogic score: +" + Rewards.Check);
    235.             }
    236.             else if (target.GetComponent<KeyLogic>())
    237.             {
    238.                 AddAgentReward(Rewards.Check);
    239.                 Debug.Log("Find KeyLogic score: +" + Rewards.Check);
    240.             }
    241.         }
    242.     }
    243.     void OnVisiblePlayer()
    244.     {
    245.         AddAgentReward(-Rewards.Visible);
    246.     }
    247.  
    248.     #endregion
    249.  
    250.     #region HealthAndAttack
    251.  
    252.     public void DealDamage(IDamageable damageable, int amount)
    253.     {
    254.         throw new System.NotImplementedException();
    255.     }
    256.     public void TakeDamage(int amount)
    257.     {
    258.         throw new System.NotImplementedException();
    259.     }
    260.  
    261.     #endregion
    262.  
    263.  
    264.     #region Coroutines
    265.  
    266.     IEnumerator IEFindTargetsInRadius(float delay)
    267.     {
    268.         while (true)
    269.         {
    270.             yield return new WaitForSeconds(delay);
    271.             GameMath.FindVisibleTargets(body, visibleTargets, viewDistance, viewAngle, true);
    272.         }
    273.     }
    274.     IEnumerator IEWaitingAfterRespawn(float delay)
    275.     {
    276.         isWaiting = true;
    277.         yield return new WaitForSeconds(delay);
    278.         isWaiting = false;
    279.         OnEndedRespawn?.Invoke();
    280.         yield return null;
    281.         if(corCheckPos==null)
    282.         corCheckPos = StartCoroutine(IECheckPosition(1f));
    283.     }
    284.  
    285.     IEnumerator IECheckPosition(float delay)
    286.     {
    287.         while (true)
    288.         {
    289.             Vector3 lastPos = body.position;
    290.             yield return new WaitForSeconds(delay);
    291.             Vector3 currentPos = body.position;
    292.             Vector2 minSize = new Vector2(lastPos.x-1, lastPos.y-1);
    293.             Vector2 maxSize = new Vector2(lastPos.x+1, lastPos.y+1);
    294.             if (((minSize.x <= currentPos.x) && (currentPos.x <= maxSize.x)) &&
    295.                 ((minSize.y <= currentPos.y) && (currentPos.y <= maxSize.y)))
    296.             {
    297.                 AddAgentReward(-Rewards.Check);
    298.                 //Debug.Log("ALARM! Change Position!");
    299.             }
    300.             yield return null;
    301.         }
    302.     }
    303.  
    304.     #endregion
    305.  
    306.  
    307.     #region LevelLogic
    308.  
    309.     public void ExitLevel(bool success = true)
    310.     {
    311.         if(corCheckPos!=null)
    312.         StopCoroutine(corCheckPos);
    313.         corCheckPos = null;
    314.         OnEscaped?.Invoke();
    315.         float tempReward = (success) ? Rewards.Win : -Rewards.Win;
    316.         sumValue += (success) ? Rewards.Win : -Rewards.Win;
    317.         SetReward(tempReward);
    318.         EndEpisode();
    319.     }
    320.  
    321.     public void Respawn()
    322.     {
    323.         rigBody.velocity = default(Vector2);
    324.         rigBody.angularVelocity = 0;
    325.  
    326.         OnRespawn?.Invoke(rigBody.transform, body);
    327.  
    328.         StartCoroutine(IEWaitingAfterRespawn(1f));
    329.     }
    330.  
    331.     public void SetLevel(LevelManager levelManager)
    332.     {
    333.         this.level = levelManager;
    334.     }
    335.  
    336.     private void OnTriggerEnter2D(Collider2D collision)
    337.     {
    338.         if (collision.gameObject.tag == "Finish")
    339.         {
    340.             ExitLevel();
    341.         }
    342.     }
    343.     private void OnCollisionEnter2D(Collision2D collision)
    344.     {
    345.         if (collision.transform.tag == "Wall")
    346.         {
    347.             AddAgentReward(-(Rewards.Check * 2));
    348.         }
    349.  
    350.  
    351.     }
    352.     private void OnCollisionExit2D(Collision2D collision)
    353.     {
    354.         if (collision.transform.tag == "Wall")
    355.         {
    356.             //AddAgentReward(Rewards.Check);
    357.         }
    358.     }
    359.  
    360.     #endregion
    361.  
    362.     #region ItemActions
    363.     public bool UseKey()
    364.     {
    365.         if (Keys > 0)
    366.         {
    367.             keys--;
    368.             AddAgentReward(Rewards.Key);
    369.             return true;
    370.         }
    371.         return false;
    372.     }
    373.  
    374.     public void AddKey()
    375.     {
    376.         keys++;
    377.         AddAgentReward(Rewards.Key);
    378.     }
    379.  
    380.     public void Collect()
    381.     {
    382.         scores++;
    383.         AddAgentReward(Rewards.Collectable);
    384.         OnAddedScore?.Invoke(scores);
    385.     }
    386.  
    387.     #endregion
    388. }
    389.  

    Also a link to the github repository:
     
  2. murphlaw

    murphlaw

    Joined:
    Apr 18, 2021
    Posts:
    5
    Hello,

    I'm new to ML Agents and only followed the Unity Learn projects, but here is a screenshot of my Tensorboard graphs which show the reward my agent got in relation to how many steps it took.

    Screenshot (37).png

    As you can see, my agent got the first rewards only after abut 750k (orange, first run) or even 1.5m (blue, second run) steps. Not sure, if this is the case for you as well, but did you try running it for 5m steps once, to see, if you get slightly better results?
     
  3. vadwave

    vadwave

    Joined:
    Feb 6, 2014
    Posts:
    4
    Hello! Thanks for the answer!
    No, I didn't train the agent for 5m steps. I stopped at 1m, because the average reward was constantly decreasing for no apparent reason. I will try to complete 5m steps and let you know the result!