Search Unity

  1. We are migrating the Unity Forums to Unity Discussions. On July 12, the Unity Forums will become read-only. On July 15, Unity Discussions will become read-only until July 18, when the new design and the migrated forum contents will go live. Read our full announcement for more information and let us know if you have any questions.

RuntimeError when use imitation learning

Discussion in 'ML-Agents' started by antreas20197, Jun 23, 2021.

  1. antreas20197

    antreas20197

    Joined:
    Jun 9, 2021
    Posts:
    15
    Hello. I am trying to use imitation learning. However when I try to enter the recorded .demo file the following error, related to torch, appears. Thanks.

    RuntimeError: cannot sample n_sample > prob_dist.size(-1) samples without replacement

    Code (CSharp):
    1.  Version information:
    2.   ml-agents: 0.27.0,
    3.   ml-agents-envs: 0.27.0,
    4.   Communicator API: 1.5.0,
    5.   PyTorch: 1.9.0+cu111
    6. [INFO] Listening on port 5004. Start training by pressing the Play button in the Unity Editor.
    7. [INFO] Connected to Unity environment with package version 2.0.0-pre.3 and communication version 1.5.0
    8. [INFO] Connected new brain: ZARAgoal?team=1
    9. [WARNING] Deleting TensorBoard data events.out.tfevents.1624485266.AndreasPC.18212.0 that was left over from a previous run.
    10. [INFO] Hyperparameters for behavior name ZARAgoal:
    11.         trainer_type:   ppo
    12.         hyperparameters:
    13.           batch_size:   128
    14.           buffer_size:  2048
    15.           learning_rate:        0.0003
    16.           beta: 0.01
    17.           epsilon:      0.2
    18.           lambd:        0.95
    19.           num_epoch:    3
    20.           learning_rate_schedule:       linear
    21.         network_settings:
    22.           normalize:    False
    23.           hidden_units: 256
    24.           num_layers:   2
    25.           vis_encode_type:      simple
    26.           memory:       None
    27.           goal_conditioning_type:       hyper
    28.         reward_signals:
    29.           extrinsic:
    30.             gamma:      0.99
    31.             strength:   1.0
    32.             network_settings:
    33.               normalize:        False
    34.               hidden_units:     128
    35.               num_layers:       2
    36.               vis_encode_type:  simple
    37.               memory:   None
    38.               goal_conditioning_type:   hyper
    39.           gail:
    40.             gamma:      0.99
    41.             strength:   0.01
    42.             network_settings:
    43.               normalize:        False
    44.               hidden_units:     128
    45.               num_layers:       2
    46.               vis_encode_type:  simple
    47.               memory:   None
    48.               goal_conditioning_type:   hyper
    49.             learning_rate:      0.0003
    50.             encoding_size:      None
    51.             use_actions:        False
    52.             use_vail:   False
    53.             demo_path:  Demos/ZARAdemos/
    54.         init_path:      None
    55.         keep_checkpoints:       5
    56.         checkpoint_interval:    500000
    57.         max_steps:      100000
    58.         time_horizon:   64
    59.         summary_freq:   60000
    60.         threaded:       False
    61.         self_play:      None
    62.         behavioral_cloning:
    63.           demo_path:    Demos/ZARAdemos/
    64.           steps:        50000
    65.           strength:     1.0
    66.           samples_per_update:   0
    67.           num_epoch:    None
    68.           batch_size:   None
    69. d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\torch\nn\init.py:388: UserWarning: Initializing zero-element tensors is a no-op
    70.   warnings.warn("Initializing zero-element tensors is a no-op")
    71. d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\torch\nn\init.py:426: UserWarning: Initializing zero-element tensors is a no-op
    72.   warnings.warn("Initializing zero-element tensors is a no-op")
    73. Traceback (most recent call last):
    74.   File "d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\mlagents\trainers\trainer_controller.py", line 176, in start_learning
    75.     n_steps = self.advance(env_manager)
    76.   File "d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\mlagents_envs\timers.py", line 305, in wrapped
    77.     return func(*args, **kwargs)
    78.   File "d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\mlagents\trainers\trainer_controller.py", line 234, in advance
    79.     new_step_infos = env_manager.get_steps()
    80.   File "d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\mlagents\trainers\env_manager.py", line 124, in get_steps
    81.     new_step_infos = self._step()
    82.   File "d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\mlagents\trainers\subprocess_env_manager.py", line 298, in _step
    83.     self._queue_steps()
    84.   File "d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\mlagents\trainers\subprocess_env_manager.py", line 291, in _queue_steps
    85.     env_action_info = self._take_step(env_worker.previous_step)
    86.   File "d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\mlagents_envs\timers.py", line 305, in wrapped
    87.     return func(*args, **kwargs)
    88.   File "d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\mlagents\trainers\subprocess_env_manager.py", line 429, in _take_step
    89.     all_action_info[brain_name] = self.policies[brain_name].get_action(
    90.   File "d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\mlagents\trainers\policy\torch_policy.py", line 212, in get_action
    91.     run_out = self.evaluate(decision_requests, global_agent_ids)
    92.   File "d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\mlagents_envs\timers.py", line 305, in wrapped
    93.     return func(*args, **kwargs)
    94.   File "d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\mlagents\trainers\policy\torch_policy.py", line 178, in evaluate
    95.     action, log_probs, entropy, memories = self.sample_actions(
    96.   File "d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\mlagents_envs\timers.py", line 305, in wrapped
    97.     return func(*args, **kwargs)
    98.   File "d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\mlagents\trainers\policy\torch_policy.py", line 140, in sample_actions
    99.     actions, log_probs, entropies, memories = self.actor.get_action_and_stats(
    100.   File "d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\mlagents\trainers\torch\networks.py", line 626, in get_action_and_stats
    101.     action, log_probs, entropies = self.action_model(encoding, masks)
    102.   File "d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl
    103.     return forward_call(*input, **kwargs)
    104.   File "d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\mlagents\trainers\torch\action_model.py", line 194, in forward
    105.     actions = self._sample_action(dists)
    106.   File "d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\mlagents\trainers\torch\action_model.py", line 84, in _sample_action
    107.     discrete_action.append(discrete_dist.sample())
    108.   File "d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\mlagents\trainers\torch\distributions.py", line 114, in sample
    109.     return torch.multinomial(self.probs, 1)
    110. RuntimeError: cannot sample n_sample > prob_dist.size(-1) samples without replacement
    111.  
    112. During handling of the above exception, another exception occurred:
    113.  
    114. Traceback (most recent call last):
    115.   File "C:\Users\antre\AppData\Local\Programs\Python\Python39\lib\runpy.py", line 197, in _run_module_as_main
    116.     return _run_code(code, main_globals, None,
    117.   File "C:\Users\antre\AppData\Local\Programs\Python\Python39\lib\runpy.py", line 87, in _run_code
    118.     exec(code, run_globals)
    119.   File "D:\Desktop\Crowds-and-ML-Agents\venv\Scripts\mlagents-learn.exe\__main__.py", line 7, in <module>
    120.   File "d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\mlagents\trainers\learn.py", line 250, in main
    121.     run_cli(parse_command_line())
    122.   File "d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\mlagents\trainers\learn.py", line 246, in run_cli
    123.     run_training(run_seed, options)
    124.   File "d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\mlagents\trainers\learn.py", line 125, in run_training
    125.     tc.start_learning(env_manager)
    126.   File "d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\mlagents_envs\timers.py", line 305, in wrapped
    127.     return func(*args, **kwargs)
    128.   File "d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\mlagents\trainers\trainer_controller.py", line 201, in start_learning
    129.     self._save_models()
    130.   File "d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\mlagents_envs\timers.py", line 305, in wrapped
    131.     return func(*args, **kwargs)
    132.   File "d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\mlagents\trainers\trainer_controller.py", line 80, in _save_models
    133.     self.trainers[brain_name].save_model()
    134.   File "d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\mlagents\trainers\trainer\rl_trainer.py", line 185, in save_model
    135.     model_checkpoint = self._checkpoint()
    136.   File "d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\mlagents_envs\timers.py", line 305, in wrapped
    137.     return func(*args, **kwargs)
    138.   File "d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\mlagents\trainers\trainer\rl_trainer.py", line 157, in _checkpoint
    139.     export_path, auxillary_paths = self.model_saver.save_checkpoint(
    140.   File "d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\mlagents\trainers\model_saver\torch_model_saver.py", line 59, in save_checkpoint
    141.     self.export(checkpoint_path, behavior_name)
    142.   File "d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\mlagents\trainers\model_saver\torch_model_saver.py", line 64, in export
    143.     self.exporter.export_policy_model(output_filepath)
    144.   File "d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\mlagents\trainers\torch\model_serialization.py", line 159, in export_policy_model
    145.     torch.onnx.export(
    146.   File "d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\torch\onnx\__init__.py", line 275, in export
    147.     return utils.export(model, args, f, export_params, verbose, training,
    148.   File "d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\torch\onnx\utils.py", line 88, in export
    149.     _export(model, args, f, export_params, verbose, training, input_names, output_names,
    150.   File "d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\torch\onnx\utils.py", line 689, in _export
    151.     _model_to_graph(model, args, verbose, input_names,
    152.   File "d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\torch\onnx\utils.py", line 458, in _model_to_graph
    153.     graph, params, torch_out, module = _create_jit_graph(model, args,
    154.   File "d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\torch\onnx\utils.py", line 422, in _create_jit_graph
    155.     graph, torch_out = _trace_and_get_graph_from_model(model, args)
    156.   File "d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\torch\onnx\utils.py", line 373, in _trace_and_get_graph_from_model
    157.     torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True)
    158.   File "d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\torch\jit\_trace.py", line 1160, in _get_trace_graph
    159.     outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
    160.   File "d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl
    161.     return forward_call(*input, **kwargs)
    162.   File "d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\torch\jit\_trace.py", line 127, in forward
    163.     graph, out = torch._C._create_graph_by_tracing(
    164.   File "d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\torch\jit\_trace.py", line 118, in wrapper
    165.     outs.append(self.inner(*trace_inputs))
    166.   File "d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl
    167.     return forward_call(*input, **kwargs)
    168.   File "d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\torch\nn\modules\module.py", line 1039, in _slow_forward
    169.     result = self.forward(*input, **kwargs)
    170.   File "d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\mlagents\trainers\torch\networks.py", line 664, in forward
    171.     ) = self.action_model.get_action_out(encoding, masks)
    172.   File "d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\mlagents\trainers\torch\action_model.py", line 171, in get_action_out
    173.     discrete_out_list = [
    174.   File "d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\mlagents\trainers\torch\action_model.py", line 172, in <listcomp>
    175.     discrete_dist.exported_model_output()
    176.   File "d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\mlagents\trainers\torch\distributions.py", line 136, in exported_model_output
    177.     return self.sample()
    178.   File "d:\desktop\crowds-and-ml-agents\venv\lib\site-packages\mlagents\trainers\torch\distributions.py", line 114, in sample
    179.     return torch.multinomial(self.probs, 1)
    180. RuntimeError: cannot sample n_sample > prob_dist.size(-1) samples without replacement
     
  2. TreyK-47

    TreyK-47

    Unity Technologies

    Joined:
    Oct 22, 2019
    Posts:
    1,843
    I'll flag with the team for some guidance. Which version of ML Agents are you using?
     
  3. antreas20197

    antreas20197

    Joined:
    Jun 9, 2021
    Posts:
    15
    Hello Trey. Finally I found out the problem.
    Although I do not using Discrete actions, i had to set the Branch 0 Size to 1 instead of 0.

    Thanks.
     
    TreyK-47 likes this.