cancel
Showing results for 
Search instead for 
Did you mean: 

Archives Discussions

BarnacleJunior
Journeyman III

D3D11 cs_5_0 radix sort performance - how to improve?

I'm reposting this from an XNA forum post I made.  I'm trying to get a respectable radix sort, but am seeing numbers well below what I think the hardware is capable of.

 

I benchmarked my own cs_5_0 radix sort to death this weekend.  It is quite difficult to write that functhion, as performance is extremely variable with the choice of the number of threads per threadgroup, values per thread, optimal size of the sort digit in bits, etc.  The Garland paper http://mgarland.org/files/papers/gpusort-ipdps09.pdf reports a throughput of 140 million pairs/sec on GTX 280 with arrays of 4 million elements.  The best I was able to do on my HD5850, which is a much faster card (although I admit I don't understand how its integer performance compares) is 64 million pairs/sec.  Garland was able to get that same throughput even on an old 8800 Ultra.  What numbers have you MS guys gotten on this important function?  I really don't know how to bring the numbers up.  I'm unsure how to resolve local data store bank conflicts on Cypress (if that is the culprit).  I've tried perfect hashes (eg ((31 & i)<< 5) | (i>> 5)) on a 10bit index) but it doesn't do anything.  Are there any profiling tools to help me out, or a reference cs_5_0 sort I could get my hands on?  I can clean up my own prefix sum and radix sort benchmark code a bit and distribute it here if that would be helpful.

Anyway, here are the benchmarks I ran last night.  Each sort was run 200 times on 1<<22 sized arrays.  The elements are uint2s, which I've filled with random numbers, and I'm ping-ponging between sorting the .x and .y components.  The peak throughput is for 5 bit digits, 64 threads, and 8 values per thread.  I'm really at a loss why 16 values per thread isn't better in almost every case, as you can increase the amount of sequential work by doing that without increasing the number of barriers or anything else..

.sean


digitSize=4. numThreads=16. valuesPerThread=1. Sorted 10.194M words/sec.
digitSize=4. numThreads=16. valuesPerThread=2. Sorted 17.724M words/sec.
digitSize=4. numThreads=16. valuesPerThread=4. Sorted 27.289M words/sec.
digitSize=4. numThreads=16. valuesPerThread=8. Sorted 36.673M words/sec.
digitSize=4. numThreads=16. valuesPerThread=16. Sorted 35.438M words/sec.
digitSize=4. numThreads=32. valuesPerThread=1. Sorted 16.690M words/sec.
digitSize=4. numThreads=32. valuesPerThread=2. Sorted 27.379M words/sec.
digitSize=4. numThreads=32. valuesPerThread=4. Sorted 40.724M words/sec.
digitSize=4. numThreads=32. valuesPerThread=8. Sorted 49.755M words/sec.
digitSize=4. numThreads=32. valuesPerThread=16. Sorted 45.398M words/sec.
digitSize=4. numThreads=64. valuesPerThread=1. Sorted 26.610M words/sec.
digitSize=4. numThreads=64. valuesPerThread=2. Sorted 42.152M words/sec.
digitSize=4. numThreads=64. valuesPerThread=4. Sorted 54.474M words/sec.
digitSize=4. numThreads=64. valuesPerThread=8. Sorted 57.046M words/sec.
digitSize=4. numThreads=64. valuesPerThread=16. Sorted 48.454M words/sec.
digitSize=4. numThreads=128. valuesPerThread=1. Sorted 26.931M words/sec.
digitSize=4. numThreads=128. valuesPerThread=2. Sorted 36.905M words/sec.
digitSize=4. numThreads=128. valuesPerThread=4. Sorted 42.631M words/sec.
digitSize=4. numThreads=128. valuesPerThread=8. Sorted 49.923M words/sec.
digitSize=4. numThreads=128. valuesPerThread=16. Sorted 38.714M words/sec.
digitSize=4. numThreads=256. valuesPerThread=1. Sorted 16.066M words/sec.
digitSize=4. numThreads=256. valuesPerThread=2. Sorted 27.004M words/sec.
digitSize=4. numThreads=256. valuesPerThread=4. Sorted 39.192M words/sec.
digitSize=4. numThreads=256. valuesPerThread=8. Sorted 45.952M words/sec.
digitSize=4. numThreads=512. valuesPerThread=1. Sorted 8.868M words/sec.
digitSize=4. numThreads=512. valuesPerThread=2. Sorted 14.117M words/sec.
digitSize=4. numThreads=512. valuesPerThread=4. Sorted 26.855M words/sec.
digitSize=4. numThreads=1024. valuesPerThread=1. Sorted 7.084M words/sec.
digitSize=4. numThreads=1024. valuesPerThread=2. Sorted 9.924M words/sec.
digitSize=5. numThreads=32. valuesPerThread=1. Sorted 9.296M words/sec.
digitSize=5. numThreads=32. valuesPerThread=2. Sorted 17.073M words/sec.
digitSize=5. numThreads=32. valuesPerThread=4. Sorted 25.587M words/sec.
digitSize=5. numThreads=32. valuesPerThread=8. Sorted 52.341M words/sec.
digitSize=5. numThreads=32. valuesPerThread=16. Sorted 47.432M words/sec.
digitSize=5. numThreads=64. valuesPerThread=1. Sorted 16.654M words/sec.
digitSize=5. numThreads=64. valuesPerThread=2. Sorted 26.697M words/sec.
digitSize=5. numThreads=64. valuesPerThread=4. Sorted 62.983M words/sec.
digitSize=5. numThreads=64. valuesPerThread=8. Sorted 64.209M words/sec.
digitSize=5. numThreads=64. valuesPerThread=16. Sorted 55.197M words/sec.
digitSize=5. numThreads=128. valuesPerThread=1. Sorted 21.765M words/sec.
digitSize=5. numThreads=128. valuesPerThread=2. Sorted 40.899M words/sec.
digitSize=5. numThreads=128. valuesPerThread=4. Sorted 46.560M words/sec.
digitSize=5. numThreads=128. valuesPerThread=8. Sorted 56.489M words/sec.
digitSize=5. numThreads=128. valuesPerThread=16. Sorted 43.941M words/sec.
digitSize=5. numThreads=256. valuesPerThread=1. Sorted 16.975M words/sec.
digitSize=5. numThreads=256. valuesPerThread=2. Sorted 28.706M words/sec.
digitSize=5. numThreads=256. valuesPerThread=4. Sorted 43.195M words/sec.
digitSize=5. numThreads=256. valuesPerThread=8. Sorted 52.469M words/sec.
digitSize=5. numThreads=512. valuesPerThread=1. Sorted 9.505M words/sec.
digitSize=5. numThreads=512. valuesPerThread=2. Sorted 15.017M words/sec.
digitSize=5. numThreads=512. valuesPerThread=4. Sorted 29.752M words/sec.
digitSize=5. numThreads=1024. valuesPerThread=1. Sorted 7.518M words/sec.
digitSize=5. numThreads=1024. valuesPerThread=2. Sorted 10.778M words/sec.
digitSize=6. numThreads=64. valuesPerThread=1. Sorted 8.046M words/sec.
digitSize=6. numThreads=64. valuesPerThread=2. Sorted 16.634M words/sec.
digitSize=6. numThreads=64. valuesPerThread=4. Sorted 31.434M words/sec.
digitSize=6. numThreads=64. valuesPerThread=8. Sorted 45.359M words/sec.
digitSize=6. numThreads=64. valuesPerThread=16. Sorted 53.156M words/sec.
digitSize=6. numThreads=128. valuesPerThread=1. Sorted 16.062M words/sec.
digitSize=6. numThreads=128. valuesPerThread=2. Sorted 24.877M words/sec.
digitSize=6. numThreads=128. valuesPerThread=4. Sorted 46.939M words/sec.
digitSize=6. numThreads=128. valuesPerThread=8. Sorted 56.465M words/sec.
digitSize=6. numThreads=128. valuesPerThread=16. Sorted 45.396M words/sec.
digitSize=6. numThreads=256. valuesPerThread=1. Sorted 17.468M words/sec.
digitSize=6. numThreads=256. valuesPerThread=2. Sorted 28.950M words/sec.
digitSize=6. numThreads=256. valuesPerThread=4. Sorted 43.296M words/sec.
digitSize=6. numThreads=256. valuesPerThread=8. Sorted 53.728M words/sec.
digitSize=6. numThreads=512. valuesPerThread=1. Sorted 9.217M words/sec.
digitSize=6. numThreads=512. valuesPerThread=2. Sorted 14.310M words/sec.
digitSize=6. numThreads=512. valuesPerThread=4. Sorted 30.464M words/sec.
digitSize=6. numThreads=1024. valuesPerThread=1. Sorted 7.987M words/sec.
digitSize=6. numThreads=1024. valuesPerThread=2. Sorted 11.530M words/sec.

0 Likes
6 Replies

BarnacleJunior,
For dealing with bank conflicts, make sure no threads that execute in the same cycle are accessing data on the same 32dword boundary since the LDS on the 5XXX series is 32 channels with each channel a single dword wide. So for example, given local int test[256]. If thread 0 accesses test[0] and thread 1 accesses test[32] you have a bank conflict. Please read this document for information on how to optimize for this series of cards.
http://developer.amd.com/gpu/A..._Performance_Notes.pdf
0 Likes

Ok..  But I'm also under the impression that if all the threads hit the same LDS offset, there is no bank conflict?  If not, that might explain the poor performance, but would also make parallel programming super difficult.

0 Likes

If the reads go to the same address in the same bank, then the read is broadcasted, otherwise it is in conflict and waterfalls. This can cause a performance hit. Basically you want to use all 32 channels in the same cycle and not have a thread access the same channel as another thread to get peak performance.
0 Likes
eduardoschardong
Journeyman III

Barnacle, the raw integer performance of 5850 is much higher than GTX280 but it's likely not the bottleneck on the algorithm, it's hard to say without looking at the code but by using a radix sort I believe your code have a lot of loads and stores wich is a point where 5850 is not very impressive, you may want to look on other sorting algorithms that may perform better on this chip, bitonic sort may be a good start.

 

0 Likes

Thanks eduardo.  I'm attaching the shader code.  There is a lot of macro junk I put in for benchmarking.  The optimal path seems to be VALUES_PER_THREAD = 8 (so each thread manages two uint4) and NUM_LEVELS=6 (64 threads) with a digit size of 5 bits. For sorting key/value pairs, SortBufferType just boils down to a struct with two uints, which I store in groupshared memory.

This is my best understanding of the Garland/Harris CUDA radix sort.  The C++ code that drives this just uses a prefix sum to compute bucket offsets for pass 2.

I don't know where the bottleneck is, because almost all the work is done on groupshared memory and with the integer ALUs.  Global memory is only being hit at the beginning and end of the shader.

.sean

-- scancommon.hlsl


// Defines no shaders, only support functions.
// Including file must define VALUES_PER_THREAD and SCAN_SIZE

uint2 Inclusive2Sum(uint2 vec) {
    vec.y += vec.x;
    return vec;
}

uint4 Inclusive4Sum(uint4 vec) {
    vec.y += vec.x;
    vec.z += vec.y;
    vec.w += vec.z;
    return vec;
}

#if VALUES_PER_THREAD == 1
    #define NUM_COUNTERS 1
    #define SEQUENCE uint(0)
    typedef uint1 Counter[1];
    void InclusiveSum(Counter word, out Counter inclusive) {
        inclusive[0] = word[0];
    }
    uint HorizontalSum(Counter word) {
        return word[0].x;
    }
#elif VALUES_PER_THREAD == 2
    #define NUM_COUNTERS 1
    #define SEQUENCE uint2(0, 1)
    typedef uint2 Counter[1];
    void InclusiveSum(Counter word, out Counter inclusive) {
        inclusive[0] = Inclusive2Sum(word[0]);
    }
    uint HorizontalSum(Counter word) {
        return word[0].x + word[0].y;
    }
#elif VALUES_PER_THREAD == 4
    #define NUM_COUNTERS 1
    #define SEQUENCE uint4(0, 1, 2, 3)
    typedef uint4 Counter[1];
    void InclusiveSum(Counter word, out Counter inclusive) {
        inclusive[0] = Inclusive4Sum(word[0]);
    }
    uint HorizontalSum(Counter word) {
        uint2 xy = word[0].xy + word[0].zw;
        return xy.x + xy.y;
    }
#elif VALUES_PER_THREAD == 8
    #define NUM_COUNTERS 2
    #define SEQUENCE uint4(0, 1, 2, 3)
    typedef uint4 Counter[2];
    void InclusiveSum(Counter word, out Counter inclusive) {
        inclusive[0] = Inclusive4Sum(word[0]);
        inclusive[1] = Inclusive4Sum(word[1]) + inclusive[0].w;
    }
    uint HorizontalSum(Counter word) {
        uint4 sum = word[0] + word[1];
        uint2 xy = sum.xy + sum.zw;
        return xy.x + xy.y;
    }
#elif VALUES_PER_THREAD == 16
    #define NUM_COUNTERS 4
    #define SEQUENCE uint4(0, 1, 2, 3)
    typedef uint4 Counter[4];
    void InclusiveSum(Counter word, out Counter inclusive) {
        inclusive[0] = Inclusive4Sum(word[0]);
        inclusive[1] = Inclusive4Sum(word[1]) + inclusive[0].w;
        inclusive[2] = Inclusive4Sum(word[2]) + inclusive[1].w;
        inclusive[3] = Inclusive4Sum(word[3]) + inclusive[2].w;
    }
    uint HorizontalSum(Counter word) {
        uint4 sum = word[0] + word[1] + word[2] + word[3];
        uint2 xy = sum.xy + sum.zw;
        return xy.x + xy.y;
    }
#endif

uint Last(Counter word) {
    return word[NUM_COUNTERS - 1][3 & (VALUES_PER_THREAD - 1)];
}


///////////////////////////////////////////////////////////////////////////////////////////////////


groupshared uint sharedSum[BANK_ADDRESS(SCAN_SIZE)];

void ThreadSum(uint tid) {
    uint tid2 = BANK_ADDRESS(tid);

    [unroll]
    for(uint d = 0; d < NUM_LEVELS - 1; ++d) {
        GroupMemoryBarrierWithGroupSync();
        uint mask = (2<< d) - 1;
        uint offset = 1<< d;
        if(mask == (mask & tid))
            sharedSum[tid2] += sharedSum[BANK_ADDRESS(tid - offset)];
    }
    GroupMemoryBarrierWithGroupSync();
   
    if(0 == tid) {
        uint ai = BANK_ADDRESS(SCAN_SIZE / 2 - 1);
        uint bi = BANK_ADDRESS(SCAN_SIZE - 1);
       
        uint at = sharedSum[ai];
       
        sharedSum[ai] += sharedSum[bi];
        sharedSum[bi] += at + at;
    }
   
    [unroll]
    for(d = NUM_LEVELS - 1; d; --d) {
        GroupMemoryBarrierWithGroupSync();
        uint mask = (1<< d) - 1;
        uint offset = 1<< (d - 1);
        if(mask == (mask & tid)) {
            uint t = sharedSum[tid2];
            uint r = BANK_ADDRESS(tid - offset);
            sharedSum[tid2] += sharedSum;
            sharedSum = t;
        }
    }
    GroupMemoryBarrierWithGroupSync();
}

-- radixsort.hlsl

#define NUM_THREADS (1<< NUM_LEVELS)
#define NUM_VALUES (VALUES_PER_THREAD * NUM_THREADS)
#define NUM_BUCKETS (1<< DIGIT_SIZE)

#define SCAN_SIZE NUM_THREADS
#define BANK_ADDRESS(x) (x + (x>> 5))

#include "scancommon.hlsl"

// Performs an LSB radix sort starting from shift bit, over four bits.

cbuffer cb0 {
    uint numElements;            // total number of elements to sort
    uint numGroups;                // defines the spacing between buckets
    uint shift;                    // LSB of key to sort
    uint valuesPerGroup;
};


// key types:
// SORT_KEY_UINT
// SORT_KEY_UINT2_X
// SORT_KEY_UINT2_Y

// value types:
// SORT_VALUE_NONE
// SORT_VALUE_UINT
// SORT_VALUE_UINT_PAIR
// SORT_VALUE_UINT2


struct SortBufferKey {
    uint key;
#ifdef SORT_KEY_UINT2_X
    uint keyValue;
#elif defined(SORT_KEY_UINT2_Y)
    uint keyValue;
#endif
};

struct SortBufferValue {
#ifdef SORT_VALUE_UINT
    uint value;
#elif defined(SORT_VALUE_UINT_PAIR)
    uint2 value;
#elif defined(SORT_VALUE_UINT2)
    uint2 value;
#endif
};

#ifdef SORT_KEY_UINT

StructuredBuffer<uint> sortKeys_pass1 : register(t0);
RWStructuredBuffer<uint> sortKeys_pass2 : register(u0);

void GatherKey(uint index, out SortBufferKey key) {
    key.key = sortKeys_pass1[index];
}
void ScatterKey(uint index, SortBufferKey key) {
    sortKeys_pass2[index] = key.key;
}

#elif defined(SORT_KEY_UINT2_X)

StructuredBuffer<uint2> sortKeys_pass1 : register(t0);
RWStructuredBuffer<uint2> sortKeys_pass2 : register(u0);

void GatherKey(uint index, out SortBufferKey key) {
    key.key = sortKeys_pass1[index].x;
    key.keyValue = sortKeys_pass1[index].y;
}
void ScatterKey(uint index, SortBufferKey key) {
    sortKeys_pass2[index].x = key.key;
    sortKeys_pass2[index].y = key.keyValue;    
}

#elif defined(SORT_KEY_UINT2_Y)

StructuredBuffer<uint2> sortKeys_pass1 : register(t0);
RWStructuredBuffer<uint2> sortKeys_pass2 : register(u0);

void GatherKey(uint index, out SortBufferKey key) {
    key.key = sortKeys_pass1[index].y;
    key.keyValue = sortKeys_pass1[index].x;
}
void ScatterKey(uint index, SortBufferKey key) {
    sortKeys_pass2[index].y = key.key;
    sortKeys_pass2[index].x = key.keyValue;
}

#endif


#ifdef SORT_VALUE_NONE

#elif defined(SORT_VALUE_UINT)

StructuredBuffer<uint> sortValues_pass1 : register(t1);
RWStructuredBuffer<uint> sortValues_pass2 : register(u1);

void GatherValue(uint index, out SortBufferValue value) {
    value.value = sortValues_pass1[index];
}
void ScatterValue(uint index, SortBufferValue value) {
    sortValues_pass2[index] = value.value;
}

#elif defined(SORT_VALUE_UINT2)

StructuredBuffer<uint2> sortValues_pass1 : register(t1);
RWStructuredBuffer<uint2> sortValues_pass2 : register(u1);

void GatherValue(uint index, inout SortBufferValue value) {
    value.value = sortValues_pass1[index];
}
void ScatterValue(uint index, SortBufferValue value) {
    sortValues_pass2[index] = value.value;
}

#elif defined(SORT_VALUE_UINT_PAIR)

StructuredBuffer<uint> sortValues1_pass1 : register(t1);
StructuredBuffer<uint> sortValues2_pass1 : register(t2);
RWStructuredBuffer<uint> sortValues1_pass2 : register(u1);
RWStructuredBuffer<uint> sortValues2_pass2 : register(u2);

void GatherValue(uint index, inout SortBufferValue value) {
    value.value.x = sortValues1_pass1[index];
    value.value.y = sortValues2_pass1[index];
}
void ScatterValue(uint index, SortBufferValue value) {
    sortValues1_pass2[index] = value.value.x;
    sortValues2_pass2[index] = value.value.y;
}

#endif

struct SortBufferType {
    SortBufferKey key;
#ifndef SORT_VALUE_NONE
    SortBufferValue value;
#endif
};

SortBufferType GatherBuf(uint index) {
    SortBufferType buf;
    GatherKey(index, buf.key);
#ifndef SORT_VALUE_NONE
    GatherValue(index, buf.value);
#endif
    return buf;
}
void ScatterBuf(uint index, SortBufferType buf) {
    ScatterKey(index, buf.key);
#ifndef SORT_VALUE_NONE
    ScatterValue(index, buf.value);
#endif
}


///////////////////////////////////////////////////////////////////////////////////////////////////
// Pass 1

#define BANK_ADDRESS_2(i) (i + (i>> 4))

RWStructuredBuffer<SortBufferType> target_pass1 : register(u0);
RWStructuredBuffer<uint> bucketOffsets_pass1 : register(u1);
RWStructuredBuffer<uint> prefixSum_pass1 : register(u2);

groupshared SortBufferType sharedPairs[NUM_VALUES];

#define scan(i) sharedSum[BANK_ADDRESS(i)]

groupshared uint2 offsets[NUM_BUCKETS];
#define offset_x(i) offsets.x
#define offset_y(i) offsets
.y


///////////////////////////////////////////////////////////////////////////////////////////////////

[numthreads(NUM_THREADS, 1, 1)]
void RadixSortBlock_Pass1(uint tid : SV_GroupIndex, uint3 groupID : SV_GroupID) {
    uint index = VALUES_PER_THREAD * (tid + NUM_THREADS * groupID.x);
    
    uint i;
    
    // Gather the key/value pairs and write them to shared memory
    [unroll]
    for(i = 0; i < VALUES_PER_THREAD; ++i) {
        uint j = index + i;
        SortBufferType buf = GatherBuf(j);
        if(j >= numElements) buf.key.key = -1;
        sharedPairs[VALUES_PER_THREAD * tid + i] = buf;
    }
    
#ifdef UNROLL_DIGIT_PASS
    [unroll]
#endif
    for(uint level = 0; level < DIGIT_SIZE; ++level) {
        SortBufferType threadValues[VALUES_PER_THREAD];
        uint levelShift = shift + level;
        Counter word;
        
        [unroll]
        for(i = 0; i < VALUES_PER_THREAD; ++i) {
            threadValues = sharedPairs[VALUES_PER_THREAD * tid + i];
            word[i / 4][3 & i] = threadValues
.key.key;        
        }
        
        [unroll]
        for(i = 0; i < NUM_COUNTERS; ++i)
            word = 1 & (word>> levelShift);
        
        // Prepare an exclusive scan of partial sums to establish the false_total and
        // true_before (as per the Garland radix sort paper)
        Counter inclusive;
        InclusiveSum(word, inclusive);
        
        // put the sum of all this thread's values into the scan array.
        // This indexing gives us the last component of the last array member.
        scan(tid) = Last(inclusive);
        
        // perform the prefix sum
        ThreadSum(tid);
        
        uint total = scan(0);
        uint scan = scan(tid) - total;
        uint falseTotal = NUM_VALUES - total;

        Counter exclusive, trueBefore, falseBefore, scatter;
        [unroll]
        for(i = 0; i < NUM_COUNTERS; ++i) {
            exclusive = inclusive - word;
            trueBefore
= scan + exclusive;
            falseBefore
= VALUES_PER_THREAD * tid + SEQUENCE + 4 * i - trueBefore;
            trueBefore
+= falseTotal;
        }
        
        [unroll]
        for(i = 0; i < NUM_COUNTERS; ++i)
            scatter = word ? trueBefore : falseBefore;
        
        [unroll]
        for(i = 0; i < VALUES_PER_THREAD; ++i)
            sharedPairs[scatter[i / 4][3 & i]] = threadValues;
            
        GroupMemoryBarrierWithGroupSync();
    }    
    
    // Serialize the sorted pairs
    [unroll]
    for(i = 0; i < VALUES_PER_THREAD; ++i)
        target_pass1[index + i] = sharedPairs[VALUES_PER_THREAD * tid + i];

    // fill the offsets
    if(tid < NUM_BUCKETS) {
        offset_x(tid) = -1;
        offset_y(tid) = -1;        
    }
    GroupMemoryBarrierWithGroupSync();
    
    // Scan sharedPairs array for bucket sizes.
    uint mask = NUM_BUCKETS - 1;
    
    [unroll]
    for(i = 0; i < VALUES_PER_THREAD; ++i) {
        uint effectiveTid = tid + NUM_THREADS * i;
        if(effectiveTid > 0) {
            uint right = mask & (sharedPairs[effectiveTid].key.key>> shift);
            uint left = mask & (sharedPairs[effectiveTid - 1].key.key>> shift);
            if(right != left) {
                // This is the first time we have encountered the right value.
                offset_x(right) = effectiveTid;
                offset_y(left) = right;
            }
        }
    }
    GroupMemoryBarrierWithGroupSync();

    if(tid < NUM_BUCKETS) {
        uint2 interval = uint2(offset_x(tid), offset_y(tid));
        if(any(uint2(-1, -1) != interval)) {
            if(-1 == interval.x) interval.x = 0;
            interval.y = (-1 == interval.y) ? NUM_VALUES : offset_x(interval.y);
        }
        uint count = interval.y - interval.x;
                
        // Serialize the bucket counts
        // We have to write the buckets into a column-major matrix, in order
        // to allow the prefix sum to scan them correctly.
        bucketOffsets_pass1[NUM_BUCKETS * groupID.x + tid] = (0xffff & interval.x) | (count<< 16);
        prefixSum_pass1[tid * numGroups + groupID.x] = count;
    }
}



///////////////////////////////////////////////////////////////////////////////////////////////////
// Pass 2 - NUM_BUCKETS depends on DIGIT_SIZE

StructuredBuffer<SortBufferType> source_pass2 : register(t0);
StructuredBuffer<uint> bucketOffsets_pass2 : register(t1);
StructuredBuffer<uint> prefixSum_pass2 : register(t2);

[numthreads(NUM_BUCKETS, 1, 1)]
void RadixSortBlock_Pass2(uint tid : SV_GroupIndex, uint3 groupID : SV_GroupID) {
    uint source = valuesPerGroup * groupID.x;
    
    uint bucket = bucketOffsets_pass2[NUM_BUCKETS * groupID.x + tid];
    uint offset = 0xffff & bucket;
    uint count = (bucket>> 16);
    
    uint target = prefixSum_pass2[tid * numGroups + groupID.x];
    
    // stream the values from each bucket
    for(uint i = 0; i < count; ++i)
        ScatterBuf(target + i, source_pass2[source + offset + i]);
}

 

0 Likes

Posted on the other topic, it's all the same problem in the end.

0 Likes