Search Unity

Question Surface nets compute shader messes up with interpolation

Discussion in 'Scripting' started by Floris0106, Nov 14, 2020.

  1. Floris0106

    Floris0106

    Joined:
    Sep 8, 2018
    Posts:
    3
    I've written a compute shader for the surface nets algorithm, which worked fine until i implemented the interpolation. I've tried changing the interpolation up but I can't seem to pinpoint where it's going wrong. This is the compute shader:
    Code (CSharp):
    1. #pragma kernel NetSurface
    2. #pragma kernel CalcVerts
    3.  
    4. static const uint numThreads = 8;
    5.  
    6. struct Quad {
    7.     float3 vertexD;
    8.     float3 vertexC;
    9.     float3 vertexB;
    10.     float3 vertexA;
    11. };
    12.  
    13. AppendStructuredBuffer<Quad> quads;
    14. StructuredBuffer<float> points;
    15. RWStructuredBuffer<float3> verts;
    16.  
    17. uint pointsPerAxis;
    18. float surface;
    19. float interpolation;
    20.  
    21. float sampleDensity(uint x, uint y, uint z) {
    22.     return points[(z * pointsPerAxis + y) * pointsPerAxis + x];
    23. }
    24. float4 samplePoint(uint x, uint y, uint z) {
    25.     return float4(x, y, z, points[(z * pointsPerAxis + y) * pointsPerAxis + x]);
    26. }
    27. float3 calculateVert(float4 v1, float4 v2) {
    28.     float t = (surface - v1.w) / (v2.w - v1.w);
    29.     float3 interpolated = v1.xyz + t * (v2.xyz - v1.xyz);
    30.     float3 centered = (v1.xyz + v2.xyz) / 2;
    31.     return (1 - interpolation) * centered + interpolation * interpolated;
    32. }
    33. float3 sampleVert(uint3 pos) {
    34.     return verts[(pos.z * (pointsPerAxis - 1) + pos.y) * (pointsPerAxis - 1) + pos.x];
    35.     //return pos;
    36. }
    37.  
    38. static const uint3 corners[8] = {
    39.     uint3(0, 0, 0),
    40.     uint3(1, 0, 0),
    41.     uint3(0, 1, 0),
    42.     uint3(1, 1, 0),
    43.     uint3(0, 0, 1),
    44.     uint3(1, 0, 1),
    45.     uint3(0, 1, 1),
    46.     uint3(1, 1, 1)
    47. };
    48.  
    49. [numthreads(numThreads, numThreads, numThreads)]
    50. void NetSurface(uint3 id : SV_DispatchThreadID)
    51. {
    52.     //if (id.x >= pointsPerAxis - 1 || id.y >= pointsPerAxis - 1 || id.z >= pointsPerAxis - 1 || id.x == 0 || id.y == 0 || id.z == 0) return;
    53.     if (id.x >= pointsPerAxis - 1 || id.y >= pointsPerAxis - 1 || id.z >= pointsPerAxis - 1) return;
    54.  
    55.     float4 localDensities = float4(sampleDensity(id.x, id.y, id.z), sampleDensity(id.x + 1, id.y, id.z), sampleDensity(id.x, id.y + 1, id.z), sampleDensity(id.x, id.y, id.z + 1));
    56.     bool pointA = localDensities.x >= surface;
    57.     bool pointB = localDensities.y >= surface;
    58.     bool pointC = localDensities.z >= surface;
    59.     bool pointD = localDensities.w >= surface;
    60.     bool3 surfaces = bool3(pointA != pointB, pointA != pointC, pointA != pointD);
    61.  
    62.     if (pointA) {
    63.         if (surfaces.x) {
    64.             Quad quad;
    65.             quad.vertexA = sampleVert(id + corners[5]);
    66.             quad.vertexB = sampleVert(id + corners[7]);
    67.             quad.vertexC = sampleVert(id + corners[3]);
    68.             quad.vertexD = sampleVert(id + corners[1]);
    69.             quads.Append(quad);
    70.         }
    71.         if (surfaces.y) {
    72.             Quad quad;
    73.             quad.vertexA = sampleVert(id + corners[3]);
    74.             quad.vertexB = sampleVert(id + corners[7]);
    75.             quad.vertexC = sampleVert(id + corners[6]);
    76.             quad.vertexD = sampleVert(id + corners[2]);
    77.             quads.Append(quad);
    78.         }
    79.         if (surfaces.z) {
    80.             Quad quad;
    81.             quad.vertexA = sampleVert(id + corners[4]);
    82.             quad.vertexB = sampleVert(id + corners[6]);
    83.             quad.vertexC = sampleVert(id + corners[7]);
    84.             quad.vertexD = sampleVert(id + corners[5]);
    85.             quads.Append(quad);
    86.         }
    87.     }
    88.     else {
    89.         if (surfaces.x) {
    90.             Quad quad;
    91.             quad.vertexA = sampleVert(id + corners[1]);
    92.             quad.vertexB = sampleVert(id + corners[3]);
    93.             quad.vertexC = sampleVert(id + corners[7]);
    94.             quad.vertexD = sampleVert(id + corners[5]);
    95.             quads.Append(quad);
    96.         }
    97.         if (surfaces.y) {
    98.             Quad quad;
    99.             quad.vertexA = sampleVert(id + corners[2]);
    100.             quad.vertexB = sampleVert(id + corners[6]);
    101.             quad.vertexC = sampleVert(id + corners[7]);
    102.             quad.vertexD = sampleVert(id + corners[3]);
    103.             quads.Append(quad);
    104.         }
    105.         if (surfaces.z) {
    106.             Quad quad;
    107.             quad.vertexA = sampleVert(id + corners[5]);
    108.             quad.vertexB = sampleVert(id + corners[7]);
    109.             quad.vertexC = sampleVert(id + corners[6]);
    110.             quad.vertexD = sampleVert(id + corners[4]);
    111.             quads.Append(quad);
    112.         }
    113.     }
    114. }
    115.  
    116. [numthreads(numThreads, numThreads, numThreads)]
    117. void CalcVerts(uint3 id : SV_DispatchThreadID)
    118. {
    119.     if (id.x >= pointsPerAxis - 1 || id.y >= pointsPerAxis - 1 || id.z >= pointsPerAxis - 1) return;
    120.  
    121.     float4 localDensities[8] = {
    122.         samplePoint(id.x, id.y, id.z),
    123.         samplePoint(id.x + 1, id.y, id.z),
    124.         samplePoint(id.x, id.y + 1, id.z),
    125.         samplePoint(id.x + 1, id.y + 1, id.z),
    126.         samplePoint(id.x, id.y, id.z + 1),
    127.         samplePoint(id.x + 1, id.y, id.z + 1),
    128.         samplePoint(id.x, id.y + 1, id.z + 1),
    129.         samplePoint(id.x + 1, id.y + 1, id.z + 1)
    130.     };
    131.  
    132.     float3 average = float3(0, 0, 0);
    133.     /*uint surfaceIntersections = 0;
    134.     for (uint i = 0; i < 8; i++) {
    135.         for (uint j = 0; j < 8; j++) {
    136.             if ((localDensities[i].w >= surface) ^ (localDensities[j].w >= surface) && any(localDensities[i] != localDensities[j])) {
    137.                 average += calculateVert(localDensities[i], localDensities[j]);
    138.                 surfaceIntersections++;
    139.             }
    140.         }
    141.     }
    142.     if (surfaceIntersections == 0) {
    143.         verts[(id.z * (pointsPerAxis - 1) + id.y) * (pointsPerAxis - 1) + id.x] = id;
    144.         return;
    145.     }
    146.     verts[(id.z * (pointsPerAxis - 1) + id.y) * (pointsPerAxis - 1) + id.x] = average / surfaceIntersections;*/
    147.  
    148.     float3 edgeVerts[12] = {
    149.         calculateVert(localDensities[0], localDensities[1]),
    150.         calculateVert(localDensities[2], localDensities[3]),
    151.         calculateVert(localDensities[4], localDensities[5]),
    152.         calculateVert(localDensities[6], localDensities[7]),
    153.         calculateVert(localDensities[0], localDensities[2]),
    154.         calculateVert(localDensities[1], localDensities[3]),
    155.         calculateVert(localDensities[4], localDensities[6]),
    156.         calculateVert(localDensities[5], localDensities[7]),
    157.         calculateVert(localDensities[0], localDensities[4]),
    158.         calculateVert(localDensities[1], localDensities[5]),
    159.         calculateVert(localDensities[2], localDensities[6]),
    160.         calculateVert(localDensities[3], localDensities[7])
    161.     };
    162.     bool surfaceIntersections[12] = {
    163.         (localDensities[0].w >= surface) ^ (localDensities[1].w >= surface),
    164.         (localDensities[2].w >= surface) ^ (localDensities[3].w >= surface),
    165.         (localDensities[4].w >= surface) ^ (localDensities[5].w >= surface),
    166.         (localDensities[6].w >= surface) ^ (localDensities[7].w >= surface),
    167.         (localDensities[0].w >= surface) ^ (localDensities[2].w >= surface),
    168.         (localDensities[1].w >= surface) ^ (localDensities[3].w >= surface),
    169.         (localDensities[4].w >= surface) ^ (localDensities[6].w >= surface),
    170.         (localDensities[5].w >= surface) ^ (localDensities[7].w >= surface),
    171.         (localDensities[0].w >= surface) ^ (localDensities[4].w >= surface),
    172.         (localDensities[1].w >= surface) ^ (localDensities[5].w >= surface),
    173.         (localDensities[2].w >= surface) ^ (localDensities[6].w >= surface),
    174.         (localDensities[3].w >= surface) ^ (localDensities[7].w >= surface)
    175.     };
    176.     bool isSurfaceVoxel = false;
    177.     uint intersectionCount = 0;
    178.     for (uint i = 0; i < 12; i++)
    179.         if (surfaceIntersections[i]) {
    180.             average += edgeVerts[i];
    181.             isSurfaceVoxel = true;
    182.             intersectionCount++;
    183.         }
    184.     if (!isSurfaceVoxel) return;
    185.     verts[(id.z * (pointsPerAxis - 1) + id.y) * (pointsPerAxis - 1) + id.x] = average / intersectionCount;
    186. }
    And this is the script that dispatches it:
    Code (CSharp):
    1. public static class SurfaceNets
    2.     {
    3.         static readonly ComputeShader shader = Resources.Load<ComputeShader>("Compute Shaders/SurfaceNets");
    4.         const int threadGroupSize = 8;
    5.  
    6.         public static void BakeMesh(Mesh mesh, float[] terrain, int chunkSize, float surface, float interpolation = 0f, bool deduplication = false)
    7.         {
    8.             int threadsPerAxis = Mathf.CeilToInt(chunkSize / (float)threadGroupSize);
    9.  
    10.             int points = (chunkSize + 1) * (chunkSize + 1) * (chunkSize + 1);
    11.             int maxQuadCount = chunkSize * chunkSize * chunkSize * 3;
    12.  
    13.             ComputeBuffer quadBuffer = new ComputeBuffer(maxQuadCount, sizeof(float) * 12, ComputeBufferType.Append);
    14.             ComputeBuffer pointsBuffer = new ComputeBuffer(points, sizeof(float));
    15.             ComputeBuffer quadCountBuffer = new ComputeBuffer(1, sizeof(int), ComputeBufferType.Raw);
    16.             ComputeBuffer vertsBuffer = new ComputeBuffer(maxQuadCount / 3, sizeof(float) * 3);
    17.  
    18.             CommandBuffer commandBuffer = new CommandBuffer();
    19.             commandBuffer.SetExecutionFlags(CommandBufferExecutionFlags.AsyncCompute);
    20.             commandBuffer.SetComputeBufferParam(shader, 0, "points", pointsBuffer);
    21.             commandBuffer.SetComputeBufferParam(shader, 1, "points", pointsBuffer);
    22.             commandBuffer.SetComputeBufferParam(shader, 0, "quads", quadBuffer);
    23.             commandBuffer.SetComputeBufferParam(shader, 0, "verts", vertsBuffer);
    24.             commandBuffer.SetComputeBufferParam(shader, 1, "verts", vertsBuffer);
    25.             commandBuffer.SetComputeBufferData(pointsBuffer, terrain);
    26.             commandBuffer.SetComputeBufferCounterValue(quadBuffer, 0);
    27.             commandBuffer.SetComputeIntParam(shader, "pointsPerAxis", chunkSize + 1);
    28.             commandBuffer.SetComputeFloatParam(shader, "surface", surface);
    29.             commandBuffer.SetComputeFloatParam(shader, "interpolation", 1);
    30.             commandBuffer.DispatchCompute(shader, 1, threadsPerAxis, threadsPerAxis, threadsPerAxis);
    31.             commandBuffer.DispatchCompute(shader, 0, threadsPerAxis, threadsPerAxis, threadsPerAxis);
    32.             Graphics.ExecuteCommandBufferAsync(commandBuffer, ComputeQueueType.Default);
    33.  
    34.             ComputeBuffer.CopyCount(quadBuffer, quadCountBuffer, 0);
    35.             int[] quadCountArray = { 0 };
    36.             quadCountBuffer.GetData(quadCountArray);
    37.             int quadCount = quadCountArray[0];
    38.             Quad[] quads = new Quad[quadCount];
    39.             quadBuffer.GetData(quads, 0, 0, quadCount);
    40.  
    41.             commandBuffer.Dispose();
    42.             quadBuffer.Release();
    43.             pointsBuffer.Release();
    44.             quadCountBuffer.Release();
    45.  
    46.             List<Vector3> vertices = new List<Vector3>();
    47.             List<int> indices = new List<int>();
    48.             if (deduplication)
    49.             {
    50.                 for (int i = 0; i < quads.Length; i++)
    51.                     for (int j = 0; j < 4; j++)
    52.                         if (vertices.Contains(quads[i][j]))
    53.                             indices.Add(vertices.IndexOf(quads[i][j]));
    54.                         else
    55.                         {
    56.                             vertices.Add(quads[i][j]);
    57.                             indices.Add(vertices.Count - 1);
    58.                         }
    59.  
    60.                 mesh.SetVertices(vertices);
    61.                 mesh.SetIndices(indices, MeshTopology.Quads, 0);
    62.             }
    63.             else
    64.             {
    65.                 List<Vector2> uvs = new List<Vector2>();
    66.  
    67.                 for (int i = 0; i < quads.Length; i++)
    68.                     for (int j = 0; j < 4; j++)
    69.                     {
    70.                         indices.Add(i * 4 + j);
    71.                         vertices.Add(quads[i][j]);
    72.                     }
    73.  
    74.                 mesh.SetUVs(0, uvs);
    75.                 mesh.SetVertices(vertices);
    76.                 mesh.SetIndices(indices, MeshTopology.Quads, 0);
    77.             }
    78.  
    79.             mesh.name = "Terrain";
    80.             mesh.Optimize();
    81.             mesh.RecalculateNormals();
    82.         }
    83.  
    84.         private struct Quad
    85.         {
    86.             public Vector3 a, b, c, d;
    87.  
    88.             public Vector3 this[int i]
    89.             {
    90.                 get
    91.                 {
    92.                     switch (i)
    93.                     {
    94.                         case 0:
    95.                             return a;
    96.                         case 1:
    97.                             return b;
    98.                         case 2:
    99.                             return c;
    100.                         default:
    101.                             return d;
    102.                     }
    103.                 }
    104.             }
    105.  
    106.             public static Vector2 GetUV(int i)
    107.             {
    108.                 switch (i)
    109.                 {
    110.                     case 0:
    111.                         return new Vector2(0, 0);
    112.                     case 1:
    113.                         return new Vector2(0, 1);
    114.                     case 2:
    115.                         return new Vector2(1, 1);
    116.                     default:
    117.                         return new Vector2(1, 0);
    118.                 }
    119.             }
    120.         }
    121.     }
    Thanks in advance for trying to help me.