Search Unity

  1. Unity 2020.1 has been released.
    Dismiss Notice
  2. We are looking for feedback on the experimental Unity Safe Mode which is aiming to help you resolve compilation errors faster during project startup.
    Dismiss Notice
  3. Good news ✨ We have more Unite Now videos available for you to watch on-demand! Come check them out and ask our experts any questions!
    Dismiss Notice

How to use trained model in pytorch or tensorflow

Discussion in 'ML-Agents' started by akn22, Jul 27, 2020.

  1. akn22

    akn22

    Joined:
    Nov 14, 2019
    Posts:
    2
    I want to use a trained model (.nn file) and load it in pytorch or tensorflow. Is there any script to convert .nn files to pytorch or tensorflow model in ml-agents ? if not what would the best way to do it ?

    One solution I tried was to load the checkpoints and create a tf graph but how to create a tensorflow model from this graph ?

    I use pytorch but don't have any expertise in tensorflow so any help regarding building the tensorflow model from a tensorflow graph would be really helpful since I can then translate it to creating a pytorch model.
     
  2. awjuliani

    awjuliani

    Unity Technologies

    Joined:
    Mar 1, 2017
    Posts:
    37
    Hello. We do not support this feature. Can you share a little about what you hope to accomplish with this use case. Your intuition to use the checkpoint files is the right one, as this will allow you to re-load the tensorflow model and graph.
     
  3. akn22

    akn22

    Joined:
    Nov 14, 2019
    Posts:
    2
    Thanks for your reply !!
    I am using an offroad vehicle model in Unity and using ml-agents to train a path following controller on it.
    I trained ppo using ml-agents and got the checkpoints and frozen graph. We also have the real vehicle in our lab and the next goal is to test this controller on the real vehicle.

    I finally used the frozen graph and did inference on it using tensorflow-1 and tensorflow-2. The future goal of the project is to train the policy further on the real vehicle and I need policy model and parameters to do that.

    Since I am more comfortable with Pytorch I am looking for ways to create a pytorch policy using the checkpoints or the frozen graph.

    One more issue that I faced while doing inference on the graph using tf-1 and tf-2 was that I get different values for the same input unless I initialize the session before every inference, to avoid this I need the details about the policy so that I can prune dropout nodes. I just started using tensorflow so I am not sure if this will actually solve the issue.

    Is there any way of saving the tensorflow model in h5 format (I was looking at trainer.py file because I think this is the process that saves checkpoints) since I can than use it to create a pytorch model (Not sure though) and do inference in evaluation mode which will not effect the computational graph and give same output for same input.

    I was able to control the simulation using openai-gym interface provided by ml-agents so this will be my last resort to train my own PPO using Pytorch.

    One more thing I noticed was during training I get the message that tensorflow-2 is being used by ml-agents but the checkpoints are being saved in tf-1 format I don't know why this is happening.
     
    Last edited: Jul 29, 2020
  4. awjuliani

    awjuliani

    Unity Technologies

    Joined:
    Mar 1, 2017
    Posts:
    37
    The training process provides the TensorFlow checkpoints, which can be used to re-instantiate a session, and keep training from there. I imagine that converting this model to something useable in PyTorch would be quite a challenge though, as there are many significant differences between the two frameworks. We do not support saving to .h5, but. you should be able to modify the trainer code to accomplish this. We do allow for exporting via ONNX, which might be an easier way to get the model into a PyTorch-usable format.

    As for TF versions, we use the tf1 compatibility layer within TF2, which allows us to use the codebase originally written for TF1 with the latest TF2 release. In the coming months we will be actually transitioning to PyTorch, so that would likely make all of this easier for you, but I understand that you likely would not want to wait that long.
     
unityunity