Search Unity

Question Barracuda has different output, is there any operation unsupported?

Discussion in 'Barracuda' started by Jamesika, Aug 17, 2022.

  1. Jamesika

    Jamesika

    Joined:
    May 13, 2017
    Posts:
    5
    this is the pytorch model and I upload the onnx model. Is there any operation unsupported?

    Code (Python):
    1. import torch
    2. import torch.nn as nn
    3.  
    4. # Custom BI-LSTM, because unity barracuda doesn't support "bidirectional = True"
    5. class BILSTM(nn.Module):
    6.     def __init__(self, inputSize, hiddenSize, numLayers, dropOut):
    7.         super(BILSTM, self).__init__()
    8.         self.biLayer1 = nn.LSTM(input_size=inputSize,hidden_size=hiddenSize, num_layers=numLayers, batch_first=True, dropout=dropOut).cuda()
    9.         self.biLayer2 = nn.LSTM(input_size=inputSize,hidden_size=hiddenSize, num_layers=numLayers, batch_first=True, dropout=dropOut).cuda()
    10.     def forward(self, x):
    11.         out1, (hidden1, _) = self.biLayer1(x)
    12.         out2, (hidden2, _) = self.biLayer2(torch.flip(x, dims=[1]))
    13.         out2 = torch.flip(out2,dims=[1])
    14.         hidden = torch.cat([hidden1, hidden2], dim = 0)
    15.         return (out1,out2), (hidden, 0)
    16.  
    17. class SimNN(nn.Module):
    18.     def __init__(self, inputSize, hiddenSize, numLayers):
    19.         super(SimNN, self).__init__()
    20.         self.BILSTM = BILSTM(inputSize*2, hiddenSize, numLayers, 0.5)
    21.         self.classifyLayer = nn.Linear(hiddenSize, 2)
    22.         self.dropOut = nn.Dropout(p=0.2)
    23.  
    24.     def forward(self, x):
    25.         xL = x[:,0,:,:]
    26.         xR = x[:,1,:,:]
    27.         x = torch.cat([xL, xR], dim = 2)
    28.         _, (h_n, c_n) = self.BILSTM.forward(x)
    29.         out = h_n[3]
    30.         out = self.classifyLayer(out)
    31.         out = self.dropOut(out)
    32.         return out
     
  2. Jamesika

    Jamesika

    Joined:
    May 13, 2017
    Posts:
    5
    I switch to use CNN, it seems LSTM is not supported in this situation