Search Unity

  1. Unity 6 Preview is now available. To find out what's new, have a look at our Unity 6 Preview blog post.
    Dismiss Notice
  2. Unity is excited to announce that we will be collaborating with TheXPlace for a summer game jam from June 13 - June 19. Learn more.
    Dismiss Notice
  3. Dismiss Notice

Question Training for handling multiple gameobjects

Discussion in 'ML-Agents' started by David69, Aug 31, 2023.

  1. David69

    David69

    Joined:
    Oct 2, 2018
    Posts:
    32
    I have a project where I'm having incoming missile and using a discrete action (Fire/Don't Fire, 1/0) along with rules about when to fire / not fire e.g. proximity and Angle between missile and self. The Agent is housed in the 'Gun' but it must make the fire / no fire decision on incoming missiles. This works fine for one missile (launching from random positions in random directions) but if I then add more missiles it makes a lot of mis-fire decisions. I tried training with e.g. 20 or 100 missiles but that doesn't seem to work. I guess my question is where and when to call the Decision Request? in a Fixed Update routine ? Looping through each missile at each fixed frame ? Calling the Decision request within the OnActions routine? Or is there a better way to do this sort of learning to handle a lot of objects at once ?

    Thanks
     
  2. smallg2023

    smallg2023

    Joined:
    Sep 2, 2018
    Posts:
    154
    if the game is in real time you could use the decision requester component rather than manual decisions as that will let you adjust the timing easier but you can call request decision whenever you need one to be made.. if you want to get a decision every step calling it in fixed update is fine but that is a lot of decisions (the default with the decision requester is 1 every 5 steps which means it could fire every 0.1s vs every 0.02s).
     
    David69 likes this.
  3. David69

    David69

    Joined:
    Oct 2, 2018
    Posts:
    32
    Can you see anything wrong with the following code?

    public override void CollectObservations(VectorSensor sensor)
    {
    Vector3 agentPosition = me.transform.localPosition;
    // Vector3 missilePosition = missileObject.transform.position;
    // Vector3 missileVelocity = missileObject.GetComponent<Rigidbody>().velocity;

    Vector3 center = dome.transform.localPosition;
    // float radius = dome.transform.localScale.x / 1.5f; // Assuming a uniform scale

    float radius = dome.transform.localScale.x / 2; // Assuming a uniform scale

    Vector3 objectToPlayer = me.transform.localPosition - missile.transform.localPosition;
    Vector3 objectVelocity = missile.GetComponent<Rigidbody>().velocity;

    playerTransform = me.transform;

    if (Missiles.Length > 0)
    {
    foreach (GameObject missileObject in Missiles)
    {


    missileTransform = missileObject.transform;
    Vector3 missileToPlayerLocal = playerTransform.localPosition - missileTransform.localPosition;
    Vector3 missileToPlayerWorld = missileTransform.TransformDirection(missileToPlayerLocal);

    float dotProduct = Vector3.Dot(missileToPlayerWorld.normalized, missileTransform.forward);

    towards = dotProduct;
    // float dotProduct = Vector3.Dot(missileToPlayerWorld.normalized, missileTransform.forward);

    // towards = Vector3.Dot(objectToPlayer.normalized, objectVelocity.normalized);

    Vector3 missileDirection = missileObject.transform.forward;

    angle = Vector3.Angle(missileToPlayerLocal, missileDirection);

    missileObject.GetComponent<Launch>().AOA = angle;


    DOT = angle;

    Vector3 centr = dome.transform.position;
    radius1 = dome.transform.localScale.x * 4; // Assuming a uniform scale
    // Calculate the distance between the point and the center of the sphere
    float distance = Vector3.Distance(missile.transform.localPosition, centr);

    if (distance <= radius1)
    {
    close = true;
    }
    else
    {
    close = false;
    }

    missileObject.GetComponent<Launch>().close1 = close;

    sensor.AddObservation(shouldFire);
    sensor.AddObservation(missileObject.transform.localPosition);
    sensor.AddObservation(me.transform.localPosition);
    sensor.AddObservation(missileObject.GetComponent<Launch>().AOA);
    sensor.AddObservation(missileObject.GetComponent<Launch>().close1);


    }
    }


    public override void OnActionReceived(ActionBuffers actions)
    {
    // base.OnActionReceived(actions);



    actions_count++;

    for (int i = 0; i < Missiles.Length; i++)
    {


    // RequestDecision();
    //bool shouldFire = actions.DiscreteActions.Array[0] == 1;
    shouldFire = actions.DiscreteActions[0];


    GameObject missileObject = Missiles;
    if (missileObject.transform.localPosition.y > 1 && missileObject.tag == "Missile")
    {

    // Each Missle's Angle Calc
    missileTransform = missileObject.transform;
    Vector3 missileToPlayerLocal = playerTransform.localPosition - missileTransform.localPosition;
    Vector3 missileDirection = missileObject.transform.forward;
    angle = Vector3.Angle(missileToPlayerLocal, missileDirection);

    missileObject.GetComponent<Launch>().AOA = angle;

    //// Distance
    //Vector3 centr = dome.transform.position;
    //radius1 = dome.transform.localScale.x * 4; // Assuming a uniform scale
    //float distance = Vector3.Distance(missile.transform.localPosition, centr);

    //if (distance <= radius1)
    //{
    // close = true;
    //}


    // Fired as coming towards
    if (shouldFire == 1 && missileObject.GetComponent<Launch>().AOA < 90)
    {

    DrawLine(me.transform.localPosition, missileObject.transform.localPosition, Color.green, 0.5f);
    // GameObject.Instantiate(Explos, missileObject.transform.localPosition, missileObject.transform.rotation);

    // missile.tag = "landed"
    SetReward(1);
    corr_shot++;
    // EndEpisode();
    }

    // Fire as moving away - wasted shot
    if (shouldFire == 1 && missileObject.GetComponent<Launch>().AOA > 90)
    {
    DrawLine(me.transform.localPosition, missileObject.transform.localPosition, Color.red, 0.5f);
    SetReward(-1);
    wasted_shot++;
    // EndEpisode();
    }


    // Didn't fire when should have !!
    if (shouldFire == 0 && missileObject.GetComponent<Launch>().AOA < 90)
    {
    // DrawLine(me.transform.localPosition, missile.transform.localPosition, Color.yellow, 0.5f);
    SetReward(-1);
    //Debug.Log("Fire = " + actions.DiscreteActions[0] + " Towards = " + towards);
    Should_have_shot++;
    // EndEpisode();
    }

    // Didn't fire & Moving away - correct
    if (shouldFire == 0 && missileObject.GetComponent<Launch>().AOA > 90)
    {
    // DrawLine(me.transform.position, missile.transform.position, Color.green, 1.0f);
    SetReward(1);
    corr_left++;
    // EndEpisode();
    } //


    if (shouldFire == 1 && missileObject.GetComponent<Launch>().close1 == true)
    {
    DrawLine(me.transform.position, missile.transform.position, Color.yellow, 1.0f);
    SetReward(1);
    corr_shot++;
    // EndEpisode();
    } //

    if (shouldFire == 1 && missileObject.GetComponent<Launch>().close1 == false)
    {
    DrawLine(me.transform.position, missile.transform.position, Color.cyan, 1.0f);
    SetReward(-1);
    // wasted_shot++;
    // EndEpisode();
    } //




    //if (missile.tag == "landed")
    //{
    // EndEpisode();
    //}

    }
    }