Search Unity

  1. Welcome to the Unity Forums! Please take the time to read our Code of Conduct to familiarize yourself with the forum rules and how to post constructively.
  2. Dismiss Notice

Question Ray Perception Sensor not working with CollectObservations()

Discussion in 'ML-Agents' started by smoemnk, Sep 7, 2023.

  1. smoemnk

    smoemnk

    Joined:
    Nov 1, 2021
    Posts:
    6
    currently working on training an Agent to park a car. Before using the ray perception sensor component this was my CollectObservations() method:

    Code (CSharp):
    1. Debug.Log("collect observations called");
    2.         sensor.AddObservation(transform.localPosition);
    3.         sensor.AddObservation(transform.rotation);
    4.         sensor.AddObservation(carController.GetVelocity());
    5.  
    6.         Vector3 toTarget = targetParkingSpot.localPosition - transform.localPosition;
    7.         Vector3 normalizedToTarget = toTarget.normalized;
    8.         sensor.AddObservation(toTarget.magnitude);
    9.         sensor.AddObservation(normalizedToTarget);
    10.         sensor.AddObservation(rb.velocity.magnitude / maxSpeed);
    11.         sensor.AddObservation(rb.angularVelocity.magnitude / maxAngularVelocity);
    12.  
    13.         sensor.AddObservation(targetParkingSpot.localPosition);
    14.         sensor.AddObservation(targetParkingSpot.rotation);
    15.  
    16.         Vector3[] raycastDirections = { transform.forward, transform.right, -transform.right, transform.forward + transform.right, transform.forward - transform.right };
    17.  
    18.         foreach (Vector3 direction in raycastDirections)
    19.         {
    20.             if (Physics.Raycast(transform.localPosition, direction, out RaycastHit hit, maxRaycastDistance))
    21.             {
    22.                 float normalizedDistance = hit.distance / maxRaycastDistance;
    23.                 sensor.AddObservation(normalizedDistance);
    24.             }
    25.             else
    26.             {
    27.                 sensor.AddObservation(-1f);
    28.             }
    29.         }
    then when my agent wasn't responding or doing anything I figured CollectObservations() is at fault so I switched to ray perception sensors. However it the agent is still not moving and at seemingly random i get the "Fewer observations (0) made than vector observation size (21). The observations will be padded." warning. Here is how my scene is currently set up:

    and this is my current script:
    Code (CSharp):
    1. using System.Collections;
    2. using System.Collections.Generic;
    3. using Unity.MLAgents;
    4. using Unity.MLAgents.Sensors;
    5. using UnityEngine;
    6. using Unity.MLAgents.Actuators;
    7.  
    8. public class CarParkingAgent : Agent
    9. {
    10.     public Transform targetParkingSpot;
    11.     public CarController carController;
    12.     public float maxRaycastDistance = 10f;
    13.     private Rigidbody rb;
    14.     public float maxSpeed = 10f;
    15.     public float maxAngularVelocity = 10f;
    16.     private bool isCarParked = false;
    17.     public Vector3 tensor;
    18.     private bool lastFrameParkedStatus = false;
    19.     private int currentStep = 0;
    20.     private bool useHeuristic = true; // after ur done with BC just delete this whole logic
    21.     //private float totalReward = 0f;
    22.     private RayPerceptionSensorComponent3D RayPerceptionSensorComponent;
    23.     private int observationCount = 0;
    24.  
    25.     void Awake()
    26.     {
    27.         RayPerceptionSensorComponent = GetComponent<RayPerceptionSensorComponent3D>();
    28.         Debug.Log("Awake called");
    29.         rb = GetComponent<Rigidbody>();
    30.         rb.inertiaTensor = tensor;
    31.         rb.inertiaTensor = new Vector3(1829.532f, 1974.514f, 391.8728f);
    32.         useHeuristic = true;
    33.     }
    34.  
    35.  
    36.     private void FixedUpdate()
    37.     {
    38.         if (!isCarParked)
    39.         {
    40.             Debug.Log("car aint parked");
    41.             isCarParked = IsCarOnParkingSpot();
    42.         }
    43.         //idk what this does but if i remove it my entire IsCarOnParkingSpot() stops working
    44.         if (!lastFrameParkedStatus)
    45.         {
    46.             isCarParked = IsCarOnParkingSpot();
    47.         }
    48.         lastFrameParkedStatus = isCarParked;
    49.     }
    50.  
    51.     private void Heuristic()
    52.     {
    53.         if (!useHeuristic)
    54.         {
    55.             return;
    56.         }
    57.         else
    58.         {
    59.             Debug.Log("using heuristics");
    60.         }
    61.  
    62.  
    63.         // Calculate target direction (normalized)
    64.         Vector3 toTarget = targetParkingSpot.localPosition - transform.localPosition;
    65.         Vector3 normalizedToTarget = toTarget.normalized;
    66.  
    67.         // Apply heuristic actions based on target direction
    68.         float steering = Vector3.Dot(normalizedToTarget, transform.right);
    69.         float accelerate = Mathf.Clamp01(Vector3.Dot(rb.velocity, transform.forward));
    70.         float brake = 0f;
    71.  
    72.         // Apply actions
    73.         carController.HandleMotor(accelerate, brake);
    74.         carController.HandleSteering(steering);
    75.     }
    76.    
    77.     public override void CollectObservations(VectorSensor sensor)
    78.     {
    79.         Debug.Log("collect observations called");
    80.         // Observations related to the agent's position and orientation
    81.         sensor.AddObservation(transform.localPosition);
    82.         sensor.AddObservation(transform.rotation);
    83.         sensor.AddObservation(carController.GetVelocity());
    84.  
    85.         // Observations related to the distance and direction to the target parking spot
    86.         Vector3 toTarget = targetParkingSpot.localPosition - transform.localPosition;
    87.         Vector3 normalizedToTarget = toTarget.normalized;
    88.         sensor.AddObservation(toTarget.magnitude);
    89.         sensor.AddObservation(normalizedToTarget);
    90.  
    91.         // Observations related to the agent's speed and angular velocity
    92.         sensor.AddObservation(rb.velocity.magnitude / maxSpeed);
    93.         sensor.AddObservation(rb.angularVelocity.magnitude / maxAngularVelocity);
    94.  
    95.         // Observations related to the target parking spot's position and rotation
    96.         sensor.AddObservation(targetParkingSpot.localPosition);
    97.         sensor.AddObservation(targetParkingSpot.rotation);
    98.  
    99.         observationCount += 9;
    100.         Debug.Log("Observations made: " + observationCount);
    101.  
    102.     }
    103.  
    104.     public override void OnActionReceived(ActionBuffers actions)
    105.     {
    106.         Debug.Log("OnActionReceived called");
    107.         currentStep++;
    108.         float accelerate = Mathf.Clamp(actions.ContinuousActions[0], -1f, 1f);
    109.         //Debug.Log("Accelerate: " + accelerate);
    110.         float steering = Mathf.Clamp(actions.ContinuousActions[1], -1f, 1f);
    111.         //Debug.Log("Steering: " + steering);
    112.  
    113.         float brake = Mathf.Clamp(actions.ContinuousActions[2], 0f, 1f);
    114.         //Debug.Log("Brake: " + brake);
    115.  
    116.         carController.HandleMotor(accelerate, brake);
    117.         carController.HandleSteering(steering);
    118.  
    119.     }
    120.  
    121.  
    122.  
    123.     public override void Heuristic(in ActionBuffers actionsOut)
    124.     {
    125.         var continuousActionsOut = actionsOut.ContinuousActions;
    126.         continuousActionsOut[0] = Input.GetAxis("Vertical"); // Acceleration
    127.         continuousActionsOut[1] = Input.GetAxis("Horizontal"); // Steering
    128.         continuousActionsOut[2] = Input.GetKey(KeyCode.Space) ? 1.0f : 0.0f; // Brake
    129.     }
    130.  
    131.  
    132.     public override void OnEpisodeBegin()
    133.     {
    134.         observationCount = 0;
    135.         //totalReward = 0f;
    136.         currentStep = 0;
    137.         Debug.Log("new episode begun");
    138.         ResetCarPosition();
    139.         isCarParked = false;
    140.         lastFrameParkedStatus = false;
    141.     }
    142.    
    143.     public void OnTriggerEnter(Collider collision)
    144.     {
    145.         if (collision.gameObject.CompareTag("Obstacle"))
    146.         {
    147.             AddReward(-1.0f);
    148.             ResetCarPosition();
    149.             EndEpisode();
    150.             isCarParked = false;
    151.         }
    152.         else if (collision.gameObject.CompareTag("ParkingSpot") && !isCarParked)
    153.         {
    154.             IsCarOnParkingSpot();
    155.         }
    156.     }
    157.  
    158.     public bool IsCarOnParkingSpot()
    159.     {
    160.         Vector3 toParkingSpot = targetParkingSpot.position - transform.position;
    161.         //Debug.Log("Distance from parking spot: " + toParkingSpot.magnitude);
    162.         if (toParkingSpot.magnitude < 1f)
    163.         {
    164.             AddReward(1.0f);
    165.             isCarParked = true;
    166.             ResetCarPosition();
    167.             EndEpisode();
    168.             return true;
    169.         }
    170.         isCarParked = false;
    171.  
    172.         return false;
    173.     }
    174.  
    175.     private void ResetCarPosition()
    176.     {
    177.         rb.velocity = Vector3.zero;
    178.         rb.angularVelocity = Vector3.zero;
    179.  
    180.         Vector3 localPosition = new Vector3(
    181.             UnityEngine.Random.Range(-7f, -4f),
    182.             0.1f,
    183.             UnityEngine.Random.Range(-7f, 2f));
    184.  
    185.         //Debug.Log("Calculated Local Position: " + localPosition);
    186.  
    187.         transform.localPosition = localPosition;
    188.         transform.localRotation = Quaternion.Euler(Vector3.zero);
    189.  
    190.     }
    191.    
    192. }
    any help is greatly appreciated thankyou in advance.
     
  2. smallg2023

    smallg2023

    Joined:
    Sep 2, 2018
    Posts:
    102
    can't see any reason it wouldn't be calling collect observations but it's hard to tell when they are being skipped with collapse on in the console.