CUDA Block Scheduler
The CUDA programming model supports launching dynamic number of blocks in a grid, and dynamic number of threads in a block. To make this possible, the underlying hardware must be able to dynamically schedule blocks onto actual cores - streaming multiprocessors (SMs), and dynamically schedule warps onto warp execution units.
Typically there are much more blocks than SMs, and much more warps than warp execution units. When a kernel is launched, the hardware resources used by a block is known: the most important ones are
- Number of warps used per block
- Amount of shared memory used
- Amount of registers used
For example, for an RTX 3090:
Device 0: NVIDIA GeForce RTX 3090
Compute capability: 8.6
Total global memory: 25327697920 bytes
Multi-processor count: 82
Max threads per block: 1024
Warp size: 32
Max threads dimensions: (1024, 1024, 64)
Max grid size: (2147483647, 65535, 65535)
Memory clock rate (kHz): 9751000
Memory bus width (bits): 384
L2 cache size: 6291456
Concurrent kernels: 1
Cooperative launch: Supported
Shared memory per block: 49152 bytes
Registers per block: 65536
Max shared memory per multiprocessor: 102400 bytes
Max threads per multiprocessor: 1536
Max registers per multiprocessor: 65536
Max registers per thread: 255
As this paper suggests, the block scheduler uses a best fit algorithm to assign a block to the SM that is least resource constrained, and breaks ties in a pre-determined order (that seems to be generation specific).
Regardless, my working model for thinking about the block scheduler is as follows:
- launch kernel K, which uses some amount of registers, shared memory and threads.
- the block scheduler will place blocks onto SMs that fit, until there are no more blocks.
- each SM can execute multiple blocks concurrently, in each SM the warp scheduler decides which warp to execute next, if none are stalled. Warp can get stalled due to waiting for memory access, or waiting on various compute pipelines.
- after all the blocks of the current kernel finish, the next kernel in the stream can start.
This dynamicism on the parts of the block and warp scheduler is good for overlapping compute with communication, and more fully utilizing all available hardware resources. It can hide "naive" kernels' performance issues through concurrency on the warp scheduler level; multiple warps executing on the same SM can hide each other's latency, for example, as each can take turns executing on different units.
One can demonstrate this through a toy example that simulates a memory bound workload. The workload is
int tid = threadIdx.x + blockIdx.x * blockDim.x;
for (int s = 0; s < stages; s++) {
for (int i = tid; i < N; i += stride ) {
b[i] = elementwise(a[i]);
}
swap(b, a);
}
One can choose to partition this workload among the blocks in many ways. One way is to launch as many blocks as possible:
nblocks = N / nThreads, which produces much more blocks than SMs, and leaves it up to the block scheduler to distribute the blocks amongst the available SMs. Another way is to place k blocks on each SM, and divide the work N amongst k * numSMs blocks.
In the following experiments, seq-naive refers to an naive kernel much like the pseudo code above, straightup load, compute and store. seq refers to a slightly optimized version that pipelines the memory accesses to reduce memory access dependencies, and thus warp-stalling.
Kernel stages blocks_per_sm time(ms)
seq-naive: 256 1 36.562
seq: 256 1 20.020
grid sync: 256 1 20.558
Kernel stages blocks_per_sm time(ms)
seq-naive: 256 2 24.090
seq: 256 2 20.026
grid sync: 256 2 20.249
Kernel stages blocks_per_sm time(ms)
seq-naive: 256 4 20.099
seq: 256 4 19.971
grid sync: 256 4 20.247
One can see that the naive kernel performance drastically improves as the number of concurrent blocks per SM increases, from 36.5 ms, to 24.1 ms to 20.1 ms, while the optimized version does not change. The naive kernel only uses about 50% of the max device memory bandwidth, while the slightly optimized version uses about 90%. So it seems that occupancy, the number of blocks that SMs can accommodate, is important insofar as the kernel is effective at using device resources, be it the memory, scalar math or tensor pipelines.
The paper also goes into some detail about the block scheduler from a multi-stream context. There are some interesting contention effects when two kernels are running concurrently, when the block scheduler places blocks from each kernel on the same SMs, which result in slowdowns where warps must content for the same hardware pipelines. Another interesting detail is that sometimes the block scheduler performs suboptimally in non-obvious ways, when a minor change in launch configuration makes the block scheduler assign SMs suboptimally that results in non-negligible performance degradations.
I think there's quite a lot of work on optimizing single kernel performance, but not enough work on the orchestration of kernels as a whole in ML-systems research. It very well could be that when considered in isolation, a kernel may perform optimally and approach the roofline, but such kernels (I'm thinking of matrix multiplication and flash attention) are very resource intensive, and potentially slow down the system as a whole when other kernels are factored into the equation.