Search Unity

  1. Unity 6 Preview is now available. To find out what's new, have a look at our Unity 6 Preview blog post.
    Dismiss Notice
  2. Unity is excited to announce that we will be collaborating with TheXPlace for a summer game jam from June 13 - June 19. Learn more.
    Dismiss Notice
  3. Dismiss Notice

Question It doesn't work when I copy the environmets for faster training - MLAgents

Discussion in 'ML-Agents' started by orhun868, Mar 5, 2024.

  1. orhun868

    orhun868

    Joined:
    Feb 22, 2024
    Posts:
    5
    When I copy the environment, the script does not work for the original environment and only works for the environment the last I copied. Normally hidden objects appear when I start it. But only the ones in the env I last copied appear and the other environment moves as the last one but does not work.
    upload_2024-3-6_1-29-48.png
    upload_2024-3-6_1-30-24.png

    The other one works fine but the main env is never resetted. When I delete the working one during the process, the main env stops. Here is my code:
    Code (CSharp):
    1. using System;
    2. using System.Runtime.CompilerServices;
    3. using System.Collections.Generic;
    4. using System.IO;
    5. using System.Linq;
    6. using TMPro;
    7. using Unity.Burst.Intrinsics;
    8. using Unity.VisualScripting;
    9. using UnityEngine;
    10. using UnityEngine.UI;
    11. using Unity.MLAgents;
    12. using Unity.MLAgents.Actuators;
    13. using Unity.MLAgents.Sensors;
    14. using UnityEditor;
    15.  
    16. public class ChildObjectManager : Agent
    17. {
    18.     private GameObject parentObject;
    19.     private int rows = 20;
    20.  
    21.     private Vector3 point1 = new Vector3(-20.0900002f,37.9300003f,0.839999974f);
    22.     private Vector3 point2 = new Vector3(-38.9300003f,37.9300003f,0.839999974f);
    23.     private Vector3 point3 = new Vector3(-20.0900002f,37.9300003f,20.0900002f);
    24.  
    25.     private string path = Application.dataPath + "/trainLog/Log.txt";
    26.     private int columns = 20;
    27.     private float minY = 1.75f;
    28.     private float maxY = 3.5f;
    29.     public float moveDist = 5f;
    30.     public float moveSpeed = 0.7f;
    31.     private int distance_lim = 8;
    32.     private Transform[,] childArray;
    33.     private float[,] movementObserveArray;
    34.     private GameObject wall;
    35.     private GameObject product;
    36.     private GameObject target;
    37.     // const float maxProductSpeed = 10;
    38.     public float angleTolerance = 40f;
    39.     private Rigidbody productRigidbody;
    40.     private TextMeshPro ui;
    41.     private Vector3 savedProductLoc;
    42.     private Vector3 transformLoc;
    43.     public int actionLimit = 800;
    44.     private int actionCount = 0;
    45.     private Vector3[,] tableLoc;
    46.     private int gameCount = 0;
    47.     private int win = 0;
    48.     private Queue<float> last_rewards;
    49.     private float lastReward = 0;
    50.    
    51.     private void Awake()
    52.     {
    53.         using (StreamWriter sw = new StreamWriter(path))
    54.         {
    55.             sw.WriteLine("\nTraining Started...");
    56.         }
    57.         last_rewards = new Queue<float>(200);
    58.         transformLoc = transform.localPosition;
    59.         product = GameObject.FindWithTag("Product");
    60.         wall = GameObject.FindWithTag("Wall");
    61.         savedProductLoc = product.transform.localPosition;
    62.         GameObject uiGameObject  = GameObject.FindWithTag("GeneralText");
    63.         productRigidbody = product.GetComponent<Rigidbody>();
    64.         ui = uiGameObject.GetComponent<TextMeshPro>();
    65.         target = GameObject.FindWithTag("Target");
    66.         parentObject = transform.gameObject;
    67.         if (parentObject != null)
    68.         {
    69.             GetChildObjects();
    70.         }
    71.         else
    72.         {
    73.             Debug.LogError("Parent object not assigned!");
    74.         }
    75.     }
    76.  
    77.     private void GetChildObjects()
    78.     {
    79.         childArray = new Transform[rows, columns];
    80.         tableLoc = new Vector3[rows, columns];
    81.         int index = 0;
    82.  
    83.         foreach (Transform child in parentObject.transform)
    84.         {
    85.             int i = index / rows;
    86.             int j = index % columns;
    87.             childArray[i, j] = child;
    88.             tableLoc[i, j] = child.transform.localPosition;
    89.             index++;
    90.         }
    91.     }
    92.  
    93.     private void Update()
    94.     {
    95.         updateUI();
    96.     }
    97.  
    98.     public void triggerReset(){
    99.         lastReward = -10;
    100.         AddReward(-10f);
    101.         AddValue(-10f);
    102.         EndEpisode();
    103.     }
    104.  
    105.     public void winReset(){
    106.             float reward = rewardCalculate(0);
    107.             AddValue(reward);
    108.             AddReward(reward);
    109.             EndEpisode();
    110.     }
    111.  
    112.     public float CurrentAverage()
    113.     {
    114.         if (last_rewards.Count == 0)
    115.         {
    116.             return 0f;
    117.         }
    118.  
    119.         return last_rewards.Average();
    120.     }
    121.  
    122.     public void AddValue(float newValue)
    123.     {
    124.         if (last_rewards.Count == 200)
    125.         {
    126.             last_rewards.Dequeue();
    127.         }
    128.  
    129.         last_rewards.Enqueue(newValue);
    130.     }
    131.  
    132.     private float rewardCalculate(float point)
    133.     {
    134.         float reward;
    135.         if (point == 0){
    136.             reward = (actionLimit-actionCount)/10f + 5f;
    137.             win++;
    138.         }
    139.         else{
    140.             reward = (28.5f-targetCloseness()*1f)/2f - 10f;
    141.         }
    142.         lastReward = reward;
    143.         return reward;
    144.     }
    145.  
    146.     private void updateUI()
    147.     {
    148.         ui.text = "Product States\nSpeed: "+getProductSpeed()+"\nPosition: "+getProductPos()+"\nDistance to Target: "+targetCloseness()+"\nAngle Correctness: "+getProductRot()+"\nAction Count: "+actionCount+"\nGame Count: "+gameCount+"\nWin Count: "+win+"\nAvg of Last 200 Rewards: "+CurrentAverage()+"\nLast Reward: "+lastReward;
    149.     }
    150.     private void MoveChildren()
    151.     {
    152.         movementObserveArray = new float[rows, columns];
    153.         int index = 0;
    154.         foreach (Transform child in childArray)
    155.         {
    156.             if (child != null)
    157.             {
    158.                 if (GetDistanceToChild(child))
    159.                 {
    160.                     float randomDirection = UnityEngine.Random.Range(-moveDist, moveDist);
    161.                     float newYPosition = child.localPosition.y + randomDirection * moveSpeed * Time.deltaTime;
    162.                     newYPosition = Mathf.Clamp(newYPosition, minY, maxY);
    163.                     child.localPosition = new Vector3(child.localPosition.x, newYPosition, child.localPosition.z);
    164.                 }
    165.                 int i = index / columns;
    166.                 int j = index % columns;
    167.                 movementObserveArray[i, j] = child.localPosition.y;
    168.                 index++;
    169.             }
    170.         }
    171.     }
    172.     private bool getProductRot()
    173.     {
    174.         Quaternion currentRotation = product.transform.localRotation;
    175.         Quaternion targetRotation = Quaternion.Euler(0f, 0f, 0f);
    176.         float angleDifference = Quaternion.Angle(currentRotation, targetRotation);
    177.         return angleDifference <= angleTolerance;
    178.     }
    179.     private float getProductSpeed()
    180.     {
    181.         return productRigidbody.velocity.magnitude;
    182.     }
    183.     private Vector3 getProductPos()
    184.     {
    185.         return product.transform.localPosition;
    186.     }
    187.     private float targetCloseness()
    188.     {
    189.         float distance = Vector3.Distance(product.transform.localPosition, target.transform.localPosition);
    190.         if (distance < 1)
    191.         {
    192.             return 0;
    193.         }
    194.         return distance;
    195.     }
    196.     bool GetDistanceToChild(Transform child)
    197.     {
    198.         float distance = Vector3.Distance(child.transform.localPosition, product.transform.localPosition);
    199.         return distance < distance_lim;
    200.     }
    201.  
    202.  
    203.     public override void OnActionReceived(ActionBuffers actions)
    204.     {
    205.         if (actionLimit < actionCount) //  || getProductRot() != true
    206.         {
    207.             float reward = rewardCalculate(1);
    208.             AddValue(reward);
    209.             AddReward(reward);
    210.             EndEpisode();
    211.         }
    212.         int index = 0;
    213.         actionCount++;
    214.         foreach (Transform child in childArray)
    215.         {
    216.             if (child != null)
    217.             {
    218.                 // if (GetDistanceToChild(child))
    219.                 float randomDirection = actions.ContinuousActions[index]*10;
    220.                 float newYPosition = child.localPosition.y + randomDirection * moveSpeed * Time.deltaTime;
    221.                 newYPosition = Mathf.Clamp(newYPosition, minY, maxY);
    222.                 child.localPosition = new Vector3(child.localPosition.x, newYPosition, child.localPosition.z);
    223.                 index++;
    224.             }
    225.             else{
    226.                 Debug.Log("Null child founded!");
    227.             }
    228.         }
    229.     }
    230.  
    231.     public override void OnEpisodeBegin()
    232.     {
    233.         productRigidbody.velocity = Vector3.zero;
    234.         transform.localPosition = transformLoc;
    235.         product.transform.localPosition = savedProductLoc;
    236.         actionCount = 0;
    237.         for (int i = 0; i < rows; i++)
    238.         {
    239.             for (int j = 0; j < columns; j++)
    240.             {
    241.                 childArray[i,j].transform.localPosition = tableLoc[i,j];
    242.             }
    243.         }
    244.         gameCount++;
    245.         if (gameCount%200==0){
    246.             Debug.Log("Game Count: "+gameCount+"\nWin Count: "+win+"\nAvg of Last 200 Rewards: "+CurrentAverage()+"\nLast Reward: "+lastReward);
    247.             using (StreamWriter sw = new StreamWriter(path))
    248.             {
    249.                 sw.WriteLine("Game Count: "+gameCount+"\nWin Count: "+win+"\nAvg of Last 200 Rewards: "+CurrentAverage()+"\nLast Reward: "+lastReward);
    250.             }
    251.         }
    252.     }
    253.  
    254.     public override void CollectObservations(VectorSensor sensor)
    255.     {
    256.         foreach (Transform child in childArray)
    257.         {
    258.             if (child != null)
    259.             {
    260.                 sensor.AddObservation(child.localPosition.y);
    261.             }
    262.         }
    263.  
    264.         sensor.AddObservation(getProductPos());
    265.         sensor.AddObservation(getProductSpeed());
    266.         sensor.AddObservation(targetCloseness());
    267.         sensor.AddObservation(getProductRot());
    268.         sensor.AddObservation(getProductRot());
    269.     }
    270.  
    271.     public override void Heuristic(in ActionBuffers actionsOut)
    272.     {
    273.         ActionSegment<float> continuousActions = actionsOut.ContinuousActions;
    274.         for (int i = 0; i < 400; i++)
    275.         {
    276.             continuousActions[i] = UnityEngine.Random.Range(-1f, 1f);
    277.             // continuousActions[i] = Input.GetAxisRaw("Vertical");
    278.         }
    279.     }
    280.  
    281.     // void OnTriggerEnter(Collider other)
    282.     // {
    283.     //     if (other.CompareTag(object1Tag) && other.CompareTag(object2Tag))
    284.     //     {
    285.     //         Debug.Log("Collision detected between objects with tags: " + object1Tag + " and " + object2Tag);
    286.     //         AddReward(-10f);
    287.     //         EndEpisode();
    288.     //     }
    289.     //     else if (other.CompareTag(object1Tag) && other.CompareTag("TestTag"))
    290.     //     {
    291.     //         AddReward(-10f);
    292.     //         EndEpisode();
    293.     //     }
    294.     // }
    295.  
    296. }
    297.