Search Unity

Discussion Make a Rigidbody rotate

Discussion in 'ML-Agents' started by GamerLordMat, Jul 3, 2022.

  1. GamerLordMat

    GamerLordMat

    Joined:
    Oct 10, 2019
    Posts:
    185
    Hello,

    so I am asking basically the same question I have already asked some time ago.
    How to make a Rigidbody Agent rotate to a specified Rotation, with a starting angular velocity?

    Important is that the agent is not symetric by no means, it is a ball with a heavy tip (Fixed Joint), discentered etc..

    It sounds like a simple problem, but I had not been able to solve it after 2 years of trying.
    Please try it youself as a challenge, this is basic stuff you need for turrets, walking beasts etc.

    My code is a mess, but if you want to plug and play it see below.
    The main take away is that I solved it without a starting angular velocity but with it just never worked. Now it's already been training for 10 hours, still bad results.

    If someone has an idea, please go ahead :D

    My code:
    Code (CSharp):
    1. using System.Collections;
    2. using System.Collections.Generic;
    3. using UnityEngine;
    4. using Unity.MLAgents;
    5. using Unity.MLAgents.Actuators;
    6. using Unity.MLAgents.Sensors;
    7. using Unity.MLAgents;
    8.  
    9. public class Agent3DPointer : Agent
    10. {
    11.     public GameObject visual;
    12.     Rigidbody rb;
    13.     public Rigidbody pointer;
    14.     public Quaternion targetRotation;
    15.     public GameObject goalObject;
    16.     public float strength;
    17.     public override void OnActionReceived(ActionBuffers actions)
    18.  
    19.     {
    20.  
    21.         float f1 = actions.ContinuousActions[0];
    22.         float f2 = actions.ContinuousActions[1];
    23.         float f3 = actions.ContinuousActions[2];
    24.  
    25.         rb.AddTorque(transform.right * strength * f1);
    26.         rb.AddTorque(transform.up * strength * f2);
    27.         rb.AddTorque(transform.forward * strength * f3);
    28.         //rb.AddRelativeTorque(Vector3.right * manager.legStrenght * f1);
    29.         //rb.AddRelativeTorque(Vector3.up * manager.legStrenght * f2);
    30.         //rb.AddRelativeTorque(Vector3.forward * manager.legStrenght * f3);
    31.     }
    32.     public float toDegree = 57.2f;
    33.     public float akkumulation;
    34.     public bool touches;
    35.     public override void CollectObservations(VectorSensor sensor)
    36.     {
    37.  
    38.  
    39.         //sensor.AddObservation(transform.localRotation.normalized);
    40.         Quaternion q2 = targetRotation.normalized;
    41.         //sensor.AddObservation(q2);
    42.  
    43.         //sensor.AddObservation(a * clockwise);
    44.         //sensor.AddObservation((a) );
    45.         bool b = false;
    46.         if (clockwise == 1)
    47.         {
    48.             b = true;
    49.         }
    50.         else if (clockwise == -1)
    51.         {
    52.             b = false;
    53.         }
    54.  
    55.      
    56.         sensor.AddObservation(transform.rotation.normalized);
    57.         sensor.AddObservation((Quaternion.Inverse(transform.rotation) * rb.angularVelocity)/rb.maxAngularVelocity);
    58.         sensor.AddObservation(rb.angularVelocity.magnitude / rb.maxAngularVelocity);
    59.         sensor.AddObservation(targetRotation.normalized);
    60.         sensor.AddObservation((targetRotation.normalized * Quaternion.Inverse( transform.rotation.normalized).normalized).normalized);
    61.      
    62.         //for turret
    63.         //sensor.AddObservation(touches);
    64.  
    65.  
    66.         int there = -1;
    67.         float dot = (Quaternion.Dot(targetRotation, transform.rotation) + 1) * 0.5f;
    68.         if (1 - dot < delta)
    69.         {
    70.             //AddReward(10000f*d);
    71.             there = 1;
    72.         }
    73.         //sensor.AddObservation(there);
    74.  
    75.  
    76.     }
    77.     public int clockwise;
    78.     float closestDIfference(float a, float target)
    79.     {
    80.         bool isNegativ = false;
    81.         float difference = target - a + 360;
    82.         float res1 = 0, res2 = 0;
    83.         if (difference < 0)
    84.         {
    85.             isNegativ = true;
    86.             res1 = difference + 360;
    87.             res2 = -difference;
    88.         }
    89.  
    90.         if (difference > 0)
    91.         {
    92.             res1 = 360 - difference;
    93.             res2 = difference;
    94.         }
    95.  
    96.         if (res1 > res2 && !isNegativ)
    97.         {
    98.             clockwise = -1;
    99.             return res2;
    100.         }
    101.         if (res1 > res2 && isNegativ)
    102.         {
    103.             clockwise = 1;
    104.             return res2;
    105.         }
    106.         if (res1 < res2 && !isNegativ)
    107.  
    108.         {
    109.             clockwise = 1;
    110.             return res1;
    111.  
    112.         }
    113.         if (res1 < res2 && isNegativ)
    114.         {
    115.             clockwise = -1;
    116.             return res1;
    117.         }
    118.         if (Mathf.Equals(res1, res2))
    119.         {
    120.             return res1;
    121.         }
    122.         return 0;
    123.     }
    124.     public float newAbstand;
    125.     public float delta = 5f;
    126.  
    127.     Quaternion lastRotation;
    128.     Quaternion newRotation;
    129.    
    130.     float nowDot;
    131.     float nowAngularSpeed;
    132.     private void FixedUpdate()
    133.     {
    134.         nowDot = (Quaternion.Dot(targetRotation.normalized, transform.rotation.normalized)+1)*0.5f;
    135.         //nowAngularSpeed = (rb.angularVelocity/rb.maxAngularVelocity).magnitude;
    136.         nowAngularSpeed = rb.angularVelocity.magnitude / rb.maxAngularVelocity;
    137.         nowAngularSpeed = Mathf.Pow(nowAngularSpeed,3);
    138.         float rew1 = Mathf.Clamp01(nowDot);
    139.         //float rew2 = 1 - nowAngularSpeed / (Mathf.Sqrt(3));
    140.         float rew2 = Mathf.Clamp01 (1 - nowAngularSpeed);
    141.         //AddReward(rew1-1);
    142.         //AddReward(rew2-1);
    143.         float exp = 1 / (1 + Mathf.Exp(-rew1));
    144.         exp = Mathf.Clamp01(exp);
    145.         AddReward(rew1);
    146.         AddReward(rew2);
    147.         //AddReward(Mathf.Clamp01(rew1 *rew2));
    148.  
    149.         //Debug.Log("------------------------");
    150.         //Debug.Log(( rew1*rew2));
    151.         //Debug.Log((rew2));
    152.         //Debug.Log((rew1));
    153.         //AddReward(rew2);
    154.         //AddReward((nowDot - lastdot));
    155.         //AddReward(nowDot);
    156.         //AddReward(nowDot);
    157.         if (nowDot > 1- manager.delta)
    158.         {
    159.             AddReward(1);
    160.         }
    161.        
    162.         //AddReward((lastAngularSpeed - nowAngularSpeed)/rb.maxAngularVelocity);
    163.         /*
    164.         float reward = 1;
    165.         //Quaternion mult with Vector is transform.up
    166.         matchVectorSensor1.UpdateVectors(targetRotation * Vector3.up, transform.up);
    167.  
    168.         bool b1 = false, b2 = false, b3 = false;
    169.         if (matchVectorSensor1.inDelta())
    170.         {
    171.             b1 = true;
    172.             AddReward(1);
    173.         }
    174.         matchVectorSensor2.UpdateVectors(targetRotation * Vector3.right, transform.right);
    175.         if (matchVectorSensor2.inDelta())
    176.         {
    177.             b2 = true;
    178.             AddReward(1);
    179.  
    180.         }
    181.         matchVectorSensor3.UpdateVectors(targetRotation * Vector3.forward, transform.forward);
    182.         if (matchVectorSensor3.inDelta())
    183.         {
    184.             b3 = true;
    185.             AddReward(1);
    186.  
    187.         }
    188.         if (b1 && b2 && b3)
    189.         {
    190.             AddReward(1);
    191.         }
    192.         else
    193.         {
    194.             AddReward(-0.1f);
    195.         }
    196.         //AddReward(matchVectorSensor1.dot() * matchVectorSensor2.dot() * matchVectorSensor3.dot());
    197.         AddReward(1 - Mathf.Abs(matchVectorSensor1.radDistanceNorm()));
    198.         AddReward(1 - Mathf.Abs(matchVectorSensor2.radDistanceNorm()));
    199.         AddReward(1 - Mathf.Abs(matchVectorSensor3.radDistanceNorm()));
    200.  
    201.         AddReward(-(rb.angularVelocity.magnitude) / rb.maxAngularVelocity);
    202.         lastAbstand = newAbstand;
    203.         // lastRotation = newRotation;
    204.        
    205.         */
    206.         lastAngularSpeed = nowAngularSpeed;
    207.         lastdot = nowDot;
    208.     }
    209.     public float lastAbstand;
    210.     public float goalSpeed;
    211.     float sigmoid(float value, float max)
    212.     {
    213.         float v = Mathf.Pow(1 - Mathf.Pow(value / max, 2), 2);
    214.         return v;
    215.     }
    216.     float lastdot;
    217.     float lastAngularSpeed=0;
    218.     public override void OnEpisodeBegin()
    219.     {
    220.         lastRotation = Quaternion.identity;
    221.         goalSpeed = Random.Range(-20, 20);
    222.        
    223.         int i = (int)Academy.Instance.EnvironmentParameters.GetWithDefault("lecture", 0.1f);
    224.         switch (i)
    225.         {
    226.             case 0:
    227.                 //lastAbstand = closestDIfference(targetRotation, hJoint.angle);
    228.                 rb.velocity = Vector3.zero;
    229.                 rb.angularVelocity = Random.rotation.eulerAngles * rb.maxAngularVelocity;
    230.                 //rb.angularVelocity = Vector3.zero;
    231.                 delta = manager.delta;
    232.                 //rb.rotation = Random.rotation;
    233.                 targetRotation = Random.rotation;
    234.                 //matchVectorSensor1.constructor(targetRotation * Vector3.up, 1.5f, transform, manager.delta);
    235.                 //matchVectorSensor2.constructor(targetRotation * Vector3.right, 1.5f, transform, manager.delta);
    236.                 //matchVectorSensor3.constructor(targetRotation * Vector3.forward, 1.5f, transform, manager.delta);
    237.  
    238.                 break;
    239.             case 1:
    240.  
    241.                 rb.velocity = Vector3.zero;
    242.                 //rb.angularVelocity = new Vector3(Random.Range(manager.min, manager.max), 0, 0);
    243.                 rb.angularVelocity = Random.rotation.eulerAngles.normalized * rb.maxAngularVelocity;
    244.                 //rb.angularVelocity = new Vector3(Random.Range(manager.min, manager.max), 0, 0);
    245.                 delta = manager.delta;
    246.                 //rb.rotation.SetEulerAngles(Random.Range(-180, 180), 0, 0);
    247.                 targetRotation = Random.rotation;
    248.                 lastRotation = Quaternion.FromToRotation(transform.rotation.eulerAngles, targetRotation.eulerAngles);
    249.  
    250.  
    251.                 break;
    252.             case 2:
    253.                 rb.velocity = Vector3.zero;
    254.                 rb.angularVelocity = new Vector3(Random.Range(0, 0), 0, 0);
    255.                 delta = 2;
    256.                 //rb.rotation.SetEulerAngles(Random.Range(-180, 180), 0, 0);
    257.                 targetRotation = Quaternion.Euler(Random.Range(-180, 180), 0, 0);
    258.                 break;
    259.             case 3:
    260.                 rb.velocity = Vector3.zero;
    261.                 rb.angularVelocity = new Vector3(Random.Range(0, 0), 0, 0);
    262.                 delta = 2;
    263.                 //rb.rotation.SetEulerAngles(Random.Range(-180, 180), 0, 0);
    264.                 targetRotation = Quaternion.Euler(Random.Range(-180, 180), 0, 0);
    265.                 break;
    266.         }
    267.         //Pointer
    268.         pointer.velocity = Vector3.zero;
    269.         //pointer.rotation = transform.rotation;
    270.         pointer.angularVelocity = rb.angularVelocity;
    271.         //----------------------------------------------------------
    272.         lastdot = Quaternion.Dot(targetRotation.normalized, transform.rotation.normalized);
    273.         lastAngularSpeed = rb.angularVelocity.magnitude;
    274.         visual.transform.localPosition = targetRotation * Vector3.up * 1.5f;
    275.         //targetRotation = -160;
    276.     }
    277.     /*
    278.     MatchVectorSensor matchVectorSensor1;
    279.     MatchVectorSensor matchVectorSensor2;
    280.     MatchVectorSensor matchVectorSensor3;
    281.     */
    282.     // Start is called before the first frame update
    283.     void Start()
    284.     {
    285.         //matchVectorSensor1 = GetComponent<MatchVectorSensor>();
    286.         //matchVectorSensor2 = gameObject.AddComponent<MatchVectorSensor>();
    287.         //matchVectorSensor3 = gameObject.AddComponent<MatchVectorSensor>();
    288.         // hJoint = GetComponent<HingeJoint>();
    289.         rb = GetComponent<Rigidbody>();
    290.         rb.maxAngularVelocity = 50;
    291.         manager = FindObjectOfType<MM>();
    292.     }
    293.  
    294.     // Update is called once per frame
    295.  
    296. }
    297.  
    Hyper parameters


    Code (CSharp):
    1. default_settings: null
    2. behaviors:
    3.   Pointer:
    4.     trainer_type: ppo
    5.     hyperparameters:
    6.       batch_size: 128
    7.       buffer_size: 1280
    8.       learning_rate: 0.0003
    9.       beta: 0.005
    10.       epsilon: 0.2
    11.       lambd: 0.95
    12.       num_epoch: 3
    13.       learning_rate_schedule: linear
    14.     network_settings:
    15.       normalize: true
    16.       hidden_units: 128
    17.       num_layers: 2
    18.       vis_encode_type: simple
    19.       memory: null
    20.       goal_conditioning_type: hyper
    21.     reward_signals:
    22.       extrinsic:
    23.         gamma: 0.995
    24.         strength: 1.0
    25.         network_settings:
    26.           normalize: false
    27.           hidden_units: 128
    28.           num_layers: 2
    29.           vis_encode_type: simple
    30.           memory: null
    31.           goal_conditioning_type: hyper
    32.     init_path: null
    33.     keep_checkpoints: 5
    34.     checkpoint_interval: 500000
    35.     max_steps: 500000000
    36.     time_horizon: 1000
    37.     summary_freq: 30000
    38.     threaded: true
    39.     self_play: null
    40.     behavioral_cloning: null
    41. env_settings:
    42.   env_path: null
    43.   env_args: null
    44.   base_port: 5005
    45.   num_envs: 1
    46.   seed: -1
    47.  
    48.  
    49. # Add this section
    50.  
     
  2. mbaske

    mbaske

    Joined:
    Dec 31, 2017
    Posts:
    473
    Try a PID controller, no need for ML.
     
    GamerLordMat likes this.
  3. GamerLordMat

    GamerLordMat

    Joined:
    Oct 10, 2019
    Posts:
    185
    Hello mbaske,

    thanks for answering.
    Yeah, that is what I used in the end after having all kind of experiments and it worked very precise.
    But is still bothered me that such a simple thing cant be achieved with DL (I also tried to give the PID observations to a MLagent (), no luck).

    If you have a drone and just dont know how to model the PID Controller for it the drone wont perform well. A pitty.

    But I think that is the point, if you know the exact Math ML is not necessary and performs worse.