Search Unity

  1. Unity support for visionOS is now available. Learn more in our blog post.
    Dismiss Notice

Question can I train two different ai at the same time?

Discussion in 'ML-Agents' started by rnjsehdus, Jun 23, 2023.

  1. rnjsehdus

    rnjsehdus

    Joined:
    Jun 9, 2023
    Posts:
    1
    I wanted to make an ai to collect the target and avoid the enemy, and another to collect the target and get the other ai. I wanted to train them both at the same time, but I don't know how to do it. Is there anyone who knows how to train two ai at the same time?

    here's the code

    this is the ai that collects the target and avoids the enemy.
    Code (CSharp):
    1. using System.Collections;
    2. using System.Collections.Generic;
    3. using UnityEngine;
    4. using Unity.MLAgents;
    5. using Unity.MLAgents.Sensors;
    6. using Unity.MLAgents.Actuators;
    7.  
    8. public class RollerAgent : Agent
    9. {
    10.     Rigidbody rBody;
    11.     public Transform Target;
    12.     public Transform Enemy;
    13.     public float forceMultiplier = 10;
    14.  
    15.     void Start()
    16.     {
    17.         rBody = GetComponent<Rigidbody>();
    18.     }
    19.  
    20.     public override void OnEpisodeBegin()
    21.     {
    22.         if(this.transform.localPosition.y < 0)
    23.         {
    24.             this.rBody.angularVelocity = Vector3.zero;
    25.             this.rBody.velocity = Vector3.zero;
    26.             this.transform.localPosition = new Vector3(0, 0.5f, 0);
    27.         }
    28.         targetNewPosition();
    29.        
    30.     }
    31.  
    32.     public override void CollectObservations(VectorSensor sensor)
    33.     {
    34.         sensor.AddObservation(Target.localPosition);
    35.         sensor.AddObservation(this.transform.localPosition);
    36.         sensor.AddObservation(Enemy.localPosition);
    37.  
    38.         sensor.AddObservation(rBody.velocity.x);
    39.         sensor.AddObservation(rBody.velocity.z);
    40.     }
    41.  
    42.     public override void OnActionReceived(ActionBuffers actionBuffers)
    43.     {
    44.         Vector3 controlSignal = Vector3.zero;
    45.         controlSignal.x = actionBuffers.ContinuousActions[0];
    46.         controlSignal.z = actionBuffers.ContinuousActions[1];
    47.         rBody.AddForce(controlSignal * forceMultiplier);
    48.  
    49.         float distanceToTarget = Vector3.Distance(this.transform.localPosition, Target.localPosition);
    50.         float distanceToEnemy = Vector3.Distance(this.transform.localPosition, Enemy.localPosition);
    51.  
    52.         if(distanceToTarget < 1.42)
    53.         {
    54.             this.AddReward(1.0f);
    55.             targetNewPosition();
    56.         }
    57.         if(distanceToEnemy < 1.42)
    58.         {
    59.             this.AddReward(-1.0f);
    60.             EndEpisode();
    61.         }
    62.  
    63.         if(this.transform.localPosition.y < 0)
    64.         {
    65.             EndEpisode();
    66.         }
    67.     }
    68.  
    69.     public void targetNewPosition()
    70.     {
    71.         Target.localPosition = new Vector3(Random.value * 8 -4, 0.5f, Random.value * 8 - 4);
    72.     }
    73.  
    74.     // public override void Heuristic(in ActionBuffers actionsOut)
    75.     // {
    76.     //     var continuousActionsOut = actionsOut.ContinuousActions;
    77.     //     continuousActionsOut[0] = Input.GetAxis("Horizontal");
    78.     //     continuousActionsOut[1] = Input.GetAxis("Vertical");
    79.     // }
    80.  
    81. }
    82.  
    this is the code that gets the other ai
    Code (CSharp):
    1. using System.Collections;
    2. using System.Collections.Generic;
    3. using UnityEngine;
    4. using Unity.MLAgents;
    5. using Unity.MLAgents.Sensors;
    6. using Unity.MLAgents.Actuators;
    7.  
    8. public class EnemyAgent : Agent
    9. {
    10.     Rigidbody rBody;
    11.     public Transform Target;
    12.     public Transform Roller;
    13.     public float forceMultiplier = 7;
    14.  
    15.     void Start()
    16.     {
    17.         rBody = GetComponent<Rigidbody>();
    18.     }
    19.  
    20.     public override void OnEpisodeBegin()
    21.     {
    22.         if(this.transform.localPosition.y < 0)
    23.         {
    24.             this.rBody.angularVelocity = Vector3.zero;
    25.             this.rBody.velocity = Vector3.zero;
    26.             this.transform.localPosition = new Vector3(4, .5f, 4);
    27.         }
    28.        
    29.     }
    30.  
    31.     public override void CollectObservations(VectorSensor sensor)
    32.     {
    33.         sensor.AddObservation(Target.localPosition);
    34.         sensor.AddObservation(this.transform.localPosition);
    35.         sensor.AddObservation(Roller.localPosition);
    36.  
    37.         sensor.AddObservation(rBody.velocity.x);
    38.         sensor.AddObservation(rBody.velocity.z);
    39.     }
    40.  
    41.     public override void OnActionReceived(ActionBuffers actionBuffers)
    42.     {
    43.         Vector3 controlSignal = Vector3.zero;
    44.         controlSignal.x = actionBuffers.ContinuousActions[0];
    45.         controlSignal.z = actionBuffers.ContinuousActions[1];
    46.         rBody.AddForce(controlSignal * forceMultiplier);
    47.  
    48.         float distanceToTarget = Vector3.Distance(this.transform.localPosition, Target.localPosition);
    49.         float distanceToRoller = Vector3.Distance(this.transform.localPosition, Roller.localPosition);
    50.  
    51.         if(distanceToRoller < 1.42)
    52.         {
    53.             this.AddReward(1.0f);
    54.             EndEpisode();
    55.         }
    56.  
    57.         if(distanceToTarget < 1.42)
    58.         {
    59.             this.AddReward(0.4f);
    60.             targetNewPosition();
    61.         }
    62.  
    63.         if(this.transform.localPosition.y < 0)
    64.         {
    65.             this.rBody.angularVelocity = Vector3.zero;
    66.             this.rBody.velocity = Vector3.zero;
    67.             this.transform.localPosition = new Vector3(4, .5f, 4);
    68.         }
    69.        
    70.     }
    71.     public void targetNewPosition()
    72.     {
    73.         Target.localPosition = new Vector3(Random.value * 8 -4, 0.5f, Random.value * 8 - 4);
    74.     }
    75.  
    76.     // public override void Heuristic(in ActionBuffers actionsOut)
    77.     // {
    78.     //     var continuousActionsOut = actionsOut.ContinuousActions;
    79.     //     continuousActionsOut[0] = Input.GetAxis("Horizontal");
    80.     //     continuousActionsOut[1] = Input.GetAxis("Vertical");
    81.     // }
    82.  
    83. }
    84.  
     
  2. smallg2023

    smallg2023

    Joined:
    Sep 2, 2018
    Posts:
    130
    yes, this could be done with self play - look at the soccer vs goalie demo