Search Unity

Question Struggling to teach an AI 3D pong

Discussion in 'ML-Agents' started by allymacv, Apr 23, 2023.

  1. allymacv

    allymacv

    Joined:
    Apr 1, 2023
    Posts:
    2
    Hey everyone,

    I created a simple game environment that's just 6 walls, 2 paddles and a ball. The paddles can move along the Y and Z axis, with the X axis being the space between the paddles.

    I've been trying to use PPO but haven't had much success, My goal is to get the Agent to have an infinite rally, but even after 20 million steps it sruggles to get more than 1 hit per episode.

    I've tried playing around with the hyperparameters and the reward system, but the results aren't getting any better. Is there something that I'm missing here?

    Code (CSharp using System.Collections; using System.Collections.Generic; using UnityEngine; using Unity.MLAgents; using Unity.MLAgents.Sensors; using Unity.MLAgents.Actuators; public class Paddle : Agent { private Vector3 initialPosition; private Vector3 movement; private Rigidbody paddleRigidbody; public BallMovement ball; public float speed = 20f; // Start is called before the first frame update void Start() { initialPosition = transform.position; paddleRigidbody = GetComponent(); } public override void OnEpisodeBegin() { ResetPaddles(); } public override void CollectObservations(VectorSensor sensor) { // Normalize the paddle's velocity based on the actual limits Vector3 toBall = new Vector3((ball.ballRigidbody.transform.position.x - transform.position.x), (ball.ballRigidbody.transform.position.y - transform.position.y), (ball.ballRigidbody.transform.position.z - transform.position.z)); sensor.AddObservation(toBall.normalized); sensor.AddObservation(paddleRigidbody.velocity.normalized); sensor.AddObservation(transform.position.normalized); sensor.AddObservation(ball.transform.position.normalized); sensor.AddObservation(ball.ballRigidbody.velocity.normalized); } public override void OnActionReceived(ActionBuffers actions) { // Extract the actions from the ActionBuffers object float moveVertical = actions.ContinuousActions[0):
    1. ;
    2.         float moveHorizontal = actions.ContinuousActions[1];
    3.  
    4.         // Apply actions to the paddle
    5.         movement = new Vector3(0f, moveVertical, moveHorizontal);
    6.         paddleRigidbody.velocity = movement * speed;
    7.  
    8.         //Add a small reward when the ball is in play
    9.         AddReward(0.01f);
    10.     }
    11.  
    12.     public override void Heuristic(in ActionBuffers actionsOut)
    13.     {
    14.         // Use the existing input system for human-controlled paddle
    15.         float moveVertical = 0f;
    16.         float moveHorizontal = 0f;
    17.  
    18.         if (Input.GetKey(KeyCode.UpArrow))
    19.         {
    20.             moveVertical = 1f;
    21.         }
    22.         if (Input.GetKey(KeyCode.DownArrow))
    23.         {
    24.             moveVertical = -1f;
    25.         }
    26.  
    27.         if (Input.GetKey(KeyCode.LeftArrow))
    28.         {
    29.             moveHorizontal = -1f;
    30.         }
    31.  
    32.         if (Input.GetKey(KeyCode.RightArrow))
    33.         {
    34.             moveHorizontal = 1f;
    35.         }
    36.  
    37.         actionsOut.ContinuousActions.Array[0] = moveVertical;
    38.         actionsOut.ContinuousActions.Array[1] = moveHorizontal;
    39.     }
    40.  
    41.     public void ResetPaddles()
    42.     {
    43.         // Reset paddles to initial position
    44.         transform.position = initialPosition;
    45.     }
    46.  
    47.     //Reward Functions
    48.     void OnCollisionEnter(Collision collision)
    49.     {
    50.         if (collision.gameObject.tag == "ball")
    51.         {
    52.             AddReward(1f);
    53.         }
    54.     }
    55. }
    56.  
    57.  

    Code (CSharp):
    1.  
    2. using System.Collections;
    3. using System.Collections.Generic;
    4. using UnityEngine;
    5.  
    6. public class BallMovement : MonoBehaviour
    7. {
    8.     public Rigidbody ballRigidbody;
    9.     public Vector3 direction;
    10.     public Paddle paddle1;
    11.     public Paddle paddle2;
    12.  
    13.     private Vector3 initialPosition;
    14.     private Paddle lastHit;
    15.     private float speed;
    16.     private int leftOrRight;
    17.  
    18.     void Start()
    19.     {
    20.         ballRigidbody = GetComponent<Rigidbody>();
    21.         initialPosition = ballRigidbody.transform.position;
    22.         speed = 20f;
    23.  
    24.         // Randomly choose left or right
    25.         if (Random.Range(0f, 1f) < 0.5f)
    26.         {
    27.             leftOrRight = -1;
    28.         }
    29.         else
    30.         {
    31.             leftOrRight = 1;
    32.         }
    33.  
    34.         // Generate random direction
    35.         direction = new Vector3(leftOrRight, Random.Range(-1f, 1f), Random.Range(-1f, 1f)).normalized;
    36.  
    37.         // Apply force in random direction
    38.         ballRigidbody.velocity = direction.normalized * speed;
    39.     }
    40.  
    41.     void OnCollisionEnter(Collision collision)
    42.     {
    43.         // If ball hits wall, reset game
    44.         if (collision.gameObject.tag == "backWall")
    45.         {
    46.             if (lastHit == paddle1)
    47.             {
    48.                 paddle1.AddReward(-1f);
    49.             }
    50.             else if (lastHit == paddle2)
    51.             {
    52.                 paddle2.AddReward(-1f);
    53.             }
    54.             else
    55.             {
    56.                 paddle1.AddReward(-1f);
    57.                 paddle2.AddReward(-1f);
    58.             }
    59.             ResetGame();
    60.         }
    61.         else if (collision.gameObject.tag == "paddle")
    62.         {
    63.             lastHit = collision.gameObject.GetComponent<Paddle>();
    64.  
    65.             // Calculate hit factor for y and z axes
    66.             float y = yHit(transform.position, collision.transform.position, collision.collider.bounds.size.y);
    67.             float z = zHit(transform.position, collision.transform.position, collision.collider.bounds.size.z);
    68.  
    69.             //x value depends on paddle hit
    70.             float x = collision.gameObject.name == "Paddle1" ? -1 : 1;
    71.  
    72.             // Calculate direction, make length=1 via .normalized
    73.             direction = new Vector3(x, y, z).normalized;
    74.  
    75.             // Set Velocity with direction * speed
    76.             GetComponent<Rigidbody>().velocity = direction * speed;
    77.         }
    78.     }
    79.  
    80.     float yHit(Vector3 ballPos, Vector3 racketPos, float racketHeight) {
    81.         return (ballPos.y - racketPos.y) / (racketHeight / 2f);
    82.     }
    83.  
    84.     float zHit(Vector3 ballPos, Vector3 racketPos, float racketWidth) {
    85.         return (ballPos.z - racketPos.z) / (racketWidth / 2f);
    86.     }
    87.  
    88.     void ResetGame()
    89.     {
    90.         // Reset ball to initial position
    91.         paddle1.ResetPaddles();
    92.         paddle2.ResetPaddles();
    93.         paddle1.EndEpisode();
    94.         paddle2.EndEpisode();
    95.         ballRigidbody.transform.position = initialPosition;
    96.         Start();
    97.     }
    98. }
    99.  

    Code (CSharp):
    1.  
    2. behaviors:
    3.   My Behavior:
    4.     trainer_type: ppo
    5.     hyperparameters:
    6.       batch_size: 4096
    7.       buffer_size: 409600
    8.       learning_rate: 0.0002
    9.       beta: 0.003
    10.       epsilon: 0.15
    11.       lambd: 0.93
    12.       num_epoch: 6
    13.       learning_rate_schedule: linear
    14.     network_settings:
    15.       normalize: true
    16.       hidden_units: 512
    17.       num_layers: 3
    18.       vis_encode_type: simple
    19.     reward_signals:
    20.       extrinsic:
    21.         gamma: 0.99
    22.         strength: 1.0
    23.     keep_checkpoints: 5
    24.     max_steps: 5000000
    25.     time_horizon: 100000
    26.     summary_freq: 20000
    27.  
    28. 0

    I appreciate any help in advance

    Thanks!
     
    Last edited: Apr 23, 2023
  2. allymacv

    allymacv

    Joined:
    Apr 1, 2023
    Posts:
    2
    Here are the parameters I have set in Unity.
     

    Attached Files: