Search Unity

Resolved Problems converting C# code to Compute Shader

Discussion in 'Scripting' started by UserNobody, Apr 19, 2021.

  1. UserNobody

    UserNobody

    Joined:
    Oct 3, 2019
    Posts:
    144
    Hello there,
    I have this code that generates a new mesh out of a provided height map. At first, this code was executed on a new thread, to ensure that the main thread would not be affected, however, it did. I always would get frame drops/freezes whenever this code would execute.
    So I decided to transfer this code to a Compute Shader. I have converted the entire code, but there are some issues with it. When I run my game in Unity I get this error message

    upload_2021-4-19_20-5-26.png

    I have executed much bigger tasks with Compute Shader with no problems, so I am a bit confused about this error.

    Does anyone know what is the cause of this?

    This is the C# code
    Code (CSharp):
    1. public static class MeshGenerator
    2. {
    3.     public static MeshData GenerateTerrainMesh(float[,] heightMap, MeshSettings meshSettings, int levelOfDetail)
    4.     {
    5.         int skipIncrement = (levelOfDetail == 0) ? 1 : levelOfDetail * 2;
    6.         int numVertsPerLine = meshSettings.numVertsPerLine;
    7.  
    8.         Vector2 topLeft = new Vector2(-1, 1) * meshSettings.meshWorldSize / 2f;
    9.  
    10.         MeshData meshData = new MeshData(numVertsPerLine, skipIncrement);
    11.  
    12.         int[,] vertexIndicesMap = new int[numVertsPerLine, numVertsPerLine];
    13.         int meshVertexIndex = 0;
    14.         int outOfMeshVertexIndex = -1;
    15.  
    16.         for (int y = 0; y < numVertsPerLine; y++)
    17.         {
    18.             for (int x = 0; x < numVertsPerLine; x++)
    19.             {
    20.                 bool isOutOfMeshVertex = y == 0 || y == numVertsPerLine - 1 || x == 0 || x == numVertsPerLine - 1;
    21.                 bool isSkippedVertex = x > 2 && x < numVertsPerLine - 3 && y > 2 && y < numVertsPerLine - 3 && ((x - 2) % skipIncrement != 0 || (y - 2) % skipIncrement != 0);
    22.                 if (isOutOfMeshVertex)
    23.                 {
    24.                     vertexIndicesMap[x, y] = outOfMeshVertexIndex;
    25.                     outOfMeshVertexIndex--;
    26.                 }
    27.                 else if (!isSkippedVertex)
    28.                 {
    29.                     vertexIndicesMap[x, y] = meshVertexIndex;
    30.                     meshVertexIndex++;
    31.                 }
    32.             }
    33.         }
    34.  
    35.         for (int y = 0; y < numVertsPerLine; y++)
    36.         {
    37.             for (int x = 0; x < numVertsPerLine; x++)
    38.             {
    39.                 bool isSkippedVertex = x > 2 && x < numVertsPerLine - 3 && y > 2 && y < numVertsPerLine - 3 && ((x - 2) % skipIncrement != 0 || (y - 2) % skipIncrement != 0);
    40.  
    41.                 if (!isSkippedVertex)
    42.                 {
    43.                     bool isOutOfMeshVertex = y == 0 || y == numVertsPerLine - 1 || x == 0 || x == numVertsPerLine - 1;
    44.                     bool isMeshEdgeVertex = (y == 1 || y == numVertsPerLine - 2 || x == 1 || x == numVertsPerLine - 2) && !isOutOfMeshVertex;
    45.                     bool isMainVertex = (x - 2) % skipIncrement == 0 && (y - 2) % skipIncrement == 0 && !isOutOfMeshVertex && !isMeshEdgeVertex;
    46.                     bool isEdgeConnectionVertex = (y == 2 || y == numVertsPerLine - 3 || x == 2 || x == numVertsPerLine - 3) && !isOutOfMeshVertex && !isMeshEdgeVertex && !isMainVertex;
    47.  
    48.                     int vertexIndex = vertexIndicesMap[x, y];
    49.                     Vector2 percent = new Vector2(x - 1, y - 1) / (numVertsPerLine - 3);
    50.                     Vector2 vertexPosition2D = topLeft + new Vector2(percent.x, -percent.y) * meshSettings.meshWorldSize;
    51.  
    52.                     float height = heightMap[x, y];
    53.  
    54.                     if (isEdgeConnectionVertex)
    55.                     {
    56.                         bool isVertical = x == 2 || x == numVertsPerLine - 3;
    57.                         int dstToMainVertexA = ((isVertical) ? y - 2 : x - 2) % skipIncrement;
    58.                         int dstToMainVertexB = skipIncrement - dstToMainVertexA;
    59.                         float dstPercentFromAToB = dstToMainVertexA / (float)skipIncrement;
    60.  
    61.                         float heightMainVertexA = heightMap[(isVertical) ? x : x - dstToMainVertexA, (isVertical) ? y - dstToMainVertexA : y];
    62.                         float heightMainVertexB = heightMap[(isVertical) ? x : x + dstToMainVertexB, (isVertical) ? y + dstToMainVertexB : y];
    63.  
    64.                         height = heightMainVertexA * (1 - dstPercentFromAToB) + heightMainVertexB * dstPercentFromAToB;
    65.                     }
    66.  
    67.                     meshData.AddVertex(new Vector3(vertexPosition2D.x, height, vertexPosition2D.y), percent, vertexIndex);
    68.  
    69.                     bool createTriangle = x < numVertsPerLine - 1 && y < numVertsPerLine - 1 && (!isEdgeConnectionVertex || (x != 2 && y != 2));
    70.  
    71.                     if (createTriangle)
    72.                     {
    73.                         int currentIncrement = (isMainVertex && x != numVertsPerLine - 3 && y != numVertsPerLine - 3) ? skipIncrement : 1;
    74.  
    75.                         int a = vertexIndicesMap[x, y];
    76.                         int b = vertexIndicesMap[x + currentIncrement, y];
    77.                         int c = vertexIndicesMap[x, y + currentIncrement];
    78.                         int d = vertexIndicesMap[x + currentIncrement, y + currentIncrement];
    79.                         meshData.AddTriangle(a, d, c);
    80.                         meshData.AddTriangle(d, a, b);
    81.                     }
    82.                 }
    83.             }
    84.         }
    85.         meshData.ProcessMesh();
    86.         return meshData;
    87.     }
    88. }
    And this is the Compute Shader code

    Code (CSharp):
    1. #pragma kernel Bake
    2.  
    3. int skipIncrement;
    4. int numVertsPerLine;
    5. float2 topLeft;
    6. float meshWorldSize;
    7.  
    8. RWBuffer<int> vertexIndicesMap;
    9. RWBuffer<float> heightMap;
    10.  
    11. RWBuffer<float3> vertices;
    12. RWBuffer<int> triangles;
    13. RWBuffer<float2> uvs;
    14. RWBuffer<float3> outOfMeshVertices;
    15. RWBuffer<int> outOfMeshTriangles;
    16.  
    17. RWBuffer<int> triangleInOutIndexValues;
    18.  
    19. void AddVertex(float3 vertexPosition, float2 uv, int vertexIndex)
    20. {
    21.     if (vertexIndex < 0)
    22.     {
    23.         outOfMeshVertices[-vertexIndex - 1] = vertexPosition;
    24.     }
    25.     else
    26.     {
    27.         vertices[vertexIndex] = vertexPosition;
    28.         uvs[vertexIndex] = uv;
    29.     }
    30. }
    31. void AddTriangle(int a, int b, int c)
    32. {
    33.     if (a < 0 || b < 0 || c < 0)
    34.     {
    35.         outOfMeshTriangles[triangleInOutIndexValues[1] + 0] = a;
    36.         outOfMeshTriangles[triangleInOutIndexValues[1] + 1] = b;
    37.         outOfMeshTriangles[triangleInOutIndexValues[1] + 2] = c;
    38.         triangleInOutIndexValues[1] += 3;
    39.     }
    40.     else
    41.     {
    42.         triangles[triangleInOutIndexValues[0] + 0] = a;
    43.         triangles[triangleInOutIndexValues[0] + 1] = b;
    44.         triangles[triangleInOutIndexValues[0] + 2] = c;
    45.         triangleInOutIndexValues[0] += 3;
    46.     }
    47. }
    48.  
    49. [numthreads(1,1,1)]
    50. void Bake (uint3 id : SV_DispatchThreadID)
    51. {
    52.     int x = numVertsPerLine - id.x;
    53.     int y = numVertsPerLine - id.y;
    54.  
    55.         bool isSkippedVertex = x > 2 && x < numVertsPerLine - 3 && y > 2 && y < numVertsPerLine - 3 && ((x - 2) % skipIncrement != 0 || (y - 2) % skipIncrement != 0);
    56.  
    57.         if (!isSkippedVertex)
    58.         {
    59.             bool isOutOfMeshVertex = y == 0 || y == numVertsPerLine - 1 || x == 0 || x == numVertsPerLine - 1;
    60.             bool isMeshEdgeVertex = (y == 1 || y == numVertsPerLine - 2 || x == 1 || x == numVertsPerLine - 2) && !isOutOfMeshVertex;
    61.             bool isMainVertex = (x - 2) % skipIncrement == 0 && (y - 2) % skipIncrement == 0 && !isOutOfMeshVertex && !isMeshEdgeVertex;
    62.             bool isEdgeConnectionVertex = (y == 2 || y == numVertsPerLine - 3 || x == 2 || x == numVertsPerLine - 3) && !isOutOfMeshVertex && !isMeshEdgeVertex && !isMainVertex;
    63.  
    64.             int vertexIndex = vertexIndicesMap[x + (y * numVertsPerLine)];
    65.             float2 percent = float2(x - 1, y - 1) / (numVertsPerLine - 3);
    66.             float2 vertexPosition2D = topLeft + float2(percent.x, -percent.y) * meshWorldSize;
    67.  
    68.             float height = heightMap[x + (y * numVertsPerLine)];
    69.  
    70.             if (isEdgeConnectionVertex)
    71.             {
    72.                 bool isVertical = x == 2 || x == numVertsPerLine - 3;
    73.                 int dstToMainVertexA = ((isVertical) ? y - 2 : x - 2) % skipIncrement;
    74.                 int dstToMainVertexB = skipIncrement - dstToMainVertexA;
    75.                 float dstPercentFromAToB = dstToMainVertexA / (float)skipIncrement;
    76.  
    77.                 float heightMainVertexA = heightMap[(isVertical ? x : x - dstToMainVertexA) + ((isVertical ? y - dstToMainVertexA : y) * numVertsPerLine)];
    78.                 float heightMainVertexB = heightMap[(isVertical ? x : x + dstToMainVertexB) + ((isVertical ? y + dstToMainVertexB : y) * numVertsPerLine)];
    79.  
    80.                 height = heightMainVertexA * (1 - dstPercentFromAToB) + heightMainVertexB * dstPercentFromAToB;
    81.             }
    82.             AddVertex(float3(vertexPosition2D.x, height, vertexPosition2D.y), percent, vertexIndex);
    83.  
    84.             bool createTriangle = x < numVertsPerLine - 1 && y < numVertsPerLine - 1 && (!isEdgeConnectionVertex || (x != 2 && y != 2));
    85.  
    86.             if (createTriangle)
    87.             {
    88.                 int currentIncrement = (isMainVertex && x != numVertsPerLine - 3 && y != numVertsPerLine - 3) ? skipIncrement : 1;
    89.  
    90.                 int a = vertexIndicesMap[x + (y * numVertsPerLine)];
    91.                 int b = vertexIndicesMap[(x + currentIncrement) + (y * numVertsPerLine)];
    92.                 int c = vertexIndicesMap[x + ((y + currentIncrement) * numVertsPerLine)];
    93.                 int d = vertexIndicesMap[(x + currentIncrement) + ((y + currentIncrement) * numVertsPerLine)];
    94.                 AddTriangle(a, d, c);
    95.                 AddTriangle(d, a, b);
    96.             }
    97.         }
    98. }
    And this is the C# code when I dispatch the shader

    Code (CSharp):
    1. public static MeshData GenerateMeshData(float[] heightMap, MeshSettings meshSettings, int levelOfDetail, ComputeShader baker)
    2.     {
    3.         int skipIncrement = (levelOfDetail == 0) ? 1 : levelOfDetail * 2;
    4.         int numVertsPerLine = meshSettings.numVertsPerLine;
    5.        
    6.         NativeArray<int> vertexIndicesMap = new NativeArray<int>(numVertsPerLine * numVertsPerLine, Allocator.TempJob);
    7.        
    8.         GenerateVertexIndicesMapJob vertexIndicesMapJob = new GenerateVertexIndicesMapJob(vertexIndicesMap, skipIncrement, numVertsPerLine);
    9.         JobHandle handle = vertexIndicesMapJob.Schedule(vertexIndicesMap.Length, 1);
    10.         handle.Complete();
    11.  
    12.         ComputeBuffer indicesMap = new ComputeBuffer(vertexIndicesMap.Length, sizeof(int));
    13.         indicesMap.SetData(vertexIndicesMap.ToArray());
    14.         vertexIndicesMap.Dispose();
    15.  
    16.         ComputeBuffer heightMapBuffer = new ComputeBuffer(heightMap.Length, sizeof(float));
    17.         heightMapBuffer.SetData(heightMap);
    18.  
    19.         Vector2 topLeft = new Vector2(-1, 1) * meshSettings.meshWorldSize / 2f;
    20.  
    21.         int numMeshEdgeVertices = (numVertsPerLine - 2) * 4 - 4;
    22.         int numEdgeConnectionVertices = (skipIncrement - 1) * (numVertsPerLine - 5) / skipIncrement * 4;
    23.         int numMainVerticesPerLine = (numVertsPerLine - 5) / skipIncrement + 1;
    24.         int numMainVertices = numMainVerticesPerLine * numMainVerticesPerLine;
    25.  
    26.         ComputeBuffer verticesBuffer = new ComputeBuffer(numMeshEdgeVertices + numEdgeConnectionVertices + numMainVertices, sizeof(float) * 3);
    27.         ComputeBuffer uvBuffer = new ComputeBuffer(verticesBuffer.count, sizeof(float) * 2);
    28.  
    29.         int numMeshEdgeTriangles = 8 * (numVertsPerLine - 4);
    30.         int numMainTriangles = (numMainVerticesPerLine - 1) * (numMainVerticesPerLine - 1) * 2;
    31.         ComputeBuffer trianglesBuffer = new ComputeBuffer((numMeshEdgeTriangles + numMainTriangles) * 3, sizeof(int));
    32.  
    33.         ComputeBuffer outOfMeshVertices = new ComputeBuffer(numVertsPerLine * 4 - 4, sizeof(float) * 3);
    34.         ComputeBuffer outOfMeshTriangles = new ComputeBuffer(24 * (numVertsPerLine - 2), sizeof(int));
    35.  
    36.         Vector3[] normals = new Vector3[verticesBuffer.count];
    37.  
    38.         int kernel = baker.FindKernel("Bake");
    39.         baker.SetInt("skipIncrement", skipIncrement);
    40.         baker.SetInt("numVertsPerLine", numVertsPerLine);
    41.         baker.SetFloats("topLeft", topLeft.x, topLeft.y);
    42.         baker.SetFloat("meshWorldSize", meshSettings.meshWorldSize);
    43.  
    44.         baker.SetBuffer(kernel, "vertexIndicesMap", indicesMap);
    45.         baker.SetBuffer(kernel, "heightMap", heightMapBuffer);
    46.         baker.SetBuffer(kernel, "vertices", verticesBuffer);
    47.         baker.SetBuffer(kernel, "triangles", trianglesBuffer);
    48.         baker.SetBuffer(kernel, "uvs", uvBuffer);
    49.         baker.SetBuffer(kernel, "outOfMeshVertices", outOfMeshVertices);
    50.         baker.SetBuffer(kernel, "outOfMeshTriangles", outOfMeshTriangles);
    51.  
    52.         ComputeBuffer triangleInOutIndexValues = new ComputeBuffer(2, sizeof(int));
    53.         baker.SetBuffer(kernel, "triangleInOutIndexValues", triangleInOutIndexValues);
    54.  
    55.         baker.Dispatch(kernel, Mathf.CeilToInt(numVertsPerLine / 8), Mathf.CeilToInt(numVertsPerLine / 8), 1);
    56.  
    57.         Vector3[] vertices = new Vector3[verticesBuffer.count];
    58.         verticesBuffer.GetData(vertices);
    59.  
    60.         Vector2[] uv = new Vector2[uvBuffer.count];
    61.         uvBuffer.GetData(uv);
    62.  
    63.         int[] triangles = new int[trianglesBuffer.count];
    64.         trianglesBuffer.GetData(triangles);
    65.  
    66.         indicesMap.Release();
    67.         heightMapBuffer.Release();
    68.         verticesBuffer.Release();
    69.         uvBuffer.Release();
    70.         trianglesBuffer.Release();
    71.         outOfMeshVertices.Release();
    72.         outOfMeshTriangles.Release();
    73.         triangleInOutIndexValues.Release();
    74.  
    75.         MeshData meshData = new MeshData(vertices, triangles, uv, normals);
    76.  
    77.         return meshData;
    78.     }
     
    Last edited: Apr 19, 2021
  2. Kurt-Dekker

    Kurt-Dekker

    Joined:
    Mar 16, 2013
    Posts:
    38,742
    I'm not 100% clear on what you're doing, but have you tried doing all the work in the second thread, then signaling the main thread when the MeshData is ready, then from the main thread telling the compute shader to get busy? OR vice versa, because honestly I don't understand what parts are doing what above.
     
  3. Neto_Kokku

    Neto_Kokku

    Joined:
    Feb 15, 2018
    Posts:
    1,751
    What is happening is that your dispatch is taking longer than the emergency GPU timeout, so Windows is resetting the GPU because it thinks it crashed.

    At glance, having numthreads(1,1,1) will make your compute shader run 32x slower on NVidia GPUs and 64x slower on AMD GPUs, because you are only using one SIMD lane per shader core (NVIDIA has 32 lanes, AMD has 64). For maximum throughput your groups should have a total number of threads that is a multiple of 64.
     
  4. UserNobody

    UserNobody

    Joined:
    Oct 3, 2019
    Posts:
    144
    The first C# code is the old code that I previously executed on the second thread, however that produced lag, so I transferred that code to a Compute Shader instead. Which is the second code. The last C# code is the code I use now to execute a Compute Shader. I needed to first initialize vertex indices map, which I've done in a Unity Job, I haven't included that piece of code and then I dispatched the Shader, by also providing those vertex indices. I've included some extra lines in the last code. That GenerateMeshData function is called by the Terrain Generator class from the main thread.
     
  5. UserNobody

    UserNobody

    Joined:
    Oct 3, 2019
    Posts:
    144
    The reason I use 1,1,1 is because the numVertsPerLine which I use in dispatch is a really annoying number which is 149. That number does not divide with anything, so I'm not sure how can I do this. And I need to only go through numVertsPerLine * numVertsPerLine times.
     
  6. Neto_Kokku

    Neto_Kokku

    Joined:
    Feb 15, 2018
    Posts:
    1,751
    You change the numthreads to (32,1,1) and dispatch 5 groups (
    Mathf.ceilToInt(149/32.0f)
    ), then you pass numVertsPerLine to the CS as a parameter, and add an if at the start of the kernel to return if the thread ID is >= numVertsPerLine. This will make the excess threads do nothing and not try to read/write from/to where they shouldn't.

    Thread efficiency (ratio of threads performing useful work) will be 93% on NVidia hardware and 77% on AMD, which is much better than your current code which is only 3.1% efficient on NVidia and 1.5% on AMD.
     
    Last edited: Apr 21, 2021
    UserNobody likes this.