Search Unity

  1. Welcome to the Unity Forums! Please take the time to read our Code of Conduct to familiarize yourself with the forum rules and how to post constructively.
  2. We have updated the language to the Editor Terms based on feedback from our employees and community. Learn more.
    Dismiss Notice
  3. Join us on November 16th, 2023, between 1 pm and 9 pm CET for Ask the Experts Online on Discord and on Unity Discussions.
    Dismiss Notice
  4. Dismiss Notice

Setting reward when the agent is getting closer to the target

Discussion in 'ML-Agents' started by Shachartz0, Jul 8, 2021.

  1. Shachartz0

    Shachartz0

    Joined:
    Dec 28, 2020
    Posts:
    13
    Hi all
    I am building a maze game with raycast used as observation.
    I would like to reward my agent for getting closer to the target so I would like to find the length of a ray that hits an object with a certain tag.
    BTW - the ray that hits the target could be each one of the 12 rays the agent has so I don't want the solution to rely on direction.

    thank you all

    Code (CSharp):
    1. using UnityEngine;
    2. using Unity.MLAgents;
    3. using Unity.MLAgents.Actuators;
    4. using Unity.MLAgents.Sensors;
    5. using System.Collections;
    6. using System.Collections.Generic;
    7.  
    8. public class SimpleCollectorAgent : Agent
    9. {
    10.     [SerializeField]
    11.     private GameObject target;
    12.  
    13.     [SerializeField]
    14.     private Material successMateial;
    15.  
    16.     [SerializeField]
    17.     private Material failMateial;
    18.  
    19.     [SerializeField]
    20.     private Material defaultMateial;
    21.  
    22.     [SerializeField]
    23.     private MeshRenderer groundMeshRenderer;
    24.  
    25.     private SimpleCharacterController characterController;
    26.     new private Rigidbody rigidbody;
    27.     private Vector3 originalPosition;
    28.     private Vector3 originalTargetPosition;
    29.  
    30.  
    31.     /// <summary>
    32.     /// Called once when the agent is first initialized
    33.     /// </summary>
    34.     public override void Initialize()
    35.     {
    36.         characterController = GetComponent<SimpleCharacterController>();
    37.         rigidbody = GetComponent<Rigidbody>();
    38.         originalPosition = transform.localPosition;
    39.         originalTargetPosition = target.transform.localPosition;
    40.     }
    41.  
    42.     /// <summary>
    43.     /// Called every time an episode begins. This is where we reset the challenge.
    44.     /// </summary>
    45.     public override void OnEpisodeBegin()
    46.     {
    47.         // Reset agent position, rotation
    48.         transform.localPosition = originalPosition;
    49.         target.transform.localPosition = originalTargetPosition;
    50.         // Reset target position
    51.         while (true)
    52.         {
    53.             target.transform.localPosition = new Vector3(Random.Range(-4.5f, 4.5f), 0.5f, Random.Range(-4.5f, 4.5f));
    54.             if (!(Physics.CheckSphere(target.transform.position, 0.1f)))
    55.             {
    56.                 break;
    57.             }
    58.         }
    59.     }
    60.  
    61.     /// <summary>
    62.     /// Controls the agent with human input
    63.     /// </summary>
    64.     /// <param name="actionsOut">The actions parsed from keyboard input</param>
    65.     public override void Heuristic(in ActionBuffers actionsOut)
    66.     {
    67.         // Read input values and round them. GetAxisRaw works better in this case
    68.         // because of the DecisionRequester, which only gets new decisions periodically.
    69.         int vertical = Mathf.RoundToInt(Input.GetAxisRaw("Vertical"));
    70.         int horizontal = Mathf.RoundToInt(Input.GetAxisRaw("Horizontal"));
    71.  
    72.         // Convert the actions to Discrete choices (0, 1, 2)
    73.         ActionSegment<int> actions = actionsOut.DiscreteActions;
    74.         actions[0] = vertical >= 0 ? vertical : 2;
    75.         actions[1] = horizontal >= 0 ? horizontal : 2;
    76.     }
    77.  
    78.     /// <summary>
    79.     /// React to actions coming from either the neural net or human input
    80.     /// </summary>
    81.     /// <param name="actions">The actions received</param>
    82.     public override void OnActionReceived(ActionBuffers actions)
    83.     {
    84.         // Convert actions from Discrete (0, 1, 2) to expected input values (-1, 0, +1)
    85.         // of the character controller
    86.         float vertical = actions.DiscreteActions[0] <= 1 ? actions.DiscreteActions[0] : -1;
    87.         float horizontal = actions.DiscreteActions[1] <= 1 ? actions.DiscreteActions[1] : -1;
    88.      
    89.         characterController.MoveVertical(vertical);
    90.         characterController.MoveHorizontal(horizontal);
    91.         AddReward(-1f / MaxStep);
    92.     }
    93.     /// <summary>
    94.     /// Respond to entering a trigger collider
    95.     /// </summary>
    96.     /// <param name="other">The object (with trigger collider) that was touched</param>
    97.     private void OnTriggerEnter(Collider other)
    98.     {
    99.        
    100.  
    101.         // If the other object is a collectible, reward and end episode
    102.         if (other.CompareTag("collectible"))
    103.         {
    104.             AddReward(1f);
    105.             EndEpisode();
    106.             StartCoroutine(swapGroundMaterial(successMateial, 0.5f));
    107.  
    108.         }
    109.  
    110.     }
    111.  
    112.     /// <summary>
    113.     /// Called when the agent collides with something solid
    114.     /// </summary>
    115.     /// <param name="collision">The collision info</param>
    116.     private void OnCollisionEnter(Collision collision)
    117.     {
    118.         if (collision.collider.CompareTag("walls"))
    119.         {
    120.             // Collided with the area boundary, give a negative reward
    121.             AddReward(-.5f);
    122.         }
    123.     }
    124.     private IEnumerator swapGroundMaterial(Material mat, float time)
    125.     {
    126.         groundMeshRenderer.material = mat;
    127.         yield return new WaitForSeconds(time);
    128.         groundMeshRenderer.material = defaultMateial;
    129.  
    130.     }
    131. }
     
  2. pablobandinopla

    pablobandinopla

    Joined:
    Jul 2, 2021
    Posts:
    97
    im a beginer myself in AI, but... if i had to guess... using the same logic as the one used in the race car examples with the checkpoints... why not, during training, use Unity's NavAgent as hints to as if we are at least walking in the direction of the path that will lead to the end goal?

    So... let the NavAgent find the right path, then use that path as reference to provide rewards in case the chosen step put the Agent closer to the closest point in the NavAgent path... it could work....

    i repeat, idk much about ai...

    PS: but this will result in the final AI learning one particular maze... if you made the maze random, i would guess it would never learn anything because there's no logic at all in the learning process of solving a maze... it is random guessings... there's no way to solve a maze that you dont know in an analytical way, it is guessing and guessing and guessing... unless you put some type of markings or visual guides in a logical manner as to help the AI create a logical structure that would actually solve stuff... because just by reinforcement learning alone and a walk direction, it will be impossible to learn anything with stucture in regards to maze solving.
     
  3. Shachartz0

    Shachartz0

    Joined:
    Dec 28, 2020
    Posts:
    13
    I want the agent to understand it should get to the target, but it fails to do so. Of course, some of that include exploration but for now the agent gets stuck in a corner and all it learns is to avoid walls.
    I saw there is a RayOutput component but I couldn't find an explanation on how to implement this method in my code.
    Could this be the solution?