Search Unity

  1. Megacity Metro Demo now available. Download now.
    Dismiss Notice
  2. Unity support for visionOS is now available. Learn more in our blog post.
    Dismiss Notice

Parallel prefix sum ComputeShader

Discussion in 'Shaders' started by cecarlsen, Feb 20, 2018.

  1. cecarlsen

    cecarlsen

    Joined:
    Jun 30, 2006
    Posts:
    858
    I am attempting to implement "Fast Fixed Radius Nearest Neighbours", a method presented by Nvidia in 2013.
    http://on-demand.gputechconf.com/gt...17-fast-fixed-radius-nearest-neighbor-gpu.pdf

    The example pointed to is Fluid v3, which is written in CUDA. It contains an implementation (copyrighted by Nvidia) of a parallel prefix sum algorithm.
    https://github.com/rchoetzlein/fluids3/blob/master/fluids/prefix_sum.cu

    Prefix sum is also called "prefix scan". It takes an array like { 1, 2, 1, 2 } and outputs the accumulative sum from left to right either exclusive { 0, 1, 3, 4 } or inclusive { 1, 3, 4, 6 }.
    https://en.wikipedia.org/wiki/Prefix_sum
    http://www.umiacs.umd.edu/~ramani/cmsc828e_gpusci/ScanTalk.pdf

    It seems like a very basic algorithm for parallel computing, so I am surprised that I can't find a standard implementation for DirectCompute anywhere. I wonder if anyone in the Unity community has already overcome this obstacle.

    EDIT:

    I found a HLSL implementation by Takahiro Harada from 2011. Note that this implementation is limited to 524.288 elements (groupshared size at 2048).
    http://www.heterogeneouscompute.org/?page_id=7
    https://github.com/erwincoumans/exp...ves/AdlPrimitives/Scan/PrefixScanKernels.hlsl

    And this one by MS, limited to 16384 elements (which could probably be increased by upping the groupshared memory from 128).
    https://github.com/walbourn/directx-sdk-samples/blob/master/AdaptiveTessellationCS40/ScanCS.hlsl

    I've been trying to port both, but no luck so far.
     
    Last edited: Feb 20, 2018
  2. customphase

    customphase

    Joined:
    Aug 19, 2012
    Posts:
    245
    Also stumbled into that, ended up just bruteforcing it with loop in one thread. It worked fine for me, but i only had 32k buckets, probably wont scale well with more than that
     
    cecarlsen likes this.
  3. scrawk

    scrawk

    Joined:
    Nov 22, 2012
    Posts:
    804
    What part is causing you issues exactly?

    The two hlsl shaders look like they should work in unity with minimal changes.
     
  4. cecarlsen

    cecarlsen

    Joined:
    Jun 30, 2006
    Posts:
    858
    I just got ScanCS by MS working. Takahiro's version kept crashing Unity, and I still don't understand why. The most difficult part was to translate how the "shader resources views" (StructuredBuffer) and "unordered access views" (RWStructuredBuffer) were set from this host file:
    https://github.com/walbourn/directx-sdk-samples/blob/master/AdaptiveTessellationCS40/ScanCS.cpp

    That ended up looking like this:

    Code (CSharp):
    1. // Compute how many thread groups we will need.
    2. int threadGroupCount = count / threadsPerGroup;
    3.  
    4. // ScanInBucket.
    5. int scanInBucketKernel = exclusive ? _scanInBucketExclusiveKernel : _scanInBucketInclusiveKernel;
    6. _computeShader.SetBuffer( scanInBucketKernel, _inputPropId, countBuffer.computeBuffer );
    7. _computeShader.SetBuffer( scanInBucketKernel, _resultPropId, resultBuffer.computeBuffer );
    8. _computeShader.Dispatch( scanInBucketKernel, threadGroupCount, 1, 1 );
    9.  
    10. // ScanBucketResult.
    11. _computeShader.SetBuffer( _scanBucketResultKernel, _inputPropId, resultBuffer.computeBuffer );
    12. _computeShader.SetBuffer( _scanBucketResultKernel, _resultPropId, _auxBuffer.computeBuffer );
    13. _computeShader.Dispatch( _scanBucketResultKernel, 1, 1, 1 );
    14.  
    15. // ScanAddBucketResult.
    16. _computeShader.SetBuffer( _scanAddBucketResultKernel, _inputPropId, _auxBuffer.computeBuffer );
    17. _computeShader.SetBuffer( _scanAddBucketResultKernel, _resultPropId, resultBuffer.computeBuffer );
    18. _computeShader.Dispatch( _scanAddBucketResultKernel, threadGroupCount, 1, 1 );

    To make it an "exclusive scan", I added a reading offset in the first kernel:

    Code (CSharp):
    1. [numthreads( THREADS_PER_GROUP, 1, 1 )]
    2. void ScanInBucketExclusive( uint DTid : SV_DispatchThreadID, uint GI: SV_GroupIndex ) // CSScanInBucket
    3. {
    4.     uint x = DTid == 0 ? 0 : _Input[ DTid-1 ];
    5.     CSScan( DTid, GI, x );
    6. }
    BTW @scrawk, thank you for the shader tutorials you once uploaded, they where a great resources when I was starting to get into all this.
     
    Last edited: Feb 22, 2018
    marcell_h likes this.
  5. cecarlsen

    cecarlsen

    Joined:
    Jun 30, 2006
    Posts:
    858
    I was asked to share the full code, so here goes. Luckily, it was MIT licensed in the first place.

    Code (CSharp):
    1. /*
    2.     The MIT License (MIT)
    3.  
    4.     Copyright (c) 2004-2019 Microsoft Corp
    5.     Modified by Carl Emil Carlsen 2018.
    6.  
    7.     Permission is hereby granted, free of charge, to any person obtaining a copy of this
    8.     software and associated documentation files (the "Software"), to deal in the Software
    9.     without restriction, including without limitation the rights to use, copy, modify,
    10.     merge, publish, distribute, sublicense, and/or sell copies of the Software, and to
    11.     permit persons to whom the Software is furnished to do so, subject to the following
    12.     conditions:
    13.  
    14.     The above copyright notice and this permission notice shall be included in all copies
    15.     or substantial portions of the Software.
    16.  
    17.     THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
    18.     INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
    19.     PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
    20.     HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF
    21.     CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE
    22.     OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
    23.      
    24.     From directx-sdk-samples by Chuck Walbourn:
    25.     https://github.com/walbourn/directx-sdk-samples/blob/master/AdaptiveTessellationCS40/ScanCS.hlsl
    26. */
    27.  
    28. #pragma kernel ScanInBucketInclusive
    29. #pragma kernel ScanInBucketExclusive
    30. #pragma kernel ScanBucketResult
    31. #pragma kernel ScanAddBucketResult
    32.  
    33. #define THREADS_PER_GROUP 512 // Ensure that this equals the 'threadsPerGroup' const in the host script.
    34.  
    35. StructuredBuffer<uint> _Input;
    36. RWStructuredBuffer<uint> _Result;
    37.  
    38. groupshared uint2 bucket[THREADS_PER_GROUP];
    39.  
    40. void CSScan( uint3 DTid, uint GI, uint x )
    41. {
    42.     // since CS40 can only support one shared memory for one shader, we use .xy and .zw as ping-ponging buffers
    43.     // if scan a single element type like int, search and replace all .xy to .x and .zw to .y below
    44.     bucket[GI].x = x;
    45.     bucket[GI].y = 0;
    46.  
    47.     // Up sweep  
    48.     [unroll]
    49.     for( uint stride = 2; stride <= THREADS_PER_GROUP; stride <<= 1 )
    50.     {
    51.         GroupMemoryBarrierWithGroupSync();
    52.         if ( (GI & (stride - 1)) == (stride - 1) ) bucket[GI].x += bucket[GI - stride/2].x;
    53.     }
    54.  
    55.     if( GI == (THREADS_PER_GROUP - 1) ) bucket[GI].x = 0;
    56.  
    57.     // Down sweep
    58.     bool n = true;
    59.     [unroll]
    60.     for( stride = THREADS_PER_GROUP / 2; stride >= 1; stride >>= 1 )
    61.     {
    62.         GroupMemoryBarrierWithGroupSync();
    63.  
    64.         uint a = stride - 1;
    65.         uint b = stride | a;
    66.  
    67.         if( n )        // ping-pong between passes
    68.         {
    69.             if( ( GI & b) == b )
    70.             {
    71.                 bucket[GI].y = bucket[GI-stride].x + bucket[GI].x;
    72.             } else
    73.             if( (GI & a) == a )
    74.             {
    75.                 bucket[GI].y = bucket[GI+stride].x;
    76.             } else      
    77.             {
    78.                 bucket[GI].y = bucket[GI].x;
    79.             }
    80.         } else {
    81.             if( ( GI & b) == b )
    82.             {
    83.                 bucket[GI].x = bucket[GI-stride].y + bucket[GI].y;
    84.             } else
    85.             if( (GI & a) == a )
    86.             {
    87.                 bucket[GI].x = bucket[GI+stride].y;
    88.             } else      
    89.             {
    90.                 bucket[GI].x = bucket[GI].y;
    91.             }
    92.         }
    93.      
    94.         n = !n;
    95.     }  
    96.    
    97.     _Result[DTid.x] = bucket[GI].y + x;
    98. }
    99.  
    100.  
    101. // Scan in each bucket.
    102. [numthreads( THREADS_PER_GROUP, 1, 1 )]
    103. void ScanInBucketInclusive( uint DTid : SV_DispatchThreadID, uint GI: SV_GroupIndex ) // CSScanInBucket
    104. {
    105.     uint x = _Input[DTid];
    106.     CSScan( DTid, GI, x );
    107. }
    108.  
    109. // Scan in each bucket.
    110. [numthreads( THREADS_PER_GROUP, 1, 1 )]
    111. void ScanInBucketExclusive( uint DTid : SV_DispatchThreadID, uint GI: SV_GroupIndex ) // CSScanInBucket
    112. {
    113.     uint x = DTid == 0 ? 0 : _Input[ DTid-1 ];
    114.     CSScan( DTid, GI, x );
    115. }
    116.  
    117.  
    118. // Record and scan the sum of each bucket.
    119. [numthreads( THREADS_PER_GROUP, 1, 1 )]
    120. void ScanBucketResult( uint DTid : SV_DispatchThreadID, uint GI: SV_GroupIndex )
    121. {
    122.     uint x = _Input[DTid*THREADS_PER_GROUP - 1];
    123.     CSScan( DTid, GI, x );
    124. }
    125.  
    126.  
    127. // Add the bucket scanned result to each bucket to get the final result.
    128. [numthreads( THREADS_PER_GROUP, 1, 1 )]
    129. void ScanAddBucketResult( uint Gid : SV_GroupID, uint3 DTid : SV_DispatchThreadID )
    130. {
    131.     _Result[DTid.x] = _Result[DTid.x] + _Input[Gid];
    132. }
     
    IsobelShasha and Cery_ like this.
  6. Mytino

    Mytino

    Joined:
    Jan 28, 2019
    Posts:
    16
    Edit 2: Never mind! Somehow I thought the csharp code you posted earlier in the thread was different ways to use the shader and not a sequence. It works now! And it's super fast :) Thanks again!

    Thank you so much for posting this! I've been using it for my SPH simulation based on the same talk as you. However, I've been wondering. Is this code limited to a compute buffer with THREADS_PER_GROUP elements? GPUs won't allow you to set that higher than 1024, and my grid search can have more than 1024 elements (in the count buffer with one count per grid cell). A 10x10x10 grid is already 1000 elements, and I assume it's usually much bigger. Do I have to split the buffer up? Or is there something I'm missing? It's working perfectly when the element count is less than or equal to THREADS_PER_GROUP.

    Edit: Here's my code for dispatch. Do I have to use any of the other kernels as well? (Btw, I changed the naming of _Input to InputBuffer and _Result to ResultBuffer)
    Code (CSharp):
    1. kernelIndex = scanCs.FindKernel("ScanInBucketExclusive");
    2. scanCs.SetBuffer(kernelIndex, "InputBuffer", countBuffer);
    3. scanCs.SetBuffer(kernelIndex, "ResultBuffer", minIndexInCellBuffer);
    4. scanCs.Dispatch(kernelIndex, gridThreadGroups.x, gridThreadGroups.y, gridThreadGroups.z);
     
    Last edited: Oct 10, 2019
    cecarlsen likes this.
  7. WamboGer

    WamboGer

    Joined:
    Dec 13, 2015
    Posts:
    2
    I could put this to use as well, that shader code from directx-sdk-samples was a good find cecarlsen. Thanks for the translation to unity aswell.

    I want to add my findings with the algorithms limitations and requirements:
    • The THREADS_PER_GROUP must be a power of two with an odd exponent.
      • This requirement for the power of two comes from the bitwise operations.
      • The requirement for the odd exponent comes from the for-loops, which need to unroll an odd amount of times.
    • Therefore the maximum value for THREADS_PER_GROUP is 2^9 = 512
      • The hardware limit of 1024 sadly does not work, because 2^10 is a power of two with an even exponent.
    • The maximum elements that can be processed is 512x512 = 262.144
    I hope my findings are correct and helpful.
     
    Mytino and cecarlsen like this.
  8. arkano22

    arkano22

    Joined:
    Sep 20, 2012
    Posts:
    1,891
    Won't this:

    Code (CSharp):
    1. uint x = _Input[DTid*THREADS_PER_GROUP - 1];
    blow up when the thread id is 0, by accessing _Input[-1]? :rolleyes:
     
    rustinlee likes this.
  9. cecarlsen

    cecarlsen

    Joined:
    Jun 30, 2006
    Posts:
    858
    It's been a while since I touched this ... but as far as I remember, on my Nvidia card if you read out of bounds it just returns zero.
     
  10. burningmime

    burningmime

    Joined:
    Jan 25, 2014
    Posts:
    845
    Not on DirectX11; the API guarantees it. But newer APIs (DX12, Metal, Vulkan) don't spec what it will do. Some GPUs, especially mobile APUs, have different semantics, and might return garbage (whatever was in that memory location).
     
  11. Neto_Kokku

    Neto_Kokku

    Joined:
    Feb 15, 2018
    Posts:
    1,751
    On some consoles that code will happily try to access out-of-bounds memory, and either read whatever value is there or outright crash if it's an unmapped address.

    Heck, you can actually crash Unity itself in DX11 too if you're writing using a bogus index (*raises hand*). Texture accesses usually have extra safeguards against that, but with buffers all bets are off.
     
    arkano22 likes this.
  12. arkano22

    arkano22

    Joined:
    Sep 20, 2012
    Posts:
    1,891
    Can only speak from experience, for me (Metal on MacOS, AMD Fire Pro D300) this prefix sum implementation either crashes (around 30% of the time) or returns incorrect results, since the first intermediate bucket in the algorithm gets random garbage when reading from _Input[-1], this value then gets added to all entries in the buffer.

    Fixed by checking for DTid == 0 and returning zero, in the ScanBucketResult kernel. Just wanted to leave this here for future visitors! :)
     
    Last edited: Feb 8, 2022
    rustinlee and cecarlsen like this.
  13. Mytino

    Mytino

    Joined:
    Jan 28, 2019
    Posts:
    16
    I started out using the scan/prefix sum from this thread, but then I came across another scan implementation, in some code by https://twitter.com/nialltl that they were kind enough to give me, and when I also added unrolling to that one, I think it was quite a bit faster than the one in this thread when performance testing.

    Here's the code for it, it's a short one. It might also support 1024 threads, but not sure.
    Code (CSharp):
    1. void Scan(uint id, uint gi, uint x) {
    2.     bucket[gi] = x;
    3.  
    4.     [unroll]
    5.     for (uint t = 1; t < THREADS_PER_GROUP; t <<= 1) {
    6.         GroupMemoryBarrierWithGroupSync();
    7.         uint temp = bucket[gi];
    8.         if (gi >= t) temp += bucket[gi - t];
    9.         GroupMemoryBarrierWithGroupSync();
    10.         bucket[gi] = temp;
    11.     }
    12.  
    13.     OutputBufW[id] = bucket[gi];
    14. }
    Here in a complete compute shader, based on the one I got from this thread:
    Code (CSharp):
    1. #define THREADS_PER_GROUP 512 // Ensure that this equals the "threadsPerGroup" variables in the host scripts using this.
    2.  
    3. StructuredBuffer<uint> InputBufR;
    4. RWStructuredBuffer<uint> OutputBufW;
    5.  
    6. groupshared uint bucket[THREADS_PER_GROUP];
    7.  
    8. void Scan(uint id, uint gi, uint x) {
    9.     bucket[gi] = x;
    10.  
    11.     [unroll]
    12.     for (uint t = 1; t < THREADS_PER_GROUP; t <<= 1) {
    13.         GroupMemoryBarrierWithGroupSync();
    14.         uint temp = bucket[gi];
    15.         if (gi >= t) temp += bucket[gi - t];
    16.         GroupMemoryBarrierWithGroupSync();
    17.         bucket[gi] = temp;
    18.     }
    19.  
    20.     OutputBufW[id] = bucket[gi];
    21. }
    22.  
    23. // Perform isolated scans within each group.
    24. #pragma kernel ScanInGroupsInclusive
    25. [numthreads(THREADS_PER_GROUP, 1, 1)]
    26. void ScanInGroupsInclusive(uint id : SV_DispatchThreadID, uint gi : SV_GroupIndex) {
    27.     uint x = InputBufR[id];
    28.     Scan(id, gi, x);
    29. }
    30.  
    31. // Perform isolated scans within each group. Shift the input so as to make the final
    32. // result (obtained after the ScanSums and AddScannedSums calls) exclusive.
    33. #pragma kernel ScanInGroupsExclusive
    34. [numthreads(THREADS_PER_GROUP, 1, 1)]
    35. void ScanInGroupsExclusive(uint id : SV_DispatchThreadID, uint gi : SV_GroupIndex) {
    36.     uint x = (id == 0) ? 0 : InputBufR[id - 1];
    37.     Scan(id, gi, x);
    38. }
    39.  
    40. // Scan the sums of each of the groups (partial sums) from the preceding ScanInGroupsInclusive/Exclusive call.
    41. #pragma kernel ScanSums
    42. [numthreads(THREADS_PER_GROUP, 1, 1)]
    43. void ScanSums(uint id : SV_DispatchThreadID, uint gi : SV_GroupIndex) {
    44.     uint x = (id == 0) ? 0 : InputBufR[id * THREADS_PER_GROUP - 1];
    45.     Scan(id, gi, x);
    46. }
    47.  
    48. // Add the scanned sums to the output of the first kernel call, to get the final, complete prefix sum.
    49. #pragma kernel AddScannedSums
    50. [numthreads(THREADS_PER_GROUP, 1, 1)]
    51. void AddScannedSums(uint id : SV_DispatchThreadID, uint gid : SV_GroupID) {
    52.     OutputBufW[id] += InputBufR[gid];
    53. }
    If you look at wikipedia: https://en.wikipedia.org/wiki/Prefix_sum
    They mention two algorithms. I think the original one in this thread is based on algorithm 2, and this other one is based on algorithm 1. I thought algorithm 2 would be faster, since it's "work-efficient", but I didn't manage to make it run faster I think. Perhaps if you dispatch or profile in some other way than I did, you'll get different results. Or if you use a different GPU or something. I'm very interested in hearing how well it runs for you.

    In theory, when choosing between these two algorithms, I think the optimal case is to use algorithm 1 when the input length is small enough to be done "efficiently" with the number of processors you have, and algorithm 2 otherwise, or something like that.

    Edit: I asked nialltl and they provided me a link to a reddit post where they got some help finalizing the implementation, which is based on a slide from an NVIDIA GDC presentation that they link to in the post. https://old.reddit.com/r/GraphicsPr...dentifying_inconsistent_behaviour_in_compute/
    The GDC presentation, see slide 39: https://www.nvidia.com/content/PDF/GDC2011/Nathan_Hoobler.pdf
     
    Last edited: Feb 11, 2022
    IsobelShasha and Cery_ like this.
  14. burningmime

    burningmime

    Joined:
    Jan 25, 2014
    Posts:
    845
    It's one of those "you shouldn't do it but DX11 lets you" things. It's undefined on Metal and Vulkan (and DX12 depending on the type of binding).

    I'm not sure about writing to an OOB index. The DX11 spec guarantees that *reading* from an out-of-bounds index on a UAV will always return 0. I can find a couple references to it (like this); the exact wording is going to be somewhere deep in here. Of course, there can be implementation issues too.
     
  15. initialNeil

    initialNeil

    Joined:
    Jan 2, 2019
    Posts:
    1
    Great work! The code from Mytino worked for me. In case anyone wants to try the compute shader, here I post the script for calling the InclusiveScan.

    - Prepare group buffer in pyramids (in case more than 512*512 or even more than 512*512*512)
    - Scan groups in a V-loop manner

    The modified compute shader (add support for the total number to scan).
    Code (CSharp):
    1. // https://forum.unity.com/threads/parallel-prefix-sum-computeshader.518397/#post-7887517
    2.  
    3. #define THREADS_PER_GROUP 512 // Ensure that this equals the "threadsPerGroup" variables in the host scripts using this.
    4.  
    5. int N;
    6. StructuredBuffer<uint> InputBufR;
    7. RWStructuredBuffer<uint> OutputBufW;
    8.  
    9. groupshared uint bucket[THREADS_PER_GROUP];
    10.  
    11. void Scan(uint id, uint gi, uint x)
    12. {
    13.     bucket[gi] = x;
    14.  
    15.     [unroll]
    16.     for (uint t = 1; t < THREADS_PER_GROUP; t <<= 1) {
    17.         GroupMemoryBarrierWithGroupSync();
    18.         uint temp = bucket[gi];
    19.         if (gi >= t) temp += bucket[gi - t];
    20.         GroupMemoryBarrierWithGroupSync();
    21.         bucket[gi] = temp;
    22.     }
    23.  
    24.     OutputBufW[id] = bucket[gi];
    25. }
    26.  
    27. // Perform isolated scans within each group.
    28. #pragma kernel ScanInGroupsInclusive
    29. [numthreads(THREADS_PER_GROUP, 1, 1)]
    30. void ScanInGroupsInclusive(uint id : SV_DispatchThreadID, uint gi : SV_GroupIndex)
    31. {
    32.     uint x = 0;
    33.     if ((int)id < N)
    34.         x = InputBufR[id];
    35.  
    36.     Scan(id, gi, x);
    37. }
    38.  
    39. // Perform isolated scans within each group. Shift the input so as to make the final
    40. // result (obtained after the ScanSums and AddScannedSums calls) exclusive.
    41. #pragma kernel ScanInGroupsExclusive
    42. [numthreads(THREADS_PER_GROUP, 1, 1)]
    43. void ScanInGroupsExclusive(uint id : SV_DispatchThreadID, uint gi : SV_GroupIndex)
    44. {
    45.     //uint x = (id == 0) ? 0 : InputBufR[id - 1];
    46.  
    47.     uint idx = (id - 1);
    48.     uint x = 0;
    49.     if ((int)idx >= 0 && (int)idx < N)
    50.         x = InputBufR[idx];
    51.  
    52.     Scan(id, gi, x);
    53. }
    54.  
    55. // Scan the sums of each of the groups (partial sums) from the preceding ScanInGroupsInclusive/Exclusive call.
    56. #pragma kernel ScanSums
    57. [numthreads(THREADS_PER_GROUP, 1, 1)]
    58. void ScanSums(uint id : SV_DispatchThreadID, uint gi : SV_GroupIndex)
    59. {
    60.     //uint x = (id == 0) ? 0 : InputBufR[id * THREADS_PER_GROUP - 1];
    61.  
    62.     uint idx = (id * THREADS_PER_GROUP - 1);
    63.     uint x = 0;
    64.     if ((int)idx >= 0 && (int)idx < N)
    65.         x = InputBufR[idx];
    66.  
    67.     Scan(id, gi, x);
    68. }
    69.  
    70. // Add the scanned sums to the output of the first kernel call, to get the final, complete prefix sum.
    71. #pragma kernel AddScannedSums
    72. [numthreads(THREADS_PER_GROUP, 1, 1)]
    73. void AddScannedSums(uint id : SV_DispatchThreadID, uint gid : SV_GroupID)
    74. {
    75.     if ((int)id < N)
    76.         OutputBufW[id] += InputBufR[gid];
    77. }
    The script to call the compute shader.
    Code (CSharp):
    1.  
    2. struct ScanHelper
    3. {
    4.     const int threadsPerGroup = 512;   // THREADS_PER_GROUP in ScanOperations.compute
    5.     public int size;
    6.     public List<ComputeBuffer> group_buffer;
    7.     public List<int> work_size;
    8.  
    9.     public void InclusiveScan(int num, ComputeShader scanOperations,
    10.         ComputeBuffer inputs, ComputeBuffer outputs)
    11.     {
    12.         this.RequireBuffer(num);
    13.  
    14.         // 1. Per group scan
    15.         int kernelScan = scanOperations.FindKernel("ScanInGroupsInclusive");
    16.         scanOperations.SetInt("N", num);
    17.         scanOperations.SetBuffer(kernelScan, "InputBufR", inputs);
    18.         scanOperations.SetBuffer(kernelScan, "OutputBufW", outputs);
    19.         scanOperations.Dispatch(kernelScan, NUM_GROUPS(num, threadsPerGroup), 1, 1);
    20.  
    21.         if (num < threadsPerGroup)
    22.             return;
    23.  
    24.         int kernelScanSums = scanOperations.FindKernel("ScanSums");
    25.         int kernelAdd = scanOperations.FindKernel("AddScannedSums");
    26.  
    27.         // 2. Scan per group sum
    28.         scanOperations.SetInt("N", num);
    29.         scanOperations.SetBuffer(kernelScanSums, "InputBufR", outputs);
    30.         scanOperations.SetBuffer(kernelScanSums, "OutputBufW", this.group_buffer[0]);
    31.         scanOperations.Dispatch(kernelScanSums, NUM_GROUPS(this.work_size[0], threadsPerGroup), 1, 1);
    32.  
    33.         // Continue down the pyramid
    34.         for (int l = 0; l < this.group_buffer.Count - 1; ++l)
    35.         {
    36.             int work_sz = this.work_size[l];
    37.             // 2. Scan per group sum
    38.             scanOperations.SetInt("N", work_sz);
    39.             scanOperations.SetBuffer(kernelScanSums, "InputBufR", this.group_buffer[l]);
    40.             scanOperations.SetBuffer(kernelScanSums, "OutputBufW", this.group_buffer[l+1]);
    41.             scanOperations.Dispatch(kernelScanSums, NUM_GROUPS(this.work_size[l+1], threadsPerGroup), 1, 1);
    42.         }
    43.  
    44.         for (int l = this.group_buffer.Count - 1; l > 0; --l)
    45.         {
    46.             int work_sz = this.work_size[l - 1];
    47.             // 3. Add scanned group sum
    48.             scanOperations.SetInt("N", work_sz);
    49.             scanOperations.SetBuffer(kernelAdd, "InputBufR", this.group_buffer[l]);
    50.             scanOperations.SetBuffer(kernelAdd, "OutputBufW", this.group_buffer[l - 1]);
    51.             scanOperations.Dispatch(kernelAdd, NUM_GROUPS(work_sz, threadsPerGroup), 1, 1);
    52.         }
    53.  
    54.         // 3. Add scanned group sum
    55.         scanOperations.SetInt("N", num);
    56.         scanOperations.SetBuffer(kernelAdd, "InputBufR", this.group_buffer[0]);
    57.         scanOperations.SetBuffer(kernelAdd, "OutputBufW", outputs);
    58.         scanOperations.Dispatch(kernelAdd, this.work_size[0], 1, 1);
    59.     }
    60.  
    61.     public void RequireBuffer(int alloc_sz)
    62.     {
    63.         if (this.size < alloc_sz)
    64.         {
    65.             this.Release();
    66.             this.size = (int)(alloc_sz * 1.5);
    67.             this.group_buffer = new List<ComputeBuffer>();
    68.             this.work_size = new List<int>();
    69.  
    70.             int work_sz = this.size;
    71.             while (work_sz > threadsPerGroup)
    72.             {
    73.                 work_sz = NUM_GROUPS(work_sz, threadsPerGroup);
    74.                 this.group_buffer.Add(new ComputeBuffer(work_sz, sizeof(uint)));
    75.                 this.work_size.Add(work_sz);
    76.             }
    77.         }
    78.     }
    79.  
    80.     public void Release()
    81.     {
    82.         if (group_buffer != null)
    83.         {
    84.             foreach (ComputeBuffer buffer in group_buffer)
    85.                 if (buffer != null)
    86.                     buffer.Dispose();
    87.             group_buffer = null;
    88.         }
    89.     }
    90. }
    91.  
    92.  
    93. [SerializeField] ComputeShader scanOperations;
    94. ScanHelper mScanHelper;
    95.  
    96. mScanHelper.InclusiveScan(N, scanOperations, inputs, outputs);
    - Make sure THREADS_PER_GROUP and threadsPerGroup are the same. I've tested both 512 and 1024 works.
     
    Last edited: Aug 31, 2023
    IsobelShasha likes this.