Search Unity

  1. Unity 6 Preview is now available. To find out what's new, have a look at our Unity 6 Preview blog post.
    Dismiss Notice
  2. Unity is excited to announce that we will be collaborating with TheXPlace for a summer game jam from June 13 - June 19. Learn more.
    Dismiss Notice
  3. Dismiss Notice

Question Ray Perception Sensor not working with CollectObservations()

Discussion in 'ML-Agents' started by smoemnk, Sep 7, 2023.

  1. smoemnk

    smoemnk

    Joined:
    Nov 1, 2021
    Posts:
    6
    currently working on training an Agent to park a car. Before using the ray perception sensor component this was my CollectObservations() method:

    Code (CSharp):
    1. Debug.Log("collect observations called");
    2.         sensor.AddObservation(transform.localPosition);
    3.         sensor.AddObservation(transform.rotation);
    4.         sensor.AddObservation(carController.GetVelocity());
    5.  
    6.         Vector3 toTarget = targetParkingSpot.localPosition - transform.localPosition;
    7.         Vector3 normalizedToTarget = toTarget.normalized;
    8.         sensor.AddObservation(toTarget.magnitude);
    9.         sensor.AddObservation(normalizedToTarget);
    10.         sensor.AddObservation(rb.velocity.magnitude / maxSpeed);
    11.         sensor.AddObservation(rb.angularVelocity.magnitude / maxAngularVelocity);
    12.  
    13.         sensor.AddObservation(targetParkingSpot.localPosition);
    14.         sensor.AddObservation(targetParkingSpot.rotation);
    15.  
    16.         Vector3[] raycastDirections = { transform.forward, transform.right, -transform.right, transform.forward + transform.right, transform.forward - transform.right };
    17.  
    18.         foreach (Vector3 direction in raycastDirections)
    19.         {
    20.             if (Physics.Raycast(transform.localPosition, direction, out RaycastHit hit, maxRaycastDistance))
    21.             {
    22.                 float normalizedDistance = hit.distance / maxRaycastDistance;
    23.                 sensor.AddObservation(normalizedDistance);
    24.             }
    25.             else
    26.             {
    27.                 sensor.AddObservation(-1f);
    28.             }
    29.         }
    then when my agent wasn't responding or doing anything I figured CollectObservations() is at fault so I switched to ray perception sensors. However it the agent is still not moving and at seemingly random i get the "Fewer observations (0) made than vector observation size (21). The observations will be padded." warning. Here is how my scene is currently set up:

    and this is my current script:
    Code (CSharp):
    1. using System.Collections;
    2. using System.Collections.Generic;
    3. using Unity.MLAgents;
    4. using Unity.MLAgents.Sensors;
    5. using UnityEngine;
    6. using Unity.MLAgents.Actuators;
    7.  
    8. public class CarParkingAgent : Agent
    9. {
    10.     public Transform targetParkingSpot;
    11.     public CarController carController;
    12.     public float maxRaycastDistance = 10f;
    13.     private Rigidbody rb;
    14.     public float maxSpeed = 10f;
    15.     public float maxAngularVelocity = 10f;
    16.     private bool isCarParked = false;
    17.     public Vector3 tensor;
    18.     private bool lastFrameParkedStatus = false;
    19.     private int currentStep = 0;
    20.     private bool useHeuristic = true; // after ur done with BC just delete this whole logic
    21.     //private float totalReward = 0f;
    22.     private RayPerceptionSensorComponent3D RayPerceptionSensorComponent;
    23.     private int observationCount = 0;
    24.  
    25.     void Awake()
    26.     {
    27.         RayPerceptionSensorComponent = GetComponent<RayPerceptionSensorComponent3D>();
    28.         Debug.Log("Awake called");
    29.         rb = GetComponent<Rigidbody>();
    30.         rb.inertiaTensor = tensor;
    31.         rb.inertiaTensor = new Vector3(1829.532f, 1974.514f, 391.8728f);
    32.         useHeuristic = true;
    33.     }
    34.  
    35.  
    36.     private void FixedUpdate()
    37.     {
    38.         if (!isCarParked)
    39.         {
    40.             Debug.Log("car aint parked");
    41.             isCarParked = IsCarOnParkingSpot();
    42.         }
    43.         //idk what this does but if i remove it my entire IsCarOnParkingSpot() stops working
    44.         if (!lastFrameParkedStatus)
    45.         {
    46.             isCarParked = IsCarOnParkingSpot();
    47.         }
    48.         lastFrameParkedStatus = isCarParked;
    49.     }
    50.  
    51.     private void Heuristic()
    52.     {
    53.         if (!useHeuristic)
    54.         {
    55.             return;
    56.         }
    57.         else
    58.         {
    59.             Debug.Log("using heuristics");
    60.         }
    61.  
    62.  
    63.         // Calculate target direction (normalized)
    64.         Vector3 toTarget = targetParkingSpot.localPosition - transform.localPosition;
    65.         Vector3 normalizedToTarget = toTarget.normalized;
    66.  
    67.         // Apply heuristic actions based on target direction
    68.         float steering = Vector3.Dot(normalizedToTarget, transform.right);
    69.         float accelerate = Mathf.Clamp01(Vector3.Dot(rb.velocity, transform.forward));
    70.         float brake = 0f;
    71.  
    72.         // Apply actions
    73.         carController.HandleMotor(accelerate, brake);
    74.         carController.HandleSteering(steering);
    75.     }
    76.    
    77.     public override void CollectObservations(VectorSensor sensor)
    78.     {
    79.         Debug.Log("collect observations called");
    80.         // Observations related to the agent's position and orientation
    81.         sensor.AddObservation(transform.localPosition);
    82.         sensor.AddObservation(transform.rotation);
    83.         sensor.AddObservation(carController.GetVelocity());
    84.  
    85.         // Observations related to the distance and direction to the target parking spot
    86.         Vector3 toTarget = targetParkingSpot.localPosition - transform.localPosition;
    87.         Vector3 normalizedToTarget = toTarget.normalized;
    88.         sensor.AddObservation(toTarget.magnitude);
    89.         sensor.AddObservation(normalizedToTarget);
    90.  
    91.         // Observations related to the agent's speed and angular velocity
    92.         sensor.AddObservation(rb.velocity.magnitude / maxSpeed);
    93.         sensor.AddObservation(rb.angularVelocity.magnitude / maxAngularVelocity);
    94.  
    95.         // Observations related to the target parking spot's position and rotation
    96.         sensor.AddObservation(targetParkingSpot.localPosition);
    97.         sensor.AddObservation(targetParkingSpot.rotation);
    98.  
    99.         observationCount += 9;
    100.         Debug.Log("Observations made: " + observationCount);
    101.  
    102.     }
    103.  
    104.     public override void OnActionReceived(ActionBuffers actions)
    105.     {
    106.         Debug.Log("OnActionReceived called");
    107.         currentStep++;
    108.         float accelerate = Mathf.Clamp(actions.ContinuousActions[0], -1f, 1f);
    109.         //Debug.Log("Accelerate: " + accelerate);
    110.         float steering = Mathf.Clamp(actions.ContinuousActions[1], -1f, 1f);
    111.         //Debug.Log("Steering: " + steering);
    112.  
    113.         float brake = Mathf.Clamp(actions.ContinuousActions[2], 0f, 1f);
    114.         //Debug.Log("Brake: " + brake);
    115.  
    116.         carController.HandleMotor(accelerate, brake);
    117.         carController.HandleSteering(steering);
    118.  
    119.     }
    120.  
    121.  
    122.  
    123.     public override void Heuristic(in ActionBuffers actionsOut)
    124.     {
    125.         var continuousActionsOut = actionsOut.ContinuousActions;
    126.         continuousActionsOut[0] = Input.GetAxis("Vertical"); // Acceleration
    127.         continuousActionsOut[1] = Input.GetAxis("Horizontal"); // Steering
    128.         continuousActionsOut[2] = Input.GetKey(KeyCode.Space) ? 1.0f : 0.0f; // Brake
    129.     }
    130.  
    131.  
    132.     public override void OnEpisodeBegin()
    133.     {
    134.         observationCount = 0;
    135.         //totalReward = 0f;
    136.         currentStep = 0;
    137.         Debug.Log("new episode begun");
    138.         ResetCarPosition();
    139.         isCarParked = false;
    140.         lastFrameParkedStatus = false;
    141.     }
    142.    
    143.     public void OnTriggerEnter(Collider collision)
    144.     {
    145.         if (collision.gameObject.CompareTag("Obstacle"))
    146.         {
    147.             AddReward(-1.0f);
    148.             ResetCarPosition();
    149.             EndEpisode();
    150.             isCarParked = false;
    151.         }
    152.         else if (collision.gameObject.CompareTag("ParkingSpot") && !isCarParked)
    153.         {
    154.             IsCarOnParkingSpot();
    155.         }
    156.     }
    157.  
    158.     public bool IsCarOnParkingSpot()
    159.     {
    160.         Vector3 toParkingSpot = targetParkingSpot.position - transform.position;
    161.         //Debug.Log("Distance from parking spot: " + toParkingSpot.magnitude);
    162.         if (toParkingSpot.magnitude < 1f)
    163.         {
    164.             AddReward(1.0f);
    165.             isCarParked = true;
    166.             ResetCarPosition();
    167.             EndEpisode();
    168.             return true;
    169.         }
    170.         isCarParked = false;
    171.  
    172.         return false;
    173.     }
    174.  
    175.     private void ResetCarPosition()
    176.     {
    177.         rb.velocity = Vector3.zero;
    178.         rb.angularVelocity = Vector3.zero;
    179.  
    180.         Vector3 localPosition = new Vector3(
    181.             UnityEngine.Random.Range(-7f, -4f),
    182.             0.1f,
    183.             UnityEngine.Random.Range(-7f, 2f));
    184.  
    185.         //Debug.Log("Calculated Local Position: " + localPosition);
    186.  
    187.         transform.localPosition = localPosition;
    188.         transform.localRotation = Quaternion.Euler(Vector3.zero);
    189.  
    190.     }
    191.    
    192. }
    any help is greatly appreciated thankyou in advance.
     
  2. smallg2023

    smallg2023

    Joined:
    Sep 2, 2018
    Posts:
    154
    can't see any reason it wouldn't be calling collect observations but it's hard to tell when they are being skipped with collapse on in the console.