Abstract:
This technical presentation by Michael Lazos (Meta) details the integration of CUDA stream semantics into the torch.compile stack to enable high-performance asynchronous execution. Historically, PyTorch's compiler focused on single-stream optimizations; this work extends that capability by allowing the compiler to respect and optimize across multiple execution queues.
The implementation spans three critical layers: Torch Dynamo (symbolic stream tracking and graph annotation), AOT Autograd (parallelizing the backward pass and enforcing synchronization boundaries), and Torch Inductor (preventing invalid kernel fusions and ensuring cross-stream memory safety). Key technical hurdles addressed include the handling of non-tensor arguments through global object tables, the use of "control dependencies" to prevent the reordering of operations across event boundaries, and the management of input mutations during functionalization. Practical applications, such as microbatch communication/compute overlapping and activation offloading, demonstrate significant peak memory savings with minimal runtime overhead in large-scale transformer models.
Technical Synthesis: Asynchronous Execution in Torch.compile with CUDA Streams
- 0:00 Core Objectives: The initiative aims to enable state-of-the-art asynchronous execution within the
torch.compileworkflow, specifically utilizing CUDA streams for concurrent kernel execution and compute/communication overlapping. - 0:51 Fundamentals of Streams and Events: Streams function as execution queues that allow for concurrent operations, compute/memory transfer hiding, and cross-device synchronization. Synchronization is managed via "events" that record stream progress and block subsequent streams until work is finalized.
- 2:11 Architecture Overview: The system integration involves three phases:
- Torch Dynamo: Tracks the current stream symbolically and annotates FX graph nodes.
- AOT Autograd: Generates a synchronized backward pass and preserves stream ordering.
- Torch Inductor: Handles code generation (cogen), restricts cross-stream fusions, and manages memory safety.
- 3:37 Stream Tracking in Dynamo: Dynamo utilizes a symbolic stack to match eager-mode context manager semantics (e.g.,
cuda.stream(s1)), ensuring that captured FX graph nodes are metadata-tagged with the correct stream index. - 4:26 Handling Non-Tensor Graph Inputs: Because AOT Autograd natively only supports tensor arguments, streams and events are managed via a global object table. The compiler rewrites bytecode to look up these objects by index, avoiding a massive refactor of the autograd engine.
- 6:48 AOT Autograd Synchronization: The backward pass is designed to be faithful to eager-mode parallelization. Stream indices are propagated from forward nodes to their backward analogs, and synchronization points (record/wait) are automatically inserted when a kernel consumes an argument produced on a different stream.
- 8:55 Preserving Execution Order: To prevent the compiler from reordering nodes across event boundaries (which causes race conditions), "fake dependencies" are introduced using a
control_depsoperator. This explicitly links tensors to event records and weights in the graph IR. - 10:43 Functionalization and Mutation Challenges: Input mutations in
torch.compileare typically moved to a "copy epilog" at the end of the graph. If this move crosses a stream synchronization boundary, it can lead to data hazards. The compiler currently throws an error for these specific race conditions to maintain correctness. - 13:47 Inductor Fusion and Memory Safety: Torch Inductor is modified to prohibit the fusion of kernels assigned to different streams. Additionally, the caching allocator's behavior is respected to ensure that memory buffers are not reused until all side-stream operations involving those buffers are complete.
- 17:12 Application: Microbatch Overlap: By utilizing side-streams for communication (e.g., AllReduce, AllGather) during compute-intensive operations, the system achieves lower latency through GPU utilization maximization.
- 18:07 Application: Activation Offloading: This technique hides the latency of moving tensors between GPU and CPU during the forward and backward passes. It enables significant peak memory reduction, especially in large language models (LLMs) where compute volume is sufficient to mask the memory transfer time.
- 19:19 Performance Metrics: Benchmarking on transformer architectures shows that as model size increases, the runtime overhead of activation offloading decreases (moving toward 0%) while memory savings increase (up to 30%+), provided there is enough compute to hide D2H/H2D transfers.
- 20:15 Availability: Support for device-agnostic streams (including AMD via
torch.stream) is integrated into the PyTorch 2.12 release.