Radix Sort

GPU radix sort for u32 arrays

Quick Start

import { createRadixSort } from './webgpu-market/radix-sort/radix-sort';

const sorter = createRadixSort(device, { maxElements: 1_000_000 });

// Sort a buffer in-place
sorter.sort(myBuffer, elementCount);

sorter.destroy();
Source
// GPU Radix Sort — 4-bit radix (16 buckets), 8 passes for 32-bit keys
//
// Three kernels dispatched per pass:
//   1. histogram  — count elements per bucket per workgroup
//   2. prefix_sum — exclusive scan over histograms for global offsets
//   3. scatter    — place each element at its sorted position
//
// Uses ping-pong buffers: input and output swap each pass.

const RADIX_BITS: u32 = 4u;
const NUM_BUCKETS: u32 = 16u; // 2^RADIX_BITS
const WORKGROUP_SIZE: u32 = 256u;

struct Uniforms {
  element_count: u32,
  bit_offset: u32,     // 0, 4, 8, 12, 16, 20, 24, 28
  workgroup_count: u32,
}

@group(0) @binding(0) var<uniform> u: Uniforms;
@group(0) @binding(1) var<storage, read> input: array<u32>;
@group(0) @binding(2) var<storage, read_write> output: array<u32>;
@group(0) @binding(3) var<storage, read_write> histograms: array<u32>;

// ============================================================
// Kernel 1: Histogram
// Each workgroup counts how many of its elements fall into each
// of the 16 buckets for the current radix digit.
// ============================================================

var<workgroup> local_histogram: array<atomic<u32>, 16>;

@compute @workgroup_size(256)
fn histogram(
  @builtin(global_invocation_id) gid: vec3u,
  @builtin(workgroup_id) wid: vec3u,
  @builtin(local_invocation_id) lid: vec3u,
) {
  // Clear shared histogram
  if (lid.x < NUM_BUCKETS) {
    atomicStore(&local_histogram[lid.x], 0u);
  }
  workgroupBarrier();

  // Count elements in this workgroup's range
  let idx = gid.x;
  if (idx < u.element_count) {
    let digit = (input[idx] >> u.bit_offset) & (NUM_BUCKETS - 1u);
    atomicAdd(&local_histogram[digit], 1u);
  }
  workgroupBarrier();

  // Write workgroup histogram to global memory
  // Layout: histograms[bucket * workgroup_count + workgroup_id]
  if (lid.x < NUM_BUCKETS) {
    let global_idx = lid.x * u.workgroup_count + wid.x;
    histograms[global_idx] = atomicLoad(&local_histogram[lid.x]);
  }
}

// ============================================================
// Kernel 2: Prefix Sum (exclusive scan)
// Computes global offsets from the per-workgroup histograms.
// Single workgroup scans the entire histogram array.
//
// The histogram is laid out as [bucket][workgroup], so scanning
// it linearly gives us stable sort order — elements from earlier
// workgroups land before elements from later workgroups within
// the same bucket.
// ============================================================

var<workgroup> scan_scratch: array<u32, 256>;

@compute @workgroup_size(256)
fn prefix_sum(
  @builtin(local_invocation_id) lid: vec3u,
) {
  let total_entries = NUM_BUCKETS * u.workgroup_count;

  // Process multiple elements per thread if needed
  // We do a sequential scan per thread, then a workgroup-level scan
  // of the per-thread totals, then add the offsets back.

  let chunk_size = (total_entries + WORKGROUP_SIZE - 1u) / WORKGROUP_SIZE;
  let start = lid.x * chunk_size;
  let end = min(start + chunk_size, total_entries);

  // Phase 1: Sequential scan within each thread's chunk
  var thread_total = 0u;
  for (var i = start; i < end; i++) {
    let val = histograms[i];
    histograms[i] = thread_total;
    thread_total += val;
  }

  scan_scratch[lid.x] = thread_total;
  workgroupBarrier();

  // Phase 2: Exclusive scan of thread totals (Blelloch-style)
  // Up-sweep
  for (var stride = 1u; stride < WORKGROUP_SIZE; stride *= 2u) {
    let idx = (lid.x + 1u) * stride * 2u - 1u;
    if (idx < WORKGROUP_SIZE) {
      scan_scratch[idx] += scan_scratch[idx - stride];
    }
    workgroupBarrier();
  }

  if (lid.x == 0u) {
    scan_scratch[WORKGROUP_SIZE - 1u] = 0u;
  }
  workgroupBarrier();

  // Down-sweep
  for (var stride = WORKGROUP_SIZE / 2u; stride >= 1u; stride /= 2u) {
    let idx = (lid.x + 1u) * stride * 2u - 1u;
    if (idx < WORKGROUP_SIZE) {
      let temp = scan_scratch[idx - stride];
      scan_scratch[idx - stride] = scan_scratch[idx];
      scan_scratch[idx] += temp;
    }
    workgroupBarrier();
  }

  // Phase 3: Add thread-level offsets back to each element
  let offset = scan_scratch[lid.x];
  for (var i = start; i < end; i++) {
    histograms[i] += offset;
  }
}

// ============================================================
// Kernel 3: Scatter (stable)
// Each element computes its digit, then a workgroup-local prefix
// sum determines a deterministic offset within each bucket.
// This preserves input order — threads with lower indices get
// lower offsets within the same bucket, making the sort stable.
// ============================================================

var<workgroup> scatter_digits: array<u32, 256>;
var<workgroup> scatter_rank: array<u32, 256>;
var<workgroup> bucket_bases: array<u32, 16>;

@compute @workgroup_size(256)
fn scatter(
  @builtin(global_invocation_id) gid: vec3u,
  @builtin(workgroup_id) wid: vec3u,
  @builtin(local_invocation_id) lid: vec3u,
) {
  // Load this workgroup's base offsets from the prefix sum
  if (lid.x < NUM_BUCKETS) {
    let global_idx = lid.x * u.workgroup_count + wid.x;
    bucket_bases[lid.x] = histograms[global_idx];
  }

  // Each thread computes its digit
  let idx = gid.x;
  var digit = 0u;
  if (idx < u.element_count) {
    digit = (input[idx] >> u.bit_offset) & (NUM_BUCKETS - 1u);
  }
  scatter_digits[lid.x] = digit;
  workgroupBarrier();

  // Compute rank: count how many threads with a lower index
  // in this workgroup have the same digit. O(n) per thread
  // within the workgroup (n=256), trading throughput for
  // deterministic stable ordering. A parallel prefix scan
  // would be faster but adds complexity.
  var rank = 0u;
  if (idx < u.element_count) {
    for (var i = 0u; i < lid.x; i++) {
      if (scatter_digits[i] == digit) {
        rank += 1u;
      }
    }
  }
  workgroupBarrier();

  // Write to output at the deterministic position
  if (idx < u.element_count) {
    let dest = bucket_bases[digit] + rank;
    output[dest] = input[idx];
  }
}
Documentation

Radix Sort

Radix sort for unsigned 32-bit integers. Pure compute — no visual output.

API

createRadixSort(device, options?)

Returns a RadixSort instance.

Option Type Default Description
maxElements number 1_000_000 Maximum number of elements (determines internal buffer sizes)

sorter.sort(buffer, elementCount)

Sorts a GPUBuffer of u32 values in-place (ascending order).

Param Type Description
buffer GPUBuffer Buffer containing u32 values. Must have STORAGE | COPY_SRC | COPY_DST usage.
elementCount number Number of elements to sort (not bytes)

sorter.sortPairs(keysBuffer, valuesBuffer, elementCount)

Sorts keys and reorders associated values to match.

sorter.destroy()

Releases all internal GPU buffers.

Further Reading

Further Reading

Resources on GPU radix sort algorithms and parallel sorting techniques.

Core Papers

Prefix Sum (Scan)

WebGPU-Specific

General References

  • Cormen, Leiserson, Rivest, Stein, "Introduction to Algorithms" (CLRS) Chapter 8 covers radix sort in the sequential setting. Understanding the sequential algorithm helps reason about the parallel version's correctness.

  • Duane Merrill, "CUB Library" CUDA's high-performance primitives library, including the most optimized GPU radix sort implementation. A good reference for advanced optimization techniques. https://github.com/NVIDIA/cub

  • Vulkan / CUDA Radix Sort Benchmarks Practical performance comparisons across GPU sorting implementations, useful for understanding expected throughput. https://github.com/b0nes164/GPUSorting