Search Unity

  1. Welcome to the Unity Forums! Please take the time to read our Code of Conduct to familiarize yourself with the forum rules and how to post constructively.
  2. Dismiss Notice

Question I'm crazy about the project about MA-POCA, anyone can help me?

Discussion in 'ML-Agents' started by ice_creamer, Jul 14, 2023.

  1. ice_creamer

    ice_creamer

    Joined:
    Jul 28, 2022
    Posts:
    33
    I have been training this project for two weeks, and i learned a lot about others settings,But iit can not get a good result.
    during traning, I changed reward , observation, trainer.yaml... Now, i used the reward other used.

    The goal of my project is round up a target in the degrees of 120.
    groupAgent: 3 boat, 18 observation and ray sensor 3D, single reward contained distance and angle,diiscrete actions: 2(forward and back.., turn left and right)
    environment: ocean simulated wind, wave ....
    The specific info is as follows:
    trainer.yaml:
    upload_2023-7-14_16-28-10.png
    agent control:
    Code (CSharp):
    1.  public override void Initialize()
    2.     {
    3.         agentRb = GetComponent<Rigidbody>();
    4.         usv = GameObject.FindGameObjectsWithTag("agent");
    5.         target = GameObject.FindGameObjectWithTag("target");
    6.         ec = GetComponentInParent<EnvControl>();
    7.         m_ResetParams = Academy.Instance.EnvironmentParameters;
    8.        
    9.        
    10.  
    11.     }
    12.  
    13.  
    14.  
    15.     public override void CollectObservations(VectorSensor sensor)
    16.     {
    17.         var aP = transform.position - target.transform.position;
    18.      
    19.         posiDiff = new Vector2(aP.x, aP.z);
    20.         //Debug.Log("posidiff:" + posiDiff);
    21.         sensor.AddObservation(-posiDiff);//2
    22.         //Debug.Log(-posiDiff / 400 + new Vector2(1, 1));
    23.         sensor.AddObservation((posiDiff.magnitude - range));//1
    24.         var v = transform.InverseTransformDirection(agentRb.velocity);
    25.         sensor.AddObservation(v.z/10);//1
    26.         var av = transform.InverseTransformDirection(agentRb.angularVelocity);
    27.         sensor.AddObservation(av.y/10);//1
    28.        
    29.         Vector2 control = new Vector2(forward, rotate);
    30.         sensor.AddObservation(control);//2
    31.         sensor.AddObservation(transform.InverseTransformVector(force));//3
    32.         sensor.AddObservation(transform.InverseTransformVector(torque));//3
    33.         sensor.AddObservation(transform.position);//3
    34.        
    35.  
    36. //the angle of adjacent agent and this
    37.         if (transform.position == usv[0].transform.position)
    38.         {
    39.              aTot1 = target.transform.position - transform.position;
    40.             var d1 = new Vector2(aTot1.x, aTot1.z);
    41.              dis1 = d1.magnitude;
    42.              aTot2 = target.transform.position - usv[1].transform.position;
    43.             var d2 = new Vector2(aTot2.x, aTot2.z);
    44.              dis2 = d2.magnitude;
    45.              aTot3 = target.transform.position - usv[2].transform.position;
    46.             var d3 = new Vector2(aTot3.x, aTot3.z);
    47.              dis3 = d3.magnitude;
    48.  
    49.         }
    50.         if (transform.position == usv[1].transform.position)
    51.         {
    52.              aTot1 = target.transform.position - transform.position;
    53.             var d1 = new Vector2(aTot1.x, aTot1.z);
    54.              dis1 = d1.magnitude;
    55.              aTot2 = target.transform.position - usv[0].transform.position;
    56.             var d2 = new Vector2(aTot2.x, aTot2.z);
    57.              dis2 = d2.magnitude;
    58.              aTot3 = target.transform.position - usv[2].transform.position;
    59.             var d3 = new Vector2(aTot3.x, aTot3.z);
    60.              dis3 = d3.magnitude;
    61.  
    62.         }
    63.         if (transform.position == usv[2].transform.position)
    64.         {
    65.              aTot1 = target.transform.position - transform.position;
    66.             var d1 = new Vector2(aTot1.x, aTot1.z);
    67.              dis1 = d1.magnitude;
    68.              aTot2 = target.transform.position - usv[0].transform.position;
    69.             var d2 = new Vector2(aTot2.x, aTot2.z);
    70.              dis2 = d2.magnitude;
    71.              aTot3 = target.transform.position - usv[1].transform.position;
    72.             var d3 = new Vector2(aTot3.x, aTot3.z);
    73.              dis3 = d3.magnitude;
    74.  
    75.         }
    76.         averDis = dis1 + dis2 + dis3 - 3 * range;
    77.         sensor.AddObservation(averDis);//1
    78.         float a1 = Vector3.SignedAngle(aTot1, aTot2, transform.up);
    79.             float a2 = Vector3.SignedAngle(aTot1, aTot3, transform.up);
    80.             if (a1 > 0&&a2>0)
    81.             {
    82.             if (a1 <= a2)
    83.             {
    84.                 thet1 = a1;
    85.                 thet2 = a2;
    86.             }
    87.             else
    88.             {
    89.                 thet1 = a2;
    90.                 thet2 = a1;
    91.             }
    92.             }
    93.             if(a1>0&&a2<0)
    94.             {
    95.                 thet1 = a1;
    96.                 thet2 = a2;
    97.             }
    98.             if (a1 < 0&&a2>0)
    99.             {
    100.                 thet1 = a2;
    101.                 thet2 = a1;
    102.             }
    103.         if (a1 < 0 && a2 < 0)
    104.         {
    105.             a1 = 360 + a1;
    106.             a2 = 360 + a2;
    107.             {
    108.                 if (a1 <= a2)
    109.                 {
    110.                     thet1 = a1;
    111.                     thet2 = a2;
    112.                 }
    113.                 else
    114.                 {
    115.                     thet1 = a2;
    116.                     thet2 = a1;
    117.                 }
    118.             }
    119.         }
    120.         if (a1 == 0 && a2 != 0)
    121.         {
    122.             thet1 = a1;
    123.             thet2 = a2;
    124.         }
    125.         if (a2 == 0 && a1 != 0)
    126.         { thet1 = a2; thet2 = a1; }
    127.         if (a1 == 0 && a2 == 0)
    128.         { thet1 = a1; thet2 = a2; }
    129.  
    130.  
    131.         //Debug.Log(thet);
    132.         sensor.AddObservation(thet1);//1
    133.            //18
    134.     }
    135.  
    136.     public void MoveAgent(ActionBuffers actionBuffers)
    137.     {
    138.         var Action = actionBuffers.DiscreteActions;
    139.      
    140.  
    141.  
    142.         forward = Action[0];
    143.         rotate = Action[1];
    144.  
    145.         switch (forward)
    146.         {
    147.             case 0:
    148.                 force = new Vector3(0, 0, 0);
    149.                 break;
    150.             case 1:
    151.                 force = transform.InverseTransformDirection(transform.forward);
    152.                 break;
    153.             case 2:
    154.                 force = -transform.InverseTransformDirection(transform.forward);
    155.                 break;
    156.         }
    157.         agentRb.AddRelativeForce(force * moveSpeed);//vel:moveSpeed=1,turnSpeed=0.6
    158.         switch (rotate)
    159.         {
    160.             case 0:
    161.                 torque = new Vector3(0, 0, 0);
    162.                 break;
    163.             case 1:
    164.                 torque = transform.InverseTransformDirection(transform.up) * turnSpeed;
    165.                 break;
    166.             case 2:
    167.                 torque = -transform.InverseTransformDirection(transform.up) * turnSpeed;
    168.                 break;
    169.         }
    170.         agentRb.AddRelativeTorque(torque);
    171.  
    172.         if (agentRb.velocity.z > 12)
    173.         {
    174.  
    175.             agentRb.velocity = 0.85f * agentRb.velocity;
    176.  
    177.         }
    178.  
    179.         if (agentRb.angularVelocity.y > 0.4)
    180.         {
    181.  
    182.             agentRb.angularVelocity = 0.85f * agentRb.angularVelocity;
    183.         }
    184.  
    185.         DisReward();
    186.         AngleReward();
    187.         CheckIfOutbound();
    188.         rewadSingle = 0.6f*Rd + 0.4f*Rthet;
    189.         AddReward(rewadSingle);
    190. // encourge forward
    191.         if(transform.InverseTransformDirection(agentRb.velocity).z>0 )
    192.         {
    193.             AddReward(0.005f);
    194.         }
    195.      
    196.  
    197.     }
    198.     public override void OnActionReceived(ActionBuffers actions)
    199.     {
    200.         MoveAgent(actions);
    201.     }
    202.     public override void Heuristic(in ActionBuffers actionsOut)
    203.     {
    204.        
    205.  
    206.         var discreteActionsOut = actionsOut.DiscreteActions;
    207.         //forward
    208.         if (Input.GetKey(KeyCode.W))
    209.         {
    210.             discreteActionsOut[0] = 1;
    211.         }
    212.         if (Input.GetKey(KeyCode.S))
    213.         {
    214.             discreteActionsOut[0] = 2;
    215.         }
    216.         //rotate
    217.         if (Input.GetKey(KeyCode.A))
    218.         {
    219.             discreteActionsOut[1] = 1;
    220.         }
    221.         if (Input.GetKey(KeyCode.D))
    222.         {
    223.             discreteActionsOut[1] = 2;
    224.         }
    225.     }
    226.  
    227.  
    228.     public override void OnEpisodeBegin()
    229.     {
    230.  
    231.         ec.ResetScene();
    232.     }
    233.     void DisReward()
    234.     {
    235.         float meanAverDis = Mathf.Pow(((dis1 - averDis) * (dis1 - averDis) + (dis2 - averDis) * (dis2 - averDis) + (dis3 - averDis) * (dis3 - averDis)) / 3, 0.5f);
    236.      
    237.         Rd = 1 - 0.05f * (posiDiff.magnitude - range) - 0.2f * Mathf.Exp((posiDiff.magnitude - range) / meanAverDis);
    238.     }
    239.     void AngleReward()
    240.     {
    241.         float a1 = Vector3.SignedAngle(aTot1, aTot2, transform.up);
    242.         float a2 = Vector3.SignedAngle(aTot1, aTot3, transform.up);
    243.         Rthet = 0.3f*Mathf.Exp(-Mathf.Abs((thet1 - 120) * Mathf.PI / 180)) - 1+ Mathf.Exp(-Mathf.Abs((thet2 - 240) * Mathf.PI / 180)) - 1+
    244.             0.4f*Mathf.Exp(-(Mathf.Abs(a1 * Mathf.PI / 180)- (Mathf.Abs(a2 * Mathf.PI / 180))))-1;
    245.     }
    246.     private void OnCollisionEnter(Collision collision)
    247.     {
    248.      
    249.         if (collision.gameObject.CompareTag("obstacle")||collision.gameObject.CompareTag("agent")|| collision.gameObject.CompareTag("target"))
    250.         {
    251.             AddReward(-0.05f);
    252.             ec.ResetScene();
    253.         }
    254.     }
    255.  
    256.     void CheckIfOutbound()
    257.     {
    258.         var bound = 1.5f;
    259.         if (transform.position.x < -bound * ec.areaBounds.extents.x || transform.position.x > bound * ec.areaBounds.extents.x
    260.             || transform.position.z < -bound * ec.areaBounds.extents.z || transform.position.z > bound *ec. areaBounds.extents.z
    261.             )
    262.  
    263.         {
    264.             AddReward(-0.03f);
    265.             ec.ResetScene();
    266.         }
    267.     }
    268.  
    269.  
    270. }
    environmen control(contain group reward):
    Code (CSharp):
    1.   private void Start()
    2.     {
    3.         areaBounds = ground.GetComponent<Collider>().bounds;
    4.         //Debug.Log(areaBounds);
    5.      
    6.         m_purseAgent = FindObjectOfType<Cooperate>();
    7.         m_AgentGroup = new SimpleMultiAgentGroup();
    8.         usv = GameObject.FindGameObjectsWithTag("agent");
    9.         target = GameObject.FindGameObjectWithTag("target");
    10.         //Debug.Log("awake");
    11.  
    12.         foreach (var item in TargetsList)
    13.         {
    14.             item.StartingPos = item.target.transform.position;
    15.             item.StartingRot = item.target.transform.rotation;
    16.             item.T = item.target.transform;
    17.             item.Rb = item.target.GetComponent<Rigidbody>();
    18.             item.Col = item.target.GetComponent<Collider>();
    19.         }
    20.  
    21.  
    22.  
    23.         foreach (var item in AgentsList)
    24.         {
    25.             item.StartingPos = item.agent.transform.position;
    26.             item.StartingRot = item.agent.transform.rotation;
    27.             item.Rb = item.agent.GetComponent<Rigidbody>();
    28.             item.Col = item.agent.GetComponent<Collider>();
    29.             m_AgentGroup.RegisterAgent(item.agent);
    30.         }
    31.         ResetScene();
    32.     }
    33.  
    34.     public void ResetScene()
    35.     {
    36.         m_ResetTimer = 0;
    37.  
    38.         //Random platform rotation
    39.         var rotation = Random.Range(0, 4);
    40.         var rotationAngle = rotation * 90f;
    41.         transform.Rotate(new Vector3(0f, rotationAngle, 0f));
    42.  
    43.         //Reset Agents
    44.         foreach (var item in AgentsList)
    45.         {
    46.             var pos = UseRandomAgentPosition ? GetRandomSpawnPos() : item.StartingPos;
    47.             var rot = UseRandomAgentRotation ? GetRandomRot() : item.StartingRot;
    48.  
    49.             item.agent.transform.SetPositionAndRotation(pos, rot);
    50.             item.Rb.velocity = Vector3.zero;
    51.             item.Rb.angularVelocity = Vector3.zero;
    52.             m_AgentGroup.RegisterAgent(item.agent);
    53.         }
    54.  
    55.  
    56.         foreach (var item in TargetsList)
    57.         {
    58.             var pos = UseRandomTargetPosition ? GetRandomSpawnPos() : item.StartingPos;
    59.             var rot = UseRandomTargetRotation ? GetRandomRot() : item.StartingRot;
    60.  
    61.             item.T.transform.SetPositionAndRotation(pos, rot);
    62.             item.Rb.velocity = Vector3.zero;
    63.             item.Rb.angularVelocity = Vector3.zero;
    64.             item.T.gameObject.SetActive(true);
    65.         }
    66.  
    67.  
    68.     }
    69.  
    70.     public Vector3 GetRandomSpawnPos()
    71.     {
    72.         var foundNewSpawnLocation = false;
    73.         var randomSpawnPos = Vector3.zero;
    74.         while (foundNewSpawnLocation == false)
    75.         {
    76.             var randomPosX = Random.Range(-areaBounds.extents.x , areaBounds.extents.x  );
    77.  
    78.             var randomPosZ = Random.Range(-areaBounds.extents.z , areaBounds.extents.z );
    79.             randomSpawnPos = ground.transform.position + new Vector3(randomPosX, 3.27f, randomPosZ);
    80.             //randomSpawnPos =  new Vector3(randomPosX, ground.transform.position.y+86.8f, randomPosZ);
    81.             var spawnPosTt = randomSpawnPos - target.transform.position;
    82.             if (Physics.CheckBox(randomSpawnPos, new Vector3(1.75f, 1f, 4.25f)) == false
    83.                 && (randomSpawnPos.x > -areaBounds.extents.x && randomSpawnPos.x < areaBounds.extents.x)
    84.                 && (randomSpawnPos.z > -areaBounds.extents.z && randomSpawnPos.z < areaBounds.extents.z)
    85.                 && (Mathf.Pow(spawnPosTt.x * spawnPosTt.x + spawnPosTt.z * spawnPosTt.z, 0.5f) > m_purseAgent.range))
    86.             {
    87.                 foundNewSpawnLocation = true;
    88.             }
    89.         }
    90.         return randomSpawnPos;
    91.     }
    92.  
    93.  
    94.     Quaternion GetRandomRot()
    95.     {
    96.         return Quaternion.Euler(0, Random.Range(0.0f, 360.0f), 0);
    97.     }
    98.  
    99.  
    100.     public void FixedUpdate()
    101.     {
    102.         m_ResetTimer += 1;
    103.         if(m_ResetTimer>=MaxEnvironmentSteps&&MaxEnvironmentSteps>0)
    104.         {
    105.             m_AgentGroup.GroupEpisodeInterrupted();
    106.             ResetScene();
    107.         }
    108.         SuccessPurse();
    109.      
    110.  
    111.         m_AgentGroup.AddGroupReward(-20f / MaxEnvironmentSteps);
    112.        
    113.     }
    114.  
    115.  
    116.     void SuccessPurse()
    117.     {
    118.         int flag = 0;
    119.         int f = 0;
    120.         Vector2 u1 = new Vector2(target.transform.position.x - usv[0].transform.position.x, target.transform.position.z - usv[0].transform.position.z);
    121.         Vector2 u2 = new Vector2(target.transform.position.x - usv[1].transform.position.x, target.transform.position.z - usv[1].transform.position.z);
    122.         Vector2 u3 = new Vector2(target.transform.position.x - usv[2].transform.position.x, target.transform.position.z - usv[2].transform.position.z);
    123.         if(u1.magnitude<=m_purseAgent.range)
    124.         {
    125.             flag++;
    126.         }
    127.         if(u2.magnitude<=m_purseAgent.range)
    128.         {
    129.             flag++;
    130.         }
    131.         if (u3.magnitude <= m_purseAgent.range)
    132.         {
    133.             flag++;
    134.         }
    135.         if(Vector2.Dot(u1.normalized,u2.normalized)>=-0.5&& Vector2.Dot(u1.normalized, u2.normalized) <= -0.1)
    136.        
    137.         {
    138.             f++;
    139.         }
    140.         if (Vector2.Dot(u1.normalized, u3.normalized) >= -0.5 && Vector2.Dot(u1.normalized, u3.normalized) <= -0.1)
    141.         {
    142.             f++;
    143.         }
    144.         if(flag==3&&f==2)
    145.         {
    146.             m_AgentGroup.AddGroupReward(100);
    147.             m_AgentGroup.EndGroupEpisode();
    148.             ResetScene();
    149.         }
    150.     }
    151.    
    152. }
    153.  
     
  2. GamerLordMat

    GamerLordMat

    Joined:
    Oct 10, 2019
    Posts:
    177
    Very hard to tell. If you have one single bug in your code it can mess up training completely (fideling around with angles is always a source of errors for me).
    Try to debug your code and see if every value is how it should be
     
  3. ice_creamer

    ice_creamer

    Joined:
    Jul 28, 2022
    Posts:
    33
    Now I'm trying to put angle part into group reward. I will show the result later.
     
  4. ice_creamer

    ice_creamer

    Joined:
    Jul 28, 2022
    Posts:
    33
    Hi,about reward normalization i have one question. Sometimes, mean reward sharply drop to -1e7. Is relative with normalize? Observation i nor in trainer.yaml. If yes, where wrong? Should i nor the total reward rather than single?
    Appreciate!
     
  5. GamerLordMat

    GamerLordMat

    Joined:
    Oct 10, 2019
    Posts:
    177
    poca team reward should be either 1 for win or 0 for loose. Having it slide between 1 and 0 can be suboptimal.
    I would try to give the indiviudals points for being close to each other and let them figure out what to do. IDK, playing around with those reward functions is the diffcult part and really depends on your training hardware, bugs in code and indiviudal project
     
  6. mehdi_1234

    mehdi_1234

    Joined:
    Jun 16, 2023
    Posts:
    4
    Are you using multiple environments to learn faster?
    My observations are similar to yours:
    Code (CSharp):
    1. if (transform.position == agent[0].transform.position)
    2.         {
    3. ...
    4.          }
    The agents could fully learn in one environment, but in multiple environments, the agents couldn't.
    I think this is because of the observation mentioned above.