Search Unity

  1. Megacity Metro Demo now available. Download now.
    Dismiss Notice
  2. Unity support for visionOS is now available. Learn more in our blog post.
    Dismiss Notice

Resolved SAC - Pytorch IndexError

Discussion in 'ML-Agents' started by darthmorf, Mar 19, 2023.

  1. darthmorf

    darthmorf

    Joined:
    Jun 21, 2017
    Posts:
    11
    Hi all,

    I've been able to train an agent using PPO totally fine. As the next stage in my research, I am now attempting to train an agent using SAC instead. However, I seem to run into a consistent issue which causes pytorch/tensor to crash. It doesn't happen at the same time, but it always happens after a few seconds of training (at 20x game speed of course). I've been able to run the Basic example SAC project fine, so I assume I've got something wrong in my config or agent setup but am at a loss as to where to start, as all I have changed is the algorithm used.

    If anyone has any ideas as to where to start troubleshooting, please let me know! I'm happy to provide further detail too where needed.

    upload_2023-3-19_12-58-31.png

    Code (CSharp):
    1. c:\users\sam\appdata\local\programs\python\python37\lib\site-packages\torch\cuda\__init__.py:52: UserWarning: CUDA initialization: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx (Triggered internally at  ..\c10\cuda\CUDAFunctions.cpp:100.)
    2.   return torch._C._cuda_getDeviceCount() > 0
    3.  
    4.             ┐  ╖
    5.         ╓╖╬│╡  ││╬╖╖
    6.     ╓╖╬│││││┘  ╬│││││╬╖
    7. ╖╬│││││╬╜        ╙╬│││││╖╖                               ╗╗╗
    8. ╬╬╬╬╖││╦╖        ╖╬││╗╣╣╣╬      ╟╣╣╬    ╟╣╣╣             ╜╜╜  ╟╣╣
    9. ╬╬╬╬╬╬╬╬╖│╬╖╖╓╬╪│╓╣╣╣╣╣╣╣╬      ╟╣╣╬    ╟╣╣╣ ╒╣╣╖╗╣╣╣╗   ╣╣╣ ╣╣╣╣╣╣ ╟╣╣╖   ╣╣╣
    10. ╬╬╬╬┐  ╙╬╬╬╬│╓╣╣╣╝╜  ╫╣╣╣╬      ╟╣╣╬    ╟╣╣╣ ╟╣╣╣╙ ╙╣╣╣  ╣╣╣ ╙╟╣╣╜╙  ╫╣╣  ╟╣╣
    11. ╬╬╬╬┐     ╙╬╬╣╣      ╫╣╣╣╬      ╟╣╣╬    ╟╣╣╣ ╟╣╣╬   ╣╣╣  ╣╣╣  ╟╣╣     ╣╣╣┌╣╣╜
    12. ╬╬╬╜       ╬╬╣╣      ╙╝╣╣╬      ╙╣╣╣╗╖╓╗╣╣╣╜ ╟╣╣╬   ╣╣╣  ╣╣╣  ╟╣╣╦╓    ╣╣╣╣╣
    13. ╙   ╓╦╖    ╬╬╣╣   ╓╗╗╖            ╙╝╣╣╣╣╝╜   ╘╝╝╜   ╝╝╝  ╝╝╝   ╙╣╣╣    ╟╣╣╣
    14.    ╩╬╬╬╬╬╬╦╦╬╬╣╣╗╣╣╣╣╣╣╣╝                                             ╫╣╣╣╣
    15.       ╙╬╬╬╬╬╬╬╣╣╣╣╣╣╝╜
    16.           ╙╬╬╬╣╣╣╜
    17.              ╙
    18.  
    19. Version information:
    20.   ml-agents: 0.29.0,
    21.   ml-agents-envs: 0.29.0,
    22.   Communicator API: 1.5.0,
    23.   PyTorch: 1.7.0+cu110
    24. c:\users\sam\appdata\local\programs\python\python37\lib\site-packages\torch\cuda\__init__.py:52: UserWarning: CUDA initialization: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx (Triggered internally at  ..\c10\cuda\CUDAFunctions.cpp:100.)
    25.   return torch._C._cuda_getDeviceCount() > 0
    26. [INFO] Listening on port 5004. Start training by pressing the Play button in the Unity Editor.
    27. [INFO] Connected to Unity environment with package version 2.0.1 and communication version 1.5.0
    28. [INFO] Connected new brain: Racecar?team=0
    29. [WARNING] Deleting TensorBoard data events.out.tfevents.1679229999.Sams-Desktop.37004.0 that was left over from a previous run.
    30. [WARNING] Deleting TensorBoard data events.out.tfevents.1679229999.Sams-Desktop.37004.0.meta that was left over from a previous run.
    31. [INFO] Hyperparameters for behavior name Racecar:
    32.         trainer_type:   sac
    33.         hyperparameters:
    34.           learning_rate:        0.0003
    35.           learning_rate_schedule:       constant
    36.           batch_size:   1024
    37.           buffer_size:  10240
    38.           buffer_init_steps:    0
    39.           tau:  0.005
    40.           steps_per_update:     10.0
    41.           save_replay_buffer:   False
    42.           init_entcoef: 0.01
    43.           reward_signal_steps_per_update:       10.0
    44.         network_settings:
    45.           normalize:    False
    46.           hidden_units: 128
    47.           num_layers:   2
    48.           vis_encode_type:      simple
    49.           memory:       None
    50.           goal_conditioning_type:       hyper
    51.           deterministic:        False
    52.         reward_signals:
    53.           extrinsic:
    54.             gamma:      0.99
    55.             strength:   1.0
    56.             network_settings:
    57.               normalize:        False
    58.               hidden_units:     128
    59.               num_layers:       2
    60.               vis_encode_type:  simple
    61.               memory:   None
    62.               goal_conditioning_type:   hyper
    63.               deterministic:    False
    64.         init_path:      None
    65.         keep_checkpoints:       5
    66.         checkpoint_interval:    500000
    67.         max_steps:      3000000
    68.         time_horizon:   64
    69.         summary_freq:   50000
    70.         threaded:       False
    71.         self_play:      None
    72.         behavioral_cloning:     None
    73. [INFO] Exported ./Assets/Training-Results\test\Racecar\Racecar-1059.onnx
    74. [INFO] Copied ./Assets/Training-Results\test\Racecar\Racecar-1059.onnx to ./Assets/Training-Results\test\Racecar.onnx.
    75. Traceback (most recent call last):
    76.   File "c:\users\sam\appdata\local\programs\python\python37\lib\runpy.py", line 193, in _run_module_as_main
    77.     "__main__", mod_spec)
    78.   File "c:\users\sam\appdata\local\programs\python\python37\lib\runpy.py", line 85, in _run_code
    79.     exec(code, run_globals)
    80.   File "C:\Users\Sam\AppData\Local\Programs\Python\Python37\Scripts\mlagents-learn.exe\__main__.py", line 7, in <module>
    81.   File "c:\users\sam\appdata\local\programs\python\python37\lib\site-packages\mlagents\trainers\learn.py", line 260, in main
    82.     run_cli(parse_command_line())
    83.   File "c:\users\sam\appdata\local\programs\python\python37\lib\site-packages\mlagents\trainers\learn.py", line 256, in run_cli
    84.     run_training(run_seed, options, num_areas)
    85.   File "c:\users\sam\appdata\local\programs\python\python37\lib\site-packages\mlagents\trainers\learn.py", line 132, in run_training
    86.     tc.start_learning(env_manager)
    87.   File "c:\users\sam\appdata\local\programs\python\python37\lib\site-packages\mlagents_envs\timers.py", line 305, in wrapped
    88.     return func(*args, **kwargs)
    89.   File "c:\users\sam\appdata\local\programs\python\python37\lib\site-packages\mlagents\trainers\trainer_controller.py", line 176, in start_learning
    90.     n_steps = self.advance(env_manager)
    91.   File "c:\users\sam\appdata\local\programs\python\python37\lib\site-packages\mlagents_envs\timers.py", line 305, in wrapped
    92.     return func(*args, **kwargs)
    93.   File "c:\users\sam\appdata\local\programs\python\python37\lib\site-packages\mlagents\trainers\trainer_controller.py", line 251, in advance
    94.     trainer.advance()
    95.   File "c:\users\sam\appdata\local\programs\python\python37\lib\site-packages\mlagents\trainers\trainer\rl_trainer.py", line 315, in advance
    96.     if self._update_policy():
    97.   File "c:\users\sam\appdata\local\programs\python\python37\lib\site-packages\mlagents_envs\timers.py", line 305, in wrapped
    98.     return func(*args, **kwargs)
    99.   File "c:\users\sam\appdata\local\programs\python\python37\lib\site-packages\mlagents\trainers\sac\trainer.py", line 205, in _update_policy
    100.     policy_was_updated = self._update_sac_policy()
    101.   File "c:\users\sam\appdata\local\programs\python\python37\lib\site-packages\mlagents\trainers\sac\trainer.py", line 272, in _update_sac_policy
    102.     update_stats = self.optimizer.update(sampled_minibatch, n_sequences)
    103.   File "c:\users\sam\appdata\local\programs\python\python37\lib\site-packages\mlagents_envs\timers.py", line 305, in wrapped
    104.     return func(*args, **kwargs)
    105.   File "c:\users\sam\appdata\local\programs\python\python37\lib\site-packages\mlagents\trainers\sac\optimizer_torch.py", line 552, in update
    106.     q1_stream = self._condense_q_streams(q1_out, disc_actions)
    107.   File "c:\users\sam\appdata\local\programs\python\python37\lib\site-packages\mlagents\trainers\sac\optimizer_torch.py", line 448, in _condense_q_streams
    108.     item, self._action_spec.discrete_branches
    109.   File "c:\users\sam\appdata\local\programs\python\python37\lib\site-packages\mlagents\trainers\torch\utils.py", line 269, in break_into_branches
    110.     for i in range(len(action_size))
    111.   File "c:\users\sam\appdata\local\programs\python\python37\lib\site-packages\mlagents\trainers\torch\utils.py", line 269, in <listcomp>
    112.     for i in range(len(action_size))
    113. IndexError: too many indices for tensor of dimension 1
    114. Press any key to continue . . .

    Code (CSharp):
    1. behaviors:
    2.   Racecar:
    3.     trainer_type: sac
    4.     hyperparameters:
    5.       learning_rate: 0.0003
    6.       learning_rate_schedule: constant
    7.       batch_size: 1024
    8.       buffer_size: 10240
    9.       buffer_init_steps: 0
    10.       tau: 0.005
    11.       steps_per_update: 10.0
    12.       save_replay_buffer: false
    13.       init_entcoef: 0.01
    14.       reward_signal_steps_per_update: 10.0
    15.     network_settings:
    16.       normalize: false
    17.       hidden_units: 128
    18.       num_layers: 2
    19.       vis_encode_type: simple
    20.     reward_signals:
    21.       extrinsic:
    22.         gamma: 0.99
    23.         strength: 1.0
    24.     keep_checkpoints: 5
    25.     max_steps: 3000000
    26.     time_horizon: 64
    27.     summary_freq: 50000
    28. checkpoint_settings:
    29.   run_id: racecar
    30.   initialize_from: null
    31.   load_model: false
    32.   resume: false
    33.   force: false
    34.   train_model: false
    35.   inference: false
    36.   results_dir: ./Assets/Training-Results
    37.  

    Code (CSharp):
    1. using System.Collections;
    2. using System.Collections.Generic;
    3. using UnityEngine;
    4. using Unity.MLAgents;
    5. using Unity.MLAgents.Sensors;
    6. using Unity.MLAgents.Actuators;
    7. using UnityEngine.InputSystem;
    8. using System;
    9. using UnityEditor.PackageManager.Requests;
    10. using Unity.VisualScripting;
    11. using Unity.MLAgents.Policies;
    12.  
    13. public class KartAgent : Agent
    14. {
    15.     // Config Params
    16.     [SerializeField] KartController kartController;
    17.     [SerializeField] TerrainColliderDetector[] terrainColliders;
    18.     [SerializeField] GameObject checkpointParent;
    19.     [SerializeField] bool handBreakEnabled = false;
    20.     [SerializeField] bool reverseEnabled = false;
    21.     [SerializeField] float steeringRange = 0.3f;
    22.     [SerializeField] bool manualControl = false;
    23.  
    24.     [Header("Rewards")]
    25.     [SerializeField] float stepReward = 0.001f;
    26.     [SerializeField] float failReward = -1f;
    27.     [SerializeField] float checkpointReward = 0.5f;
    28.     [SerializeField] float timeOut = 30.0f;
    29.     [SerializeField] [Range(1f, 20f)] float timeScale = 1f;
    30.  
    31.     // Cached Components
    32.  
    33.     // State
    34.     bool failed = false;
    35.     int checkpointIndex = 0;
    36.     float elapsedTime = 0;
    37.     RaceCheckpoint[] checkpoints;
    38.  
    39.     public override void Initialize()
    40.     {
    41.         // ResetScene();
    42.         terrainColliders = FindObjectsOfType<TerrainColliderDetector>();
    43.         checkpoints = checkpointParent.GetComponentsInChildren<RaceCheckpoint>(true);
    44.     }
    45.  
    46.     public override void CollectObservations(VectorSensor sensor)
    47.     {
    48.         sensor.AddObservation(kartController.GetRigidbody().velocity.magnitude);
    49.         sensor.AddObservation(Vector3.Distance(transform.position, checkpoints[checkpointIndex].transform.position));
    50.     }
    51.  
    52.     public override void OnActionReceived(ActionBuffers actions)
    53.     {
    54.         Time.timeScale = timeScale; // This shouldn't be needed, but is nice for demos
    55.  
    56.         if (!manualControl)
    57.         {
    58.             kartController.SetSpeed(Mathf.Abs(actions.ContinuousActions[0]));
    59.             kartController.SetTurn(actions.ContinuousActions[1]);
    60.         }
    61.  
    62.         elapsedTime += Time.deltaTime;
    63.  
    64.         foreach (TerrainColliderDetector terrainCollider in terrainColliders)
    65.         {
    66.             if (terrainCollider.GetAgentCollided())
    67.             {
    68.                 failed = true;
    69.                 break;
    70.             }
    71.         }
    72.  
    73.         CheckCheckpoints();
    74.  
    75.         AddReward(kartController.GetRigidbody().velocity.magnitude * stepReward);
    76.         AddReward(-Mathf.Abs(actions.ContinuousActions[1]) * stepReward);
    77.  
    78.         if (failed || Keyboard.current.rKey.isPressed)
    79.         {
    80.             Failure();
    81.         }
    82.  
    83.         if (elapsedTime > timeOut)
    84.         {
    85.             ResetScene();
    86.         }
    87.  
    88.         ShowReward();
    89.     }
    90.  
    91.     void CheckCheckpoints()
    92.     {
    93.         if (checkpoints[checkpointIndex].KartHitCheckpoint())
    94.         {
    95.             Debug.Log($"Checkpoint {checkpointIndex+1} hit!");
    96.  
    97.             AddReward(checkpointReward);
    98.  
    99.             checkpoints[checkpointIndex].Reset();
    100.             checkpoints[checkpointIndex].gameObject.SetActive(false);
    101.  
    102.             checkpointIndex = (checkpointIndex + 1) % checkpoints.Length;
    103.             checkpoints[checkpointIndex].gameObject.SetActive(true);
    104.         }
    105.     }
    106.  
    107.     void Failure()
    108.     {
    109.         AddReward(failReward);
    110.         ShowReward();
    111.         ResetScene();
    112.     }
    113.  
    114.     public override void OnEpisodeBegin()
    115.     {
    116.         //ResetScene();
    117.     }
    118.  
    119.     void ResetScene()
    120.     {
    121.         failed = false;
    122.         elapsedTime = 0;
    123.  
    124.         foreach (RaceCheckpoint checkpoint in checkpoints)
    125.         {
    126.             checkpoint.gameObject.SetActive(false);
    127.         }
    128.  
    129.         checkpointIndex = 0;
    130.         checkpoints[checkpointIndex].gameObject.SetActive(true);
    131.  
    132.         kartController.Reset_();
    133.  
    134.         foreach(TerrainColliderDetector terrainColliderDetector in terrainColliders)
    135.         {
    136.             terrainColliderDetector.Reset_();
    137.         }
    138.  
    139.         EndEpisode();
    140.     }
    141.  
    142.     public override void Heuristic(in ActionBuffers actionsOut)
    143.     {
    144.         base.Heuristic(actionsOut);
    145.     }
    146.  
    147.     private void ShowReward()
    148.     {
    149.         Debug.Log($"Current Reward: {GetCumulativeReward()}");
    150.     }
    151. }
    152.  
     
  2. darthmorf

    darthmorf

    Joined:
    Jun 21, 2017
    Posts:
    11