Radix Sort
GPU radix sort for u32 arrays
Quick Start
import { createRadixSort } from './webgpu-market/radix-sort/radix-sort';
const sorter = createRadixSort(device, { maxElements: 1_000_000 });
// Sort a buffer in-place
sorter.sort(myBuffer, elementCount);
sorter.destroy(); Source
// GPU Radix Sort — 4-bit radix (16 buckets), 8 passes for 32-bit keys
//
// Three kernels dispatched per pass:
// 1. histogram — count elements per bucket per workgroup
// 2. prefix_sum — exclusive scan over histograms for global offsets
// 3. scatter — place each element at its sorted position
//
// Uses ping-pong buffers: input and output swap each pass.
const RADIX_BITS: u32 = 4u;
const NUM_BUCKETS: u32 = 16u; // 2^RADIX_BITS
const WORKGROUP_SIZE: u32 = 256u;
struct Uniforms {
element_count: u32,
bit_offset: u32, // 0, 4, 8, 12, 16, 20, 24, 28
workgroup_count: u32,
}
@group(0) @binding(0) var<uniform> u: Uniforms;
@group(0) @binding(1) var<storage, read> input: array<u32>;
@group(0) @binding(2) var<storage, read_write> output: array<u32>;
@group(0) @binding(3) var<storage, read_write> histograms: array<u32>;
// ============================================================
// Kernel 1: Histogram
// Each workgroup counts how many of its elements fall into each
// of the 16 buckets for the current radix digit.
// ============================================================
var<workgroup> local_histogram: array<atomic<u32>, 16>;
@compute @workgroup_size(256)
fn histogram(
@builtin(global_invocation_id) gid: vec3u,
@builtin(workgroup_id) wid: vec3u,
@builtin(local_invocation_id) lid: vec3u,
) {
// Clear shared histogram
if (lid.x < NUM_BUCKETS) {
atomicStore(&local_histogram[lid.x], 0u);
}
workgroupBarrier();
// Count elements in this workgroup's range
let idx = gid.x;
if (idx < u.element_count) {
let digit = (input[idx] >> u.bit_offset) & (NUM_BUCKETS - 1u);
atomicAdd(&local_histogram[digit], 1u);
}
workgroupBarrier();
// Write workgroup histogram to global memory
// Layout: histograms[bucket * workgroup_count + workgroup_id]
if (lid.x < NUM_BUCKETS) {
let global_idx = lid.x * u.workgroup_count + wid.x;
histograms[global_idx] = atomicLoad(&local_histogram[lid.x]);
}
}
// ============================================================
// Kernel 2: Prefix Sum (exclusive scan)
// Computes global offsets from the per-workgroup histograms.
// Single workgroup scans the entire histogram array.
//
// The histogram is laid out as [bucket][workgroup], so scanning
// it linearly gives us stable sort order — elements from earlier
// workgroups land before elements from later workgroups within
// the same bucket.
// ============================================================
var<workgroup> scan_scratch: array<u32, 256>;
@compute @workgroup_size(256)
fn prefix_sum(
@builtin(local_invocation_id) lid: vec3u,
) {
let total_entries = NUM_BUCKETS * u.workgroup_count;
// Process multiple elements per thread if needed
// We do a sequential scan per thread, then a workgroup-level scan
// of the per-thread totals, then add the offsets back.
let chunk_size = (total_entries + WORKGROUP_SIZE - 1u) / WORKGROUP_SIZE;
let start = lid.x * chunk_size;
let end = min(start + chunk_size, total_entries);
// Phase 1: Sequential scan within each thread's chunk
var thread_total = 0u;
for (var i = start; i < end; i++) {
let val = histograms[i];
histograms[i] = thread_total;
thread_total += val;
}
scan_scratch[lid.x] = thread_total;
workgroupBarrier();
// Phase 2: Exclusive scan of thread totals (Blelloch-style)
// Up-sweep
for (var stride = 1u; stride < WORKGROUP_SIZE; stride *= 2u) {
let idx = (lid.x + 1u) * stride * 2u - 1u;
if (idx < WORKGROUP_SIZE) {
scan_scratch[idx] += scan_scratch[idx - stride];
}
workgroupBarrier();
}
if (lid.x == 0u) {
scan_scratch[WORKGROUP_SIZE - 1u] = 0u;
}
workgroupBarrier();
// Down-sweep
for (var stride = WORKGROUP_SIZE / 2u; stride >= 1u; stride /= 2u) {
let idx = (lid.x + 1u) * stride * 2u - 1u;
if (idx < WORKGROUP_SIZE) {
let temp = scan_scratch[idx - stride];
scan_scratch[idx - stride] = scan_scratch[idx];
scan_scratch[idx] += temp;
}
workgroupBarrier();
}
// Phase 3: Add thread-level offsets back to each element
let offset = scan_scratch[lid.x];
for (var i = start; i < end; i++) {
histograms[i] += offset;
}
}
// ============================================================
// Kernel 3: Scatter (stable)
// Each element computes its digit, then a workgroup-local prefix
// sum determines a deterministic offset within each bucket.
// This preserves input order — threads with lower indices get
// lower offsets within the same bucket, making the sort stable.
// ============================================================
var<workgroup> scatter_digits: array<u32, 256>;
var<workgroup> scatter_rank: array<u32, 256>;
var<workgroup> bucket_bases: array<u32, 16>;
@compute @workgroup_size(256)
fn scatter(
@builtin(global_invocation_id) gid: vec3u,
@builtin(workgroup_id) wid: vec3u,
@builtin(local_invocation_id) lid: vec3u,
) {
// Load this workgroup's base offsets from the prefix sum
if (lid.x < NUM_BUCKETS) {
let global_idx = lid.x * u.workgroup_count + wid.x;
bucket_bases[lid.x] = histograms[global_idx];
}
// Each thread computes its digit
let idx = gid.x;
var digit = 0u;
if (idx < u.element_count) {
digit = (input[idx] >> u.bit_offset) & (NUM_BUCKETS - 1u);
}
scatter_digits[lid.x] = digit;
workgroupBarrier();
// Compute rank: count how many threads with a lower index
// in this workgroup have the same digit. O(n) per thread
// within the workgroup (n=256), trading throughput for
// deterministic stable ordering. A parallel prefix scan
// would be faster but adds complexity.
var rank = 0u;
if (idx < u.element_count) {
for (var i = 0u; i < lid.x; i++) {
if (scatter_digits[i] == digit) {
rank += 1u;
}
}
}
workgroupBarrier();
// Write to output at the deterministic position
if (idx < u.element_count) {
let dest = bucket_bases[digit] + rank;
output[dest] = input[idx];
}
}
// GPU Radix Sort
// Sorts arrays of unsigned 32-bit integers on the GPU using a parallel
// radix sort. Processes 4 bits per pass (16 buckets), 8 passes total.
//
// Default WGSL loading uses a ?raw import (works with Vite, esbuild, Webpack).
// Alternative: load via fetch — see README.md for details.
import shaderSource from './radix-sort.wgsl?raw';
export interface RadixSortOptions {
maxElements?: number;
}
export interface RadixSort {
sort(buffer: GPUBuffer, elementCount: number): void;
sortPairs(keysBuffer: GPUBuffer, valuesBuffer: GPUBuffer, elementCount: number): void;
destroy(): void;
}
const WORKGROUP_SIZE = 256;
const NUM_BUCKETS = 16;
const NUM_PASSES = 8; // 32 bits / 4 bits per pass
export function createRadixSort(device: GPUDevice, options: RadixSortOptions = {}): RadixSort {
const maxElements = options.maxElements ?? 1_000_000;
const workgroupCount = Math.ceil(maxElements / WORKGROUP_SIZE);
// Internal buffers
const bufferA = device.createBuffer({
size: maxElements * 4,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST
});
const bufferB = device.createBuffer({
size: maxElements * 4,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST
});
// Per-workgroup histograms: NUM_BUCKETS * workgroupCount entries
const histogramBuffer = device.createBuffer({
size: NUM_BUCKETS * workgroupCount * 4,
usage: GPUBufferUsage.STORAGE
});
const uniformBuffer = device.createBuffer({
size: 16, // 3 x u32 + padding (WebGPU requires 16-byte aligned uniform buffers)
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
});
const shaderModule = device.createShaderModule({ code: shaderSource });
// Bind group layouts — we need different layouts for each pass direction
// (input/output swap), so we create bind groups dynamically.
const bindGroupLayout = device.createBindGroupLayout({
entries: [
{ binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform' } },
{ binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } },
{ binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } },
{ binding: 3, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }
]
});
const pipelineLayout = device.createPipelineLayout({ bindGroupLayouts: [bindGroupLayout] });
const histogramPipeline = device.createComputePipeline({
layout: pipelineLayout,
compute: { module: shaderModule, entryPoint: 'histogram' }
});
const prefixSumPipeline = device.createComputePipeline({
layout: pipelineLayout,
compute: { module: shaderModule, entryPoint: 'prefix_sum' }
});
const scatterPipeline = device.createComputePipeline({
layout: pipelineLayout,
compute: { module: shaderModule, entryPoint: 'scatter' }
});
function createBindGroup(input: GPUBuffer, output: GPUBuffer): GPUBindGroup {
return device.createBindGroup({
layout: bindGroupLayout,
entries: [
{ binding: 0, resource: { buffer: uniformBuffer } },
{ binding: 1, resource: { buffer: input } },
{ binding: 2, resource: { buffer: output } },
{ binding: 3, resource: { buffer: histogramBuffer } }
]
});
}
function sort(buffer: GPUBuffer, elementCount: number): void {
if (elementCount <= 0) return;
const wgCount = Math.ceil(elementCount / WORKGROUP_SIZE);
const uniformData = new Uint32Array([elementCount, 0, wgCount]);
// Copy user data into internal buffer A
const copyEncoder = device.createCommandEncoder();
copyEncoder.copyBufferToBuffer(buffer, 0, bufferA, 0, elementCount * 4);
device.queue.submit([copyEncoder.finish()]);
// Each pass must be a separate submit so that writeBuffer
// updates the uniform before the pass reads it.
for (let pass = 0; pass < NUM_PASSES; pass++) {
uniformData[1] = pass * 4;
device.queue.writeBuffer(uniformBuffer, 0, uniformData);
const isEvenPass = pass % 2 === 0;
const inputBuf = isEvenPass ? bufferA : bufferB;
const outputBuf = isEvenPass ? bufferB : bufferA;
const bindGroup = createBindGroup(inputBuf, outputBuf);
const encoder = device.createCommandEncoder();
const passEncoder = encoder.beginComputePass();
passEncoder.setPipeline(histogramPipeline);
passEncoder.setBindGroup(0, bindGroup);
passEncoder.dispatchWorkgroups(wgCount);
passEncoder.setPipeline(prefixSumPipeline);
passEncoder.setBindGroup(0, bindGroup);
passEncoder.dispatchWorkgroups(1);
passEncoder.setPipeline(scatterPipeline);
passEncoder.setBindGroup(0, bindGroup);
passEncoder.dispatchWorkgroups(wgCount);
passEncoder.end();
device.queue.submit([encoder.finish()]);
}
// NUM_PASSES = 8 (even), so result is in bufferA
const resultEncoder = device.createCommandEncoder();
resultEncoder.copyBufferToBuffer(bufferA, 0, buffer, 0, elementCount * 4);
device.queue.submit([resultEncoder.finish()]);
}
function sortPairs(): never {
throw new Error('sortPairs is not yet implemented — use sort() for key-only sorting');
}
function destroy(): void {
bufferA.destroy();
bufferB.destroy();
histogramBuffer.destroy();
uniformBuffer.destroy();
}
return { sort, sortPairs, destroy };
}
Documentation
Radix Sort
Radix sort for unsigned 32-bit integers. Pure compute — no visual output.
API
createRadixSort(device, options?)
Returns a RadixSort instance.
| Option | Type | Default | Description |
|---|---|---|---|
maxElements |
number |
1_000_000 |
Maximum number of elements (determines internal buffer sizes) |
sorter.sort(buffer, elementCount)
Sorts a GPUBuffer of u32 values in-place (ascending order).
| Param | Type | Description |
|---|---|---|
buffer |
GPUBuffer |
Buffer containing u32 values. Must have STORAGE | COPY_SRC | COPY_DST usage. |
elementCount |
number |
Number of elements to sort (not bytes) |
sorter.sortPairs(keysBuffer, valuesBuffer, elementCount)
Sorts keys and reorders associated values to match.
sorter.destroy()
Releases all internal GPU buffers.
Further Reading
Further Reading
Resources on GPU radix sort algorithms and parallel sorting techniques.
Core Papers
Satish, Harris, Garland, "Designing Efficient Sorting Algorithms for Manycore GPUs" (2009) The foundational paper on GPU radix sort. Introduces the histogram–prefix-sum–scatter pattern used in this module. Benchmarks against CPU sorting and other GPU approaches. https://mgarland.org/files/papers/nvr-2008-001.pdf
Merrill, Grimshaw, "High Performance and Scalable Radix Sorting" (2011) Advances the state of the art with a more efficient single-pass scan and better workgroup utilization. Basis for CUB's radix sort. https://escholarship.org/uc/item/8051s4pj
Harada, Howes, "Introduction to GPU Radix Sort" (2011) A practical, accessible walkthrough of implementing radix sort on the GPU. Good companion to the academic papers. http://www.heterogeneouscompute.org/wordpress/wp-content/uploads/2011/06/RadixSort.pdf
Prefix Sum (Scan)
Blelloch, "Prefix Sums and Their Applications" (1990) The classic reference on parallel prefix sum (scan). The Blelloch up-sweep/down-sweep pattern is used in this module's prefix sum kernel. https://www.cs.cmu.edu/~blelloch/papers/Ble90.pdf
Harris, Sengupta, Owens, "Parallel Prefix Sum (Scan) with CUDA" (GPU Gems 3, Chapter 39) Practical GPU implementation of parallel prefix sum with work-efficient algorithms and bank conflict avoidance. https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda
WebGPU-Specific
Tint WGSL Compiler — Workgroup shared memory Understanding workgroup-shared variables and barriers in WGSL, which are critical for the histogram and prefix sum kernels. https://www.w3.org/TR/WGSL/#var-and-value
WebGPU Specification — Compute Shaders The official spec for compute shader dispatch, workgroup sizes, and synchronization primitives. https://www.w3.org/TR/webgpu/#compute-passes
General References
Cormen, Leiserson, Rivest, Stein, "Introduction to Algorithms" (CLRS) Chapter 8 covers radix sort in the sequential setting. Understanding the sequential algorithm helps reason about the parallel version's correctness.
Duane Merrill, "CUB Library" CUDA's high-performance primitives library, including the most optimized GPU radix sort implementation. A good reference for advanced optimization techniques. https://github.com/NVIDIA/cub
Vulkan / CUDA Radix Sort Benchmarks Practical performance comparisons across GPU sorting implementations, useful for understanding expected throughput. https://github.com/b0nes164/GPUSorting