Search Unity

  1. Welcome to the Unity Forums! Please take the time to read our Code of Conduct to familiarize yourself with the forum rules and how to post constructively.

ML-agents and navmesh

Discussion in 'ML-Agents' started by MurdoMacIver, Aug 18, 2020.

  1. MurdoMacIver

    MurdoMacIver

    Joined:
    Feb 24, 2017
    Posts:
    2
    Hi, I am trying to create an agent that learns to a simple RTS game that i created. the problem is when i look at all the examples of ML-agents use a rigidbody component for handling movement and collisions, where as my implementation i use a nav mesh agent for pathfinding and moving across the map and a capsule collider for collision detection.

    the 4 scripts ive added below show what i am trying to make into ML-agents, the 1st script is my unit script where i set up all the behaviours using states that a unit can preform. and then the second, 3rd and 4th scripts are how a player controls the units and how they can create new units.

    unitScript ::

    using System.Collections;
    using System.Collections.Generic;
    using UnityEngine;
    using UnityEngine.AI;
    using UnityEngine.Events;
    public enum UnitState
    {
    Idle,
    Move,
    MoveToResource,
    Gather,
    MoveToEnemy,
    Attack
    }
    public class Unit : MonoBehaviour
    {
    [Header("Stats")]
    public UnitState state;
    public int curHp;
    public int maxHp;
    public int minAttackDamage;
    public int maxAttackDamage;
    public float attackRate;
    private float lastAttackTime;
    public float attackDistance;
    public float pathUpdateRate = 1.0f;
    private float lastPathUpdateTime;
    public int gatherAmount;
    public float gatherRate;
    private float lastGatherTime;
    public ResourceSource curResourceSource;
    private Unit curEnemyTarget;
    [Header("Components")]
    public GameObject selectionVisual;
    private NavMeshAgent navAgent;
    public UnitHealthBar healthBar;
    public Player player;
    // events
    [System.Serializable]
    public class StateChangeEvent : UnityEvent<UnitState> { }
    public StateChangeEvent onStateChange;
    void Start ()
    {
    // get the components
    navAgent = GetComponent<NavMeshAgent>();
    SetState(UnitState.Idle);
    }
    void SetState (UnitState toState)
    {
    state = toState;
    // calling the event
    if(onStateChange != null)
    onStateChange.Invoke(state);
    if(toState == UnitState.Idle)
    {
    navAgent.isStopped = true;
    navAgent.ResetPath();
    }
    }
    void Update ()
    {
    switch(state)
    {
    case UnitState.Move:
    {
    MoveUpdate();
    break;
    }
    case UnitState.MoveToResource:
    {
    MoveToResourceUpdate();
    break;
    }
    case UnitState.Gather:
    {
    GatherUpdate();
    break;
    }
    case UnitState.MoveToEnemy:
    {
    MoveToEnemyUpdate();
    break;
    }
    case UnitState.Attack:
    {
    AttackUpdate();
    break;
    }
    }
    }
    // called every frame the 'Move' state is active
    void MoveUpdate ()
    {
    if(Vector3.Distance(transform.position, navAgent.destination) == 0.0f)
    SetState(UnitState.Idle);
    }
    // called every frame the 'MoveToResource' state is active
    void MoveToResourceUpdate ()
    {
    if(curResourceSource == null)
    {
    SetState(UnitState.Idle);
    return;
    }
    if(Vector3.Distance(transform.position, navAgent.destination) == 0.0f)
    SetState(UnitState.Gather);
    }
    // called every frame the 'Gather' state is active
    void GatherUpdate ()
    {
    if(curResourceSource == null)
    {
    SetState(UnitState.Idle);
    return;
    }
    LookAt(curResourceSource.transform.position);
    if(Time.time - lastGatherTime > gatherRate)
    {
    lastGatherTime = Time.time;
    curResourceSource.GatherResource(gatherAmount, player);
    }
    }
    // called every frame the 'MoveToEnemy' state is active
    void MoveToEnemyUpdate ()
    {
    // if our target is dead, go idle
    if(curEnemyTarget == null)
    {
    SetState(UnitState.Idle);
    return;
    }
    if(Time.time - lastPathUpdateTime > pathUpdateRate)
    {
    lastPathUpdateTime = Time.time;
    navAgent.isStopped = false;
    navAgent.SetDestination(curEnemyTarget.transform.position);
    }
    if(Vector3.Distance(transform.position, curEnemyTarget.transform.position) <= attackDistance)
    SetState(UnitState.Attack);
    }
    // called every frame the 'Attack' state is active
    void AttackUpdate ()
    {
    // if our target is dead, go idle
    if(curEnemyTarget == null)
    {
    SetState(UnitState.Idle);
    return;
    }
    // if we're still moving, stop
    if(!navAgent.isStopped)
    navAgent.isStopped = true;
    // attack every 'attackRate' seconds
    if(Time.time - lastAttackTime > attackRate)
    {
    lastAttackTime = Time.time;
    curEnemyTarget.TakeDamage(Random.Range(minAttackDamage, maxAttackDamage + 1));
    }
    // look at the enemy
    LookAt(curEnemyTarget.transform.position);
    // if we're too far away, move towards the enemy
    if(Vector3.Distance(transform.position, curEnemyTarget.transform.position) > attackDistance)
    SetState(UnitState.MoveToEnemy);
    }
    // called when an enemy unit attacks us
    public void TakeDamage (int damage)
    {
    curHp -= damage;
    if(curHp <= 0)
    Die();
    healthBar.UpdateHealthBar(curHp, maxHp);
    }
    // called when our health reaches 0
    void Die ()
    {
    player.units.Remove(this);
    GameManager.instance.UnitDeathCheck();
    Destroy(gameObject);
    }
    // moves the unit to a specific position
    public void MoveToPosition (Vector3 pos)
    {
    SetState(UnitState.Move);
    navAgent.isStopped = false;
    navAgent.SetDestination(pos);
    }
    // move to a resource and begin to gather it
    public void GatherResource (ResourceSource resource, Vector3 pos)
    {
    curResourceSource = resource;
    SetState(UnitState.MoveToResource);
    navAgent.isStopped = false;
    navAgent.SetDestination(pos);
    }
    // move to an enemy unit and attack them
    public void AttackUnit (Unit target)
    {
    curEnemyTarget = target;
    SetState(UnitState.MoveToEnemy);
    }
    // toggles the selection ring around our feet
    public void ToggleSelectionVisual (bool selected)
    {
    if(selectionVisual != null)
    selectionVisual.SetActive(selected);
    }
    // rotate to face the given position
    void LookAt (Vector3 pos)
    {
    Vector3 dir = (pos - transform.position).normalized;
    float angle = Mathf.Atan2(dir.x, dir.z) * Mathf.Rad2Deg;
    transform.rotation = Quaternion.Euler(0, angle, 0);
    }
    }

    playerScript ::

    using System.Collections;
    using System.Collections.Generic;
    using UnityEngine;
    using UnityEngine.Events;
    public class Player : MonoBehaviour
    {
    public bool isMe;
    [Header("Units")]
    public List<Unit> units = new List<Unit>();

    [Header("Resources")]
    public int food;
    [Header("Components")]
    #region unit components
    [Header("Gatherer Unit")]
    public GameObject unitPrefab;
    public Transform unitSpawnPos;
    [Header("Soldier Unit")]
    public GameObject unit2Prefab;
    public Transform unit2SpawnPos;
    public GameObject[] soliderArray;
    [Header("Commander Unit")]
    public GameObject unit3Prefab;
    public Transform unit3SpawnPos;
    public GameObject[] commanderArray;
    #endregion
    // events
    [System.Serializable]
    public class UnitCreatedEvent : UnityEvent<Unit> { }
    public UnitCreatedEvent onUnitCreated;
    #region unit costs
    public readonly int unitCost = 50;
    public readonly int unit2Cost = 25; //only lower cause the system developed takes away money * the number of units created
    public readonly int unit3Cost = 40;
    #endregion
    public static Player me;
    void Awake ()
    {
    if(isMe)
    me = this;
    }
    void Start ()
    {
    if(isMe)
    {
    GameUI.instance.UpdateUnitCountText(units.Count);
    GameUI.instance.UpdateFoodText(food);
    GameUI.instance.UpdateSoldierCountText(units.Count);
    GameUI.instance.UpdateCommanderCountText(units.Count);
    CameraController.instance.FocusOnPosition(unitSpawnPos.position);
    }
    food += unitCost;
    CreateNewUnit();
    }
    // called when a unit gathers a certain resource
    public void GainResource (ResourceType resourceType, int amount)
    {
    switch(resourceType)
    {
    case ResourceType.Food:
    {
    food += amount;
    if(isMe)
    GameUI.instance.UpdateFoodText(food);
    break;
    }
    }
    }
    // debug to see if a unit spawns or not when a key is pressed
    /*void Update()
    {
    if (Input.GetKeyDown(KeyCode.N))
    CreateNewUnit2();
    }*/
    // creates a new unit for the player
    #region create units
    #region Gatherer unit create
    public void CreateNewUnit ()
    {
    if(food - unitCost < 0)
    return;
    GameObject unitObj = Instantiate(unitPrefab, unitSpawnPos.position, Quaternion.identity, transform);
    Unit unit = unitObj.GetComponent<Unit>();
    units.Add(unit);
    unit.player = this;
    food -= unitCost;
    if(onUnitCreated != null)
    onUnitCreated.Invoke(unit);
    if(isMe)
    {
    GameUI.instance.UpdateUnitCountText(units.Count);
    GameUI.instance.UpdateFoodText(food);
    }
    }
    #endregion
    #region Solider unit create
    public void CreateNewUnit2()
    {
    if (food - unit2Cost < 0)
    return;
    soliderArray = new GameObject[4]; // creates 4 of the 1 unit
    for (int i = 0; i < soliderArray.Length; i++)
    {
    GameObject unitObj2 = Instantiate(unit2Prefab, unit2SpawnPos.position, Quaternion.identity, transform);
    Unit unit = unitObj2.GetComponent<Unit>();
    units.Add(unit);
    unit.player = this;
    food -= unit2Cost;
    if (onUnitCreated != null)
    onUnitCreated.Invoke(unit);
    if (isMe)
    {
    GameUI.instance.UpdateUnitCountText(units.Count);
    GameUI.instance.UpdateFoodText(food);
    }

    }
    }

    #endregion
    #region Commander unit create
    public void CreateNewUnit3()
    {
    if (food - unit3Cost < 0)
    return;
    commanderArray = new GameObject[5]; //creates 6 of the 1 unit
    for (int i = 0; i < commanderArray.Length; i++)
    {
    GameObject unitObj3 = Instantiate(unit3Prefab, unit3SpawnPos.position, Quaternion.identity, transform);
    Unit unit = unitObj3.GetComponent<Unit>();

    units.Add(unit);
    unit.player = this;
    food -= unit3Cost;
    if (onUnitCreated != null)
    onUnitCreated.Invoke(unit);
    if (isMe)
    {
    GameUI.instance.UpdateUnitCountText(units.Count);
    GameUI.instance.UpdateFoodText(food);
    }
    }
    }
    #endregion
    #endregion
    // is this my unit?
    public bool IsMyUnit (Unit unit)
    {
    return units.Contains(unit);
    }
    }

    UnitCommanderScript ::

    using System.Collections;
    using System.Collections.Generic;
    using UnityEngine;
    public class UnitCommander : MonoBehaviour
    {
    public GameObject selectionMarkerPrefab;
    public LayerMask layerMask;
    // components
    private UnitSelection unitSelection;
    private Camera cam;
    void Awake ()
    {
    // get the components
    unitSelection = GetComponent<UnitSelection>();
    cam = Camera.main;
    }
    void Update ()
    {
    // did we press down our right mouse button and do we have units selected?
    if(Input.GetMouseButtonDown(1) && unitSelection.HasUnitsSelected())
    {
    // shoot a raycast from our mouse, to see what we hit
    Ray ray = cam.ScreenPointToRay(Input.mousePosition);
    RaycastHit hit;
    // cache the selected units in an array
    Unit[] selectedUnits = unitSelection.GetSelectedUnits();
    // shoot the raycast
    if(Physics.Raycast(ray, out hit, 100, layerMask))
    {
    unitSelection.RemoveNullUnitsFromSelection();
    // are we clicking on the ground?
    if(hit.collider.CompareTag("Ground"))
    {
    UnitsMoveToPosition(hit.point, selectedUnits);
    CreateSelectionMarker(hit.point, false);
    }
    // did we click on a resource?
    else if(hit.collider.CompareTag("Resource"))
    {
    UnitsGatherResource(hit.collider.GetComponent<ResourceSource>(), selectedUnits);
    CreateSelectionMarker(hit.collider.transform.position, true);
    }
    // did we click on an enemy?
    else if(hit.collider.CompareTag("Unit"))
    {
    Unit enemy = hit.collider.gameObject.GetComponent<Unit>();
    if(!Player.me.IsMyUnit(enemy))
    {
    UnitsAttackEnemy(enemy, selectedUnits);
    CreateSelectionMarker(enemy.transform.position, false);
    }
    }
    }
    }
    }
    // called when we command units to move somewhere
    void UnitsMoveToPosition (Vector3 movePos, Unit[] units)
    {
    Vector3[] destinations = UnitMover.GetUnitGroupDestinations(movePos, units.Length, 2);
    for(int x = 0; x < units.Length; x++)
    {
    units[x].MoveToPosition(destinations[x]);
    }
    }
    // called when we command units to gather a resource
    void UnitsGatherResource (ResourceSource resource, Unit[] units)
    {
    // are just selecting 1 unit?
    if(units.Length == 1)
    {
    units[0].GatherResource(resource, UnitMover.GetUnitDestinationAroundResource(resource.transform.position));
    }
    // otherwise, calculate the unit group formation
    else
    {
    Vector3[] destinations = UnitMover.GetUnitGroupDestinationsAroundResource(resource.transform.position, units.Length);
    for(int x = 0; x < units.Length; x++)
    {
    units[x].GatherResource(resource, destinations[x]);
    }
    }
    }
    // called when we command units to attack an enemy
    void UnitsAttackEnemy (Unit target, Unit[] units)
    {
    for(int x = 0; x < units.Length; x++)
    units[x].AttackUnit(target);
    }
    // creates a new selection marker visual at the given position
    void CreateSelectionMarker (Vector3 pos, bool large)
    {
    GameObject marker = Instantiate(selectionMarkerPrefab, new Vector3(pos.x, 0.01f, pos.z), Quaternion.identity);
    if(large)
    marker.transform.localScale = Vector3.one * 3;
    }
    }

    unitMoverSctipt ::

    using System.Collections;
    using System.Collections.Generic;
    using UnityEngine;
    public class UnitMover : MonoBehaviour
    {
    // calculates a unit formation around a given destination
    public static Vector3[] GetUnitGroupDestinations (Vector3 moveToPos, int numUnits, float unitGap)
    {
    // vector3 array for final destinations
    Vector3[] destinations = new Vector3[numUnits];
    // calculate the rows and columns
    int rows = Mathf.RoundToInt(Mathf.Sqrt(numUnits));
    int cols = Mathf.CeilToInt((float)numUnits / (float)rows);
    // we need to know the current row and column we're calculating
    int curRow = 0;
    int curCol = 0;
    float width = ((float)rows - 1) * unitGap;
    float length = ((float)cols - 1) * unitGap;
    for(int x = 0; x < numUnits; x++)
    {
    destinations[x] = moveToPos + (new Vector3(curRow, 0, curCol) * unitGap) - new Vector3(length / 2, 0, width / 2);
    curCol++;
    if(curCol == rows)
    {
    curCol = 0;
    curRow++;
    }
    }
    return destinations;
    }
    // returns an array of positions evenly spaced around a resource
    public static Vector3[] GetUnitGroupDestinationsAroundResource (Vector3 resourcePos, int unitsNum)
    {
    Vector3[] destinations = new Vector3[unitsNum];
    float unitDistanceGap = 360.0f / (float)unitsNum;
    for(int x = 0; x < unitsNum; x++)
    {
    float angle = unitDistanceGap * x;
    Vector3 dir = new Vector3(Mathf.Sin(angle * Mathf.Deg2Rad), 0, Mathf.Cos(angle * Mathf.Deg2Rad));
    destinations[x] = resourcePos + dir;
    }
    return destinations;
    }
    public static Vector3 GetUnitDestinationAroundResource (Vector3 resourcePos)
    {
    float angle = Random.Range(0, 360);
    Vector3 dir = new Vector3(Mathf.Sin(angle * Mathf.Deg2Rad), 0, Mathf.Cos(angle * Mathf.Deg2Rad));
    return resourcePos + dir;
    }
    }

    so any help on how to make these scripts controlled by an ML-agent would be greatly appreciated. not looking for someone to do the work for me, just a nudge in the right direction with a few examples. i have thought about using ray perception sensor for the unit script but that seems to be dependent on using a rigidbody component, and they for the player script using a camera sensor to recognise how much resources they have and to then create a new unit. like what do i put in the key methods used in ML-agents?

    public override void Initialize()
    {

    }
    public override void OnEpisodeBegin()
    {

    }
    public override void OnActionReceived(float[] vectorAction)
    {

    }
    public override void Heuristic(float[] actionsOut)
    {

    }

    any help is greatly appreciated.
     
  2. ervteng_unity

    ervteng_unity

    Unity Technologies

    Joined:
    Dec 6, 2018
    Posts:
    150
    Hey @MurdoMacIver, check out our example environments in the repo for some examples on what to put in those methods. From what I can gather, navigation tasks like Hallway, FoodCollector and Pyramids might be most similar to your task.