Search Unity

  1. Megacity Metro Demo now available. Download now.
    Dismiss Notice
  2. Unity support for visionOS is now available. Learn more in our blog post.
    Dismiss Notice

Resolved Not getting expected speedups with explicit SIMD?

Discussion in 'Burst' started by vectorized-runner, Apr 25, 2022.

  1. vectorized-runner

    vectorized-runner

    Joined:
    Jan 22, 2018
    Posts:
    396
    Hi, I've just started learning writing explicit SIMD with Burst. In the code, I'm doing the simple N sphere vs 1 sphere collision test, and counting the number of collisions. The performance speed up seems to be only 10% when compared to everyday coding, also v128 and v256 performance is the same. I don't know why.

    I'm using Burst 1.7.1, on Unity 2020.3.30f1, safety checks and leak detection disabled.

    I've read the generated assembly a bit, which I can't judge much, but I see the 'ps' xmm registers for the v128 version, and the ymm registers on the v256 version.

    Am I using a wrong example, or is there an other performance bottleneck in my code?

    Also ExpectVectorized gives compile error saying it isn't vectorized, even in the SIMD case.

    Here's the code:
    Code (CSharp):
    1. using System;
    2. using System.Runtime.CompilerServices;
    3. using Unity.Burst;
    4. using Unity.Burst.Intrinsics;
    5. using Unity.Mathematics;
    6. using UnityEngine;
    7. using static Unity.Burst.Intrinsics.X86.Sse;
    8. using static Unity.Burst.Intrinsics.X86.Avx;
    9. using static Unity.Burst.Intrinsics.X86.Avx2;
    10.  
    11. namespace BurstSIMDPractice
    12. {
    13.     public struct Sphere
    14.     {
    15.         public float3 Position;
    16.         public float Radius;
    17.     }
    18.  
    19.     // Task: Given a list of spheres and another sphere, compute the index of the first sphere that overlaps with the additional
    20.     // sphere, and the number of intersections in total.
    21.     [BurstCompile]
    22.     public static unsafe class SphereCollisions
    23.     {
    24.         [BurstCompile]
    25.         public static void SphereCollisionSIMD_v256([NoAlias] float* sphereXs, [NoAlias] float* sphereYs,
    26.                                                     [NoAlias] float* sphereZs, [NoAlias] float* sphereRadii,
    27.                                                     [NoAlias] Sphere* testSphere, int sphereCount,
    28.                                                     out int intersectionCount)
    29.         {
    30.             const int v256Size = 8;
    31.  
    32.             if(sphereCount % v256Size != 0)
    33.                 throw new NotImplementedException($"SphereCount needs to be a multiple of {v256Size}.");
    34.  
    35.             intersectionCount = 0;
    36.          
    37.             for(int i = 0; i < sphereCount; i += v256Size)
    38.             {
    39.                 // Unity.Burst.CompilerServices.Loop.ExpectVectorized();
    40.              
    41.                 var xs = Reinterpret<v256>(sphereXs + i);
    42.                 var ys = Reinterpret<v256>(sphereYs + i);
    43.                 var zs = Reinterpret<v256>(sphereZs + i);
    44.  
    45.                 var xDiffs = mm256_sub_ps(xs, new v256(testSphere->Position.x));
    46.                 var yDiffs = mm256_sub_ps(ys, new v256(testSphere->Position.y));
    47.                 var zDiffs = mm256_sub_ps(zs, new v256(testSphere->Position.z));
    48.  
    49.                 var xDiffSq = mm256_mul_ps(xDiffs, xDiffs);
    50.                 var yDiffSq = mm256_mul_ps(yDiffs, yDiffs);
    51.                 var zDiffSq = mm256_mul_ps(zDiffs, zDiffs);
    52.  
    53.                 var distSq = mm256_add_ps(xDiffSq, mm256_add_ps(yDiffSq, zDiffSq));
    54.  
    55.                 var radii = Reinterpret<v256>(sphereRadii + i);
    56.                 var radiiSum = mm256_add_ps(radii, new v256(testSphere->Radius));
    57.                 var radiiSq = mm256_mul_ps(radiiSum, radiiSum);
    58.  
    59.                 var collisionMask = new v256(cmple_ps(distSq.Lo128, radiiSq.Lo128), cmple_ps(distSq.Hi128, radiiSq.Hi128));
    60.                 intersectionCount += GetNonZeroCount(collisionMask);
    61.             }
    62.         }
    63.      
    64.         [BurstCompile]
    65.         public static void SphereCollisionSIMD_v128([NoAlias] float* sphereXs, [NoAlias] float* sphereYs,
    66.                                                     [NoAlias] float* sphereZs, [NoAlias] float* sphereRadii,
    67.                                                     [NoAlias] Sphere* testSphere, int sphereCount,
    68.                                                     out int intersectionCount)
    69.         {
    70.             const int v128Size = 4;
    71.  
    72.             if(sphereCount % v128Size != 0)
    73.                 throw new NotImplementedException($"SphereCount needs to be a multiple of {v128Size}.");
    74.  
    75.             intersectionCount = 0;
    76.  
    77.             for(int i = 0; i < sphereCount; i += v128Size)
    78.             {
    79.                 // Unity.Burst.CompilerServices.Loop.ExpectVectorized();
    80.              
    81.                 var xs = Reinterpret<v128>(sphereXs + i);
    82.                 var ys = Reinterpret<v128>(sphereYs + i);
    83.                 var zs = Reinterpret<v128>(sphereZs + i);
    84.  
    85.                 var xDiffs = sub_ps(xs, new v128(testSphere->Position.x));
    86.                 var yDiffs = sub_ps(ys, new v128(testSphere->Position.y));
    87.                 var zDiffs = sub_ps(zs, new v128(testSphere->Position.z));
    88.  
    89.                 var xDiffSq = mul_ps(xDiffs, xDiffs);
    90.                 var yDiffSq = mul_ps(yDiffs, yDiffs);
    91.                 var zDiffSq = mul_ps(zDiffs, zDiffs);
    92.  
    93.                 var distSq = add_ps(xDiffSq, add_ps(yDiffSq, zDiffSq));
    94.  
    95.                 var radii = Reinterpret<v128>(sphereRadii + i);
    96.                 var radiiSum = add_ps(radii, new v128(testSphere->Radius));
    97.                 var radiiSq = mul_ps(radiiSum, radiiSum);
    98.  
    99.                 // If distance squared is less than radius squared, then collide
    100.                 var collisionMask = cmple_ps(distSq, radiiSq);
    101.                 intersectionCount += GetNonZeroCount(collisionMask);
    102.             }
    103.         }
    104.  
    105.         [MethodImpl(MethodImplOptions.AggressiveInlining)]
    106.         static int BoolToInt(bool value)
    107.         {
    108.             return *(byte*)&value;
    109.         }
    110.      
    111.         [MethodImpl(MethodImplOptions.AggressiveInlining)]
    112.         static int GetNonZeroCount(v256 v)
    113.         {
    114.             return
    115.                 BoolToInt(v.SInt0 != 0) +
    116.                 BoolToInt(v.SInt1 != 0) +
    117.                 BoolToInt(v.SInt2 != 0) +
    118.                 BoolToInt(v.SInt3 != 0) +
    119.                 BoolToInt(v.SInt4 != 0) +
    120.                 BoolToInt(v.SInt5 != 0) +
    121.                 BoolToInt(v.SInt6 != 0) +
    122.                 BoolToInt(v.SInt7 != 0);
    123.         }
    124.  
    125.         [MethodImpl(MethodImplOptions.AggressiveInlining)]
    126.         static int GetNonZeroCount(v128 v)
    127.         {
    128.             return
    129.                 BoolToInt(v.SInt0 != 0) +
    130.                 BoolToInt(v.SInt1 != 0) +
    131.                 BoolToInt(v.SInt2 != 0) +
    132.                 BoolToInt(v.SInt3 != 0);
    133.         }
    134.  
    135.         [MethodImpl(MethodImplOptions.AggressiveInlining)]
    136.         static T Reinterpret<T>(void* ptr) where T : unmanaged
    137.         {
    138.             return *(T*)ptr;
    139.         }
    140.  
    141.         [BurstCompile]
    142.         public static void SphereCollisionNoBranch([NoAlias] Sphere* sphereToTest, [NoAlias] Sphere* spheres,
    143.                                                    int sphereCount, out int intersectionCount)
    144.         {
    145.             intersectionCount = 0;
    146.             var spherePositionTest = sphereToTest->Position;
    147.             var sphereRadiusTest = sphereToTest->Radius;
    148.  
    149.             for(int i = 0; i < sphereCount; i++)
    150.             {
    151.                 Unity.Burst.CompilerServices.Loop.ExpectNotVectorized();
    152.              
    153.                 var spherePositionA = spheres[i].Position;
    154.                 var sphereRadiusA = spheres[i].Radius;
    155.                 var distanceSq = math.distancesq(spherePositionA, spherePositionTest);
    156.                 var radius = sphereRadiusA + sphereRadiusTest;
    157.                 var radiusSq = radius * radius;
    158.                 var inRadius = distanceSq < radiusSq;
    159.                 intersectionCount += BoolToInt(inRadius);
    160.             }
    161.         }
    162.  
    163.         [BurstCompile]
    164.         public static void SphereCollisionDefault([NoAlias] Sphere* sphereToTest, [NoAlias] Sphere* spheres,
    165.                                                   int sphereCount, out int intersectionCount)
    166.         {
    167.             intersectionCount = 0;
    168.  
    169.             var spherePositionTest = sphereToTest->Position;
    170.             var sphereRadiusTest = sphereToTest->Radius;
    171.  
    172.             for(int i = 0; i < sphereCount; i++)
    173.             {
    174.                 Unity.Burst.CompilerServices.Loop.ExpectNotVectorized();
    175.              
    176.                 var spherePositionA = spheres[i].Position;
    177.                 var sphereRadiusA = spheres[i].Radius;
    178.                 var distanceSq = math.distancesq(spherePositionA, spherePositionTest);
    179.                 var radius = sphereRadiusA + sphereRadiusTest;
    180.                 var radiusSq = radius * radius;
    181.  
    182.                 if(distanceSq < radiusSq)
    183.                 {
    184.                     intersectionCount++;
    185.                 }
    186.             }
    187.         }
    188.  
    189.     }
    190. }
    Here's the benchmark code:
    Code (CSharp):
    1. using NUnit.Framework;
    2. using Unity.Collections;
    3. using Unity.Collections.LowLevel.Unsafe;
    4. using Unity.PerformanceTesting;
    5. using UnityEngine;
    6. using Random = Unity.Mathematics.Random;
    7.  
    8. namespace BurstSIMDPractice.Tests
    9. {
    10.     public static class Util
    11.     {
    12.         public static unsafe T* GetUnsafePtrCast<T>(this NativeArray<T> array) where T : unmanaged
    13.         {
    14.             return (T*)array.GetUnsafePtr();
    15.         }
    16.     }
    17.  
    18.     public static unsafe class SphereCollisionBenchmarks
    19.     {
    20.         const int SphereCount = 100_000;
    21.         const uint Seed = 132438297;
    22.         const int WarmupCount = 10;
    23.         const int IterationCount = 10;
    24.         const int MeasurementCount = 10;
    25.      
    26.         [Test, Performance]
    27.         public static void SphereCollisionSIMDv256Benchmark()
    28.         {
    29.             Random random;
    30.             NativeArray<float> sphereXs = default;
    31.             NativeArray<float> sphereYs = default;
    32.             NativeArray<float> sphereZs = default;
    33.             NativeArray<float> sphereRadii = default;
    34.             NativeArray<Sphere> testSphere = default;
    35.  
    36.             Measure.Method(
    37.                        () =>
    38.                        {
    39.                            SphereCollisions.SphereCollisionSIMD_v256(
    40.                                sphereXs.GetUnsafePtrCast(),
    41.                                sphereYs.GetUnsafePtrCast(),
    42.                                sphereZs.GetUnsafePtrCast(),
    43.                                sphereRadii.GetUnsafePtrCast(),
    44.                                testSphere.GetUnsafePtrCast(),
    45.                                SphereCount,
    46.                                out var intersectionCount);
    47.  
    48.                            Debug.Log($"SphereCollisionSIMDv256 IntersectionCount: {intersectionCount}");
    49.                        })
    50.                    .WarmupCount(WarmupCount)
    51.                    .IterationsPerMeasurement(IterationCount)
    52.                    .MeasurementCount(MeasurementCount)
    53.                    .SetUp(() =>
    54.                    {
    55.                        random = new Random(Seed);
    56.                        sphereXs = new NativeArray<float>(SphereCount, Allocator.Temp);
    57.                        sphereYs = new NativeArray<float>(SphereCount, Allocator.Temp);
    58.                        sphereZs = new NativeArray<float>(SphereCount, Allocator.Temp);
    59.                        sphereRadii = new NativeArray<float>(SphereCount, Allocator.Temp);
    60.                    
    61.                        for(int i = 0; i < SphereCount; i++)
    62.                        {
    63.                            var position = random.NextFloat3();
    64.                            var radius = random.NextFloat();
    65.  
    66.                            sphereXs[i] = position.x;
    67.                            sphereYs[i] = position.y;
    68.                            sphereZs[i] = position.z;
    69.                            sphereRadii[i] = radius;
    70.                        }
    71.  
    72.                        testSphere = new NativeArray<Sphere>(1, Allocator.Temp);
    73.                        testSphere[0] = new Sphere
    74.                        {
    75.                            Position = random.NextFloat3(),
    76.                            Radius = random.NextFloat(),
    77.                        };
    78.                    })
    79.                    .Run();
    80.         }
    81.      
    82.         [Test, Performance]
    83.         public static void SphereCollisionSIMDv128Benchmark()
    84.         {
    85.             Random random;
    86.             NativeArray<float> sphereXs = default;
    87.             NativeArray<float> sphereYs = default;
    88.             NativeArray<float> sphereZs = default;
    89.             NativeArray<float> sphereRadii = default;
    90.             NativeArray<Sphere> testSphere = default;
    91.  
    92.             Measure.Method(
    93.                        () =>
    94.                        {
    95.                            SphereCollisions.SphereCollisionSIMD_v128(
    96.                                sphereXs.GetUnsafePtrCast(),
    97.                                sphereYs.GetUnsafePtrCast(),
    98.                                sphereZs.GetUnsafePtrCast(),
    99.                                sphereRadii.GetUnsafePtrCast(),
    100.                                testSphere.GetUnsafePtrCast(),
    101.                                SphereCount,
    102.                                out var intersectionCount);
    103.  
    104.                            Debug.Log($"SphereCollisionSIMDv128 IntersectionCount: {intersectionCount}");
    105.                        })
    106.                    .WarmupCount(WarmupCount)
    107.                    .IterationsPerMeasurement(IterationCount)
    108.                    .MeasurementCount(MeasurementCount)
    109.                    .SetUp(() =>
    110.                    {
    111.                        random = new Random(Seed);
    112.                        sphereXs = new NativeArray<float>(SphereCount, Allocator.Temp);
    113.                        sphereYs = new NativeArray<float>(SphereCount, Allocator.Temp);
    114.                        sphereZs = new NativeArray<float>(SphereCount, Allocator.Temp);
    115.                        sphereRadii = new NativeArray<float>(SphereCount, Allocator.Temp);
    116.                    
    117.                        for(int i = 0; i < SphereCount; i++)
    118.                        {
    119.                            var position = random.NextFloat3();
    120.                            var radius = random.NextFloat();
    121.  
    122.                            sphereXs[i] = position.x;
    123.                            sphereYs[i] = position.y;
    124.                            sphereZs[i] = position.z;
    125.                            sphereRadii[i] = radius;
    126.                        }
    127.  
    128.                        testSphere = new NativeArray<Sphere>(1, Allocator.Temp);
    129.                        testSphere[0] = new Sphere
    130.                        {
    131.                            Position = random.NextFloat3(),
    132.                            Radius = random.NextFloat(),
    133.                        };
    134.                    })
    135.                    .Run();
    136.         }
    137.      
    138.         [Test, Performance]
    139.         public static void SphereCollisionDefaultBenchmark()
    140.         {
    141.             Random random;
    142.             NativeArray<Sphere> spheres = default;
    143.             NativeArray<Sphere> testSphere = default;
    144.  
    145.             Measure.Method(
    146.                        () =>
    147.                        {
    148.                            SphereCollisions.SphereCollisionDefault(
    149.                                testSphere.GetUnsafePtrCast(),
    150.                                spheres.GetUnsafePtrCast(),
    151.                                SphereCount,
    152.                                out var intersectionCount);
    153.  
    154.                            Debug.Log($"SphereCollisionDefault IntersectionCount: {intersectionCount}");
    155.                        })
    156.                    .WarmupCount(WarmupCount)
    157.                    .IterationsPerMeasurement(IterationCount)
    158.                    .MeasurementCount(MeasurementCount)
    159.                    .SetUp(() =>
    160.                    {
    161.                        random = new Random(Seed);
    162.                        spheres = new NativeArray<Sphere>(SphereCount, Allocator.Temp);
    163.                        for(int i = 0; i < spheres.Length; i++)
    164.                        {
    165.                            spheres[i] = new Sphere
    166.                            {
    167.                                Position = random.NextFloat3(),
    168.                                Radius = random.NextFloat(),
    169.                            };
    170.                        }
    171.  
    172.                        testSphere = new NativeArray<Sphere>(1, Allocator.Temp);
    173.                        testSphere[0] = new Sphere
    174.                        {
    175.                            Position = random.NextFloat3(),
    176.                            Radius = random.NextFloat(),
    177.                        };
    178.                    })
    179.                    .Run();
    180.         }
    181.  
    182.         [Test, Performance]
    183.         public static void SphereCollisionNoBranchBenchmark()
    184.         {
    185.             Random random;
    186.             NativeArray<Sphere> spheres = default;
    187.             NativeArray<Sphere> testSphere = default;
    188.  
    189.             Measure.Method(
    190.                        () =>
    191.                        {
    192.                            SphereCollisions.SphereCollisionNoBranch(
    193.                                testSphere.GetUnsafePtrCast(),
    194.                                spheres.GetUnsafePtrCast(),
    195.                                SphereCount,
    196.                                out var intersectionCount);
    197.  
    198.                            Debug.Log($"SphereCollisionNoBranch IntersectionCount: {intersectionCount}");
    199.                        })
    200.                    .WarmupCount(WarmupCount)
    201.                    .IterationsPerMeasurement(IterationCount)
    202.                    .MeasurementCount(MeasurementCount)
    203.                    .SetUp(() =>
    204.                    {
    205.                        random = new Random(Seed);
    206.                        spheres = new NativeArray<Sphere>(SphereCount, Allocator.Temp);
    207.                        for(int i = 0; i < spheres.Length; i++)
    208.                        {
    209.                            spheres[i] = new Sphere
    210.                            {
    211.                                Position = random.NextFloat3(),
    212.                                Radius = random.NextFloat(),
    213.                            };
    214.                        }
    215.  
    216.                        testSphere = new NativeArray<Sphere>(1, Allocator.Temp);
    217.                        testSphere[0] = new Sphere
    218.                        {
    219.                            Position = random.NextFloat3(),
    220.                            Radius = random.NextFloat(),
    221.                        };
    222.                    })
    223.                    .Run();
    224.         }
    225.     }
    226. }
    Here's the benchmark results:
    upload_2022-4-25_22-14-24.png
     

    Attached Files:

    Last edited: Apr 25, 2022
  2. Trindenberg

    Trindenberg

    Joined:
    Dec 3, 2017
    Posts:
    395
    I also couldnt figure out vectorisation, takes some time understanding, but what i would try here, is put your x y z in sequence like xs, xdiffs, xdiffsq rather than xs ys zs, xdiffs, ydiffs...

    Be interested to know if that improves anything
     
  3. vectorized-runner

    vectorized-runner

    Joined:
    Jan 22, 2018
    Posts:
    396
    No, it doesn't have any difference. It would be pretty unexpected for Burst to not be able to figure that out tbh
     
  4. DreamingImLatios

    DreamingImLatios

    Joined:
    Jun 3, 2017
    Posts:
    4,223
    You might be memory-bound. Also the way you are summing might be tripping Burst up. You either want a separate count per lane that you add with the mask & 0x1 or use movmsk and popcnt.

    That applies to auto-vectorization. Explicit vectorization will cause that function to always return false.
     
  5. vectorized-runner

    vectorized-runner

    Joined:
    Jan 22, 2018
    Posts:
    396
    Can you show a simple example for this? I couldn't figure out how to sum using the SIMD registers.

    Also how could I make this problem less memory bound?

    Cool, I didn't know that.
     
  6. vectorized-runner

    vectorized-runner

    Joined:
    Jan 22, 2018
    Posts:
    396
    Ok, so I've tested on a better machine and results are very different, v128 is 3.5x faster and v256 is 4x faster.

    So I guess it depends on the pc, if it can utilize the registers then performance will be better.