Search Unity

  1. Megacity Metro Demo now available. Download now.
    Dismiss Notice
  2. Unity support for visionOS is now available. Learn more in our blog post.
    Dismiss Notice

Question Training Not Progress Through 1st Episode.

Discussion in 'ML-Agents' started by alexshen1008, Feb 13, 2023.

  1. alexshen1008

    alexshen1008

    Joined:
    Feb 13, 2023
    Posts:
    1
    Hi there,

    I'm currently learning about ml-agents and want to play around with it by building a scene that can rotate an object (agent object) to match the orientation of the target object.

    The problem I have here is that once start playing the scene for training, unity is frozen at Application.EnterPlayMode like below.

    upload_2023-2-13_15-4-49.png

    So I logged the number of times that OnActionReceived is called and the number of OnEpisodeBegin called, and find out that no matter how long I left the training running, theOnActionReceived is always called 34 times and the training is not progressing through.
    upload_2023-2-13_15-8-33.png

    There is an exception caught as below but not sure whether that's caused by my Ctrl+C to terminate the python side learning or not:

    upload_2023-2-13_15-13-26.png

    The Heuristic mode runs with no issue by giving the agent some hard-coded rotations.

    Sorry if not described this clearly being new to both Unity and ML-Agents.Thanks ahead for any help on this one.


    ml-agents: 0.29.0,
    ml-agents-envs: 0.29.0,
    communicator API: 1.5.0,
    PyTorch: 1.7.1+cu110

    Here's all the code:

    Code (CSharp):
    1. using System.Collections;
    2. using System.Collections.Generic;
    3. using UnityEngine;
    4. using Unity.MLAgents;
    5. using Unity.MLAgents.Actuators;
    6. using Unity.MLAgents.Sensors;
    7. using System.Linq;
    8.  
    9. public class registerAgent : Agent
    10. {
    11.     [SerializeField] private Transform ARefTransform;
    12.     [SerializeField] private Transform DMTransform;
    13.     [SerializeField] private Transform DLTransform;
    14.     [SerializeField] private Transform PMTransform;
    15.     [SerializeField] private Transform PLTransform;
    16.     [SerializeField] private Transform DFCTransform;
    17.     [SerializeField] private Transform arefTransform;
    18.     [SerializeField] private Transform dmTransform;
    19.     [SerializeField] private Transform dlTransform;
    20.     [SerializeField] private Transform pmTransform;
    21.     [SerializeField] private Transform plTransform;
    22.     [SerializeField] private Transform dfcTransform;
    23.     [SerializeField] private GameObject boundingBox;
    24.     [SerializeField] private MeshRenderer indicatorRenderer;
    25.     [SerializeField] private Material successMaterial;
    26.     [SerializeField] private Material failMaterial;
    27.  
    28.     public float offsetARef;
    29.     public float offsetDM;
    30.     public float offsetDL;
    31.     public float offsetPM;
    32.     public float offsetPL;
    33.     public float offsetDFC;
    34.     public Vector3 offsetVectorARef;
    35.     public Vector3 offsetVectorDM;
    36.     public Vector3 offsetVectorDL;
    37.     public Vector3 offsetVectorPM;
    38.     public Vector3 offsetVectorPL;
    39.     public Vector3 offsetVectorDFC;
    40.     public List<float> offsetsPreviousFrame;
    41.     public List<float> offsetsCurrentFrame;
    42.     public List<float> initialOffsets;
    43.     public float totalDeltaOffset;
    44.     public int episodeCount;
    45.     public Quaternion initialRotation;
    46.     public bool gettingCloser;
    47.     public bool DEBUG = true;
    48.     public int actionsCount;
    49.  
    50.  
    51.     public override void Initialize()
    52.     {
    53.         episodeCount = 0;
    54.         initialRotation = transform.localRotation;
    55.     }
    56.  
    57.     public override void OnEpisodeBegin()
    58.     {
    59.         episodeCount += 1;
    60.         actionsCount = 0;
    61.         Debug.Log("New Episode Triggered :"+episodeCount.ToString());
    62.         transform.localPosition = Vector3.zero;
    63.         transform.localRotation = initialRotation;
    64.         transform.position += arefTransform.position - ARefTransform.position;
    65.         UpdateOffset();
    66.         initialOffsets = offsetsCurrentFrame;
    67.     }
    68.  
    69.     public void UpdateOffset()
    70.     {
    71.         offsetsPreviousFrame = new List<float> { offsetARef, offsetDM, offsetDL, offsetPM, offsetPL, offsetDFC };
    72.         offsetARef = Vector3.Distance(arefTransform.position, ARefTransform.position);
    73.         offsetDM = Vector3.Distance(dmTransform.position, DMTransform.position);
    74.         offsetDL = Vector3.Distance(dlTransform.position, DLTransform.position);
    75.         offsetPM = Vector3.Distance(pmTransform.position, PMTransform.position);
    76.         offsetPL = Vector3.Distance(plTransform.position, PLTransform.position);
    77.         offsetDFC = Vector3.Distance(dfcTransform.position, DFCTransform.position);
    78.         offsetsCurrentFrame = new List<float> {offsetARef, offsetDM, offsetDL, offsetPM, offsetPL, offsetDFC};
    79.         offsetVectorARef = arefTransform.position - ARefTransform.position;
    80.         offsetVectorDM = dmTransform.position - DMTransform.position;
    81.         offsetVectorDL = dlTransform.position - DLTransform.position;
    82.         offsetVectorPM = pmTransform.position - PMTransform.position;
    83.         offsetVectorPL = plTransform.position - PLTransform.position;
    84.         offsetVectorDFC = dfcTransform.position - DFCTransform.position;
    85.     }
    86.  
    87.  
    88.  
    89.     public override void CollectObservations(VectorSensor sensor)
    90.     {
    91.         UpdateOffset();
    92.  
    93.         sensor.AddObservation(offsetVectorDM);
    94.         sensor.AddObservation(offsetVectorDL);
    95.         sensor.AddObservation(offsetVectorPM);
    96.         sensor.AddObservation(offsetVectorPL);
    97.         sensor.AddObservation(transform.rotation.x);
    98.         sensor.AddObservation(transform.rotation.y);
    99.         sensor.AddObservation(transform.rotation.z);
    100.     }                      
    101.                            
    102.     public override void OnActionReceived(ActionBuffers actions)
    103.     {
    104.         actionsCount += 1;
    105.         Debug.Log("Action Triggered :" + actionsCount.ToString());
    106.         if (actions.Equals(ActionBuffers.Empty)) {
    107.             Debug.Log("Action Shoot Empty");
    108.         }
    109.         float rotationX = Mathf.Clamp(actions.ContinuousActions[0], -1f, 1f);
    110.         float rotationY = Mathf.Clamp(actions.ContinuousActions[1], -1f, 1f);
    111.         float rotationZ = Mathf.Clamp(actions.ContinuousActions[2], -1f, 1f);
    112.      
    113.         float rotationSpeed = 2f;
    114.  
    115.         transform.RotateAround(ARefTransform.position, new Vector3(1,0,0), rotationSpeed* rotationX);
    116.         transform.RotateAround(ARefTransform.position, new Vector3(0,1,0), rotationSpeed* rotationY);
    117.         transform.RotateAround(ARefTransform.position, new Vector3(0,0,1), rotationSpeed* rotationZ);
    118.  
    119.         UpdateOffset();
    120.      
    121.         ModifyReward(offsetsPreviousFrame, offsetsCurrentFrame);
    122.     }
    123.  
    124.     public void ModifyReward(List<float> offsetsPreviousFrame, List<float> offsetsCurrentFrame)
    125.     {
    126.         float totalDeltaOffset = 0f;
    127.         for (int i = 0; i < 6; i++)
    128.         {
    129.             totalDeltaOffset += offsetsPreviousFrame[i] - offsetsCurrentFrame[i];
    130.         }
    131.  
    132.         Debug.Log("DeltaOffsetSum: " + totalDeltaOffset.ToString());
    133.  
    134.         if (totalDeltaOffset > 0)
    135.         {
    136.             gettingCloser = true;
    137.             AddReward(+0.01f);
    138.         }
    139.         else
    140.         {
    141.             gettingCloser = false;
    142.             AddReward(-0.01f);
    143.         }
    144.  
    145.         Debug.Log("Cumulative Reward: " + GetCumulativeReward().ToString());
    146.     }
    147.  
    148.     public override void Heuristic(in ActionBuffers actionsOut)
    149.     {
    150.         ActionSegment<float> continuousActions = actionsOut.ContinuousActions;
    151.         continuousActions[0] = 0.5f;
    152.         continuousActions[1] = 0.5f;
    153.         continuousActions[2] = 0.5f;
    154.     }
    155. }
    156.  
     

    Attached Files:

    Last edited: Feb 13, 2023