diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_bwd_ck.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_bwd_ck.hip index d7633521d8eef..1907ffe875d9f 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_bwd_ck.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_bwd_ck.hip @@ -7,8 +7,110 @@ #include #include +#include +#include +#include + namespace pytorch_flash { +// SFINAE for newer composable_kernel `fmha_bwd.hpp` vs older CK (see mha_fwd_ck.hip). +// `fmha_bwd_args` / `aiter::mha_bwd_args` may gain `sink_ptr` / `d_sink_ptr` after `dq_acc_ptr`. +// Same pattern as `mha_fwd_ck.hip`: value-init, assign fields, then `set_fmha_bwd_sink_ptr_fields`. +template +struct has_fmha_bwd_args_sink_fields : std::false_type {}; + +template +struct has_fmha_bwd_args_sink_fields().sink_ptr)>> + : std::true_type {}; + +template +void set_fmha_bwd_sink_ptr_fields([[maybe_unused]] Args &args) +{ + if constexpr(has_fmha_bwd_args_sink_fields::value) { + args.sink_ptr = nullptr; + args.d_sink_ptr = nullptr; + } +} + +// SFINAE for an older composable_kernel `fmha_bwd.hpp` snapshot (see mha_fwd_ck.hip). +// In older CK, `aiter::mha_bwd_args` carried a `ck_mask_type` field in its `fmha_bwd_traits` +// block (alongside the aiter-level `mask_type`); newer CK renamed/collapsed it into `mask_type`. +template +struct has_fmha_bwd_args_ck_mask_type : std::false_type {}; + +template +struct has_fmha_bwd_args_ck_mask_type().ck_mask_type)>> + : std::true_type {}; + +template +void set_fmha_bwd_ck_mask_type([[maybe_unused]] Args &args, [[maybe_unused]] int ck_mask_type) +{ + if constexpr(has_fmha_bwd_args_ck_mask_type::value) { + args.ck_mask_type = ck_mask_type; + } +} + +// SFINAE for the dq accumulation interface. Older `aiter::mha_bwd_args` consumes an +// explicit dq_acc buffer (pointer + strides); newer AITER removed these fields and +// instead lets the kernel manage scratch internally via a `workspace_alloc` callback. +// Exactly one regime is present in a given submodule snapshot, so gate both. +template +struct has_fmha_bwd_args_dq_acc : std::false_type {}; + +template +struct has_fmha_bwd_args_dq_acc().dq_acc_ptr)>> + : std::true_type {}; + +template +void set_fmha_bwd_dq_acc_fields([[maybe_unused]] Args &args, + [[maybe_unused]] void *dq_acc_ptr, + [[maybe_unused]] int64_t stride_dq_acc, + [[maybe_unused]] int64_t nhead_stride_dq_acc, + [[maybe_unused]] int64_t batch_stride_dq_acc, + [[maybe_unused]] int64_t split_stride_dq_acc) +{ + if constexpr(has_fmha_bwd_args_dq_acc::value) { + args.dq_acc_ptr = dq_acc_ptr; + args.stride_dq_acc = stride_dq_acc; + args.nhead_stride_dq_acc = nhead_stride_dq_acc; + args.batch_stride_dq_acc = batch_stride_dq_acc; + args.split_stride_dq_acc = split_stride_dq_acc; + } +} + +template +struct has_fmha_bwd_args_workspace_alloc : std::false_type {}; + +template +struct has_fmha_bwd_args_workspace_alloc().workspace_alloc)>> + : std::true_type {}; + +template +void set_fmha_bwd_workspace_alloc([[maybe_unused]] Args &args, [[maybe_unused]] Fn &&workspace_alloc) +{ + if constexpr(has_fmha_bwd_args_workspace_alloc::value) { + args.workspace_alloc = std::forward(workspace_alloc); + } +} + +// Even in batch mode the CK backward launcher stages dq-accumulation metadata through a +// host workspace (host_ws_size_ is padded to 4 KiB for the non-QrQtrDor pipeline), so +// prepare_workspace_async requires a pinned-host allocator. Gate it like workspace_alloc. +template +struct has_fmha_bwd_args_pinned_host_alloc : std::false_type {}; + +template +struct has_fmha_bwd_args_pinned_host_alloc().pinned_host_alloc)>> + : std::true_type {}; + +template +void set_fmha_bwd_pinned_host_alloc([[maybe_unused]] Args &args, [[maybe_unused]] Fn &&pinned_host_alloc) +{ + if constexpr(has_fmha_bwd_args_pinned_host_alloc::value) { + args.pinned_host_alloc = std::forward(pinned_host_alloc); + } +} + aiter::mha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask, std::string dtype, bool has_dropout, @@ -86,12 +188,6 @@ aiter::mha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask, ck_tile::index_t stride_dv = dv.stride(1); ck_tile::index_t nhead_stride_dv = dv.stride(2); - // dq_acc: (split, batch_size, nheads, seqlen_q, hdim) - ck_tile::index_t split_stride_dq_acc = dq_acc.stride(0); - ck_tile::long_index_t batch_stride_dq_acc = dq_acc.stride(1); - ck_tile::index_t stride_dq_acc = dq_acc.stride(3); - ck_tile::long_index_t nhead_stride_dq_acc = dq_acc.stride(2); - // bias: (batch_size, nheads, seqlen_q, seqlen_k) void *attn_bias_ptr = nullptr; ck_tile::index_t nhead_stride_bias = 0; @@ -144,80 +240,88 @@ aiter::mha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask, false, // is_store_randval deterministic, // is_deterministic - // From ck fmha_bwd_args - q.data_ptr(), - k.data_ptr(), - v.data_ptr(), - attn_bias_ptr, - out.data_ptr(), // o_ptr - softmax_lse.data_ptr(), // lse_ptr - dout.data_ptr(), // do_ptr - d.data_ptr(), - nullptr, // rand_val_ptr - dq.data_ptr(), - dk.data_ptr(), - dv.data_ptr(), - dbias_ptr, - dq_acc.data_ptr(), // dq_acc_ptr - nullptr, // seqstart_q_ptr - nullptr, // seqstart_k_ptr - nullptr, // seqlen_q_ptr - nullptr, // seqlen_k_ptr - nullptr, // cu_seqlen_q_ptr - nullptr, // cu_seqlen_k_ptr - seqlen_q, - seqlen_k, - b, // batch - seqlen_q, // max_seqlen_q - seqlen_k, // max_seqlen_k - h, // nhead_q - h_k, // nhead_k - softmax_scale, // scale - stride_q, - stride_k, - stride_v, - stride_attn_bias, // stride_bias - stride_o, - 0, // stride_randval - stride_do, - stride_dq_acc, - stride_dq, - stride_dk, - stride_dv, - stride_dbias, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_bias, - nhead_stride_o, - 0, // nhead_stride_randval - nhead_stride_do, - nhead_stride_lse, - nhead_stride_dq_acc, - nhead_stride_dq, - nhead_stride_dk, - nhead_stride_dv, - nhead_stride_dbias, - batch_stride_q, - batch_stride_k, - batch_stride_v, - batch_stride_bias, - batch_stride_o, - 0, // batch_stride_randval - batch_stride_do, - batch_stride_lse, - batch_stride_dq_acc, - batch_stride_dq, - batch_stride_dk, - batch_stride_dv, - batch_stride_dbias, - split_stride_dq_acc, - mask.left, // window_size_left - mask.right, // window_size_right - p_dropout, // p_drop - p_undrop, - drop_seed_offset - }; + args.q_ptr = q.data_ptr(); + args.k_ptr = k.data_ptr(); + args.v_ptr = v.data_ptr(); + args.bias_ptr = attn_bias_ptr; + args.o_ptr = out.data_ptr(); + args.lse_ptr = softmax_lse.data_ptr(); + args.do_ptr = dout.data_ptr(); + args.d_ptr = d.data_ptr(); + args.rand_val_ptr = nullptr; + args.dq_ptr = dq.data_ptr(); + args.dk_ptr = dk.data_ptr(); + args.dv_ptr = dv.data_ptr(); + args.dbias_ptr = dbias_ptr; + set_fmha_bwd_sink_ptr_fields(args); + args.seqstart_q_ptr = nullptr; + args.seqstart_k_ptr = nullptr; + args.seqlen_q_ptr = nullptr; + args.seqlen_k_ptr = nullptr; + args.cu_seqlen_q_ptr = nullptr; + args.cu_seqlen_k_ptr = nullptr; + args.seqlen_q = seqlen_q; + args.seqlen_k = seqlen_k; + args.batch = b; + args.max_seqlen_q = seqlen_q; + args.max_seqlen_k = seqlen_k; + args.nhead_q = h; + args.nhead_k = h_k; + args.scale = softmax_scale; + args.stride_q = stride_q; + args.stride_k = stride_k; + args.stride_v = stride_v; + args.stride_bias = stride_attn_bias; + args.stride_o = stride_o; + args.stride_randval = 0; + args.stride_do = stride_do; + args.stride_dq = stride_dq; + args.stride_dk = stride_dk; + args.stride_dv = stride_dv; + args.stride_dbias = stride_dbias; + args.nhead_stride_q = nhead_stride_q; + args.nhead_stride_k = nhead_stride_k; + args.nhead_stride_v = nhead_stride_v; + args.nhead_stride_bias = nhead_stride_bias; + args.nhead_stride_o = nhead_stride_o; + args.nhead_stride_randval = 0; + args.nhead_stride_do = nhead_stride_do; + args.nhead_stride_lsed = nhead_stride_lse; + args.nhead_stride_dq = nhead_stride_dq; + args.nhead_stride_dk = nhead_stride_dk; + args.nhead_stride_dv = nhead_stride_dv; + args.nhead_stride_dbias = nhead_stride_dbias; + args.batch_stride_q = batch_stride_q; + args.batch_stride_k = batch_stride_k; + args.batch_stride_v = batch_stride_v; + args.batch_stride_bias = batch_stride_bias; + args.batch_stride_o = batch_stride_o; + args.batch_stride_randval = 0; + args.batch_stride_do = batch_stride_do; + args.batch_stride_lsed = batch_stride_lse; + args.batch_stride_dq = batch_stride_dq; + args.batch_stride_dk = batch_stride_dk; + args.batch_stride_dv = batch_stride_dv; + args.batch_stride_dbias = batch_stride_dbias; + args.window_size_left = mask.left; + args.window_size_right = mask.right; + args.p_drop = p_dropout; + args.p_undrop = p_undrop; + args.drop_seed_offset = drop_seed_offset; + + // Older AITER consumes an explicit dq accumulation buffer; gate so this compiles + // against newer AITER where the kernel manages dq scratch via workspace_alloc. + if constexpr(has_fmha_bwd_args_dq_acc::value) { + // dq_acc: (split, batch_size, nheads, seqlen_q, hdim) + set_fmha_bwd_dq_acc_fields( + args, + dq_acc.data_ptr(), + dq_acc.stride(3), // stride_dq_acc + dq_acc.stride(2), // nhead_stride_dq_acc + dq_acc.stride(1), // batch_stride_dq_acc + dq_acc.stride(0)); // split_stride_dq_acc + } + return args; } std::tuple @@ -358,12 +462,16 @@ mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x auto softmax_d = at::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); at::Tensor dq_accum; - if (!deterministic) { - dq_accum = at::zeros({1, batch_size, num_heads, seqlen_q, head_size_8x}, opts.dtype(at::kFloat)); - } else { - const ck_tile::index_t kN0 = head_size_8x <= 128 ? 128 : 64; - const ck_tile::index_t nsplits = ck_tile::integer_divide_ceil(seqlen_k, kN0); - dq_accum = at::zeros({nsplits, batch_size, num_heads, seqlen_q, head_size_8x}, opts.dtype(at::kFloat)); + // Newer AITER manages dq scratch internally (via workspace_alloc); only allocate the + // explicit dq accumulation buffer when the struct still exposes the dq_acc fields. + if constexpr(has_fmha_bwd_args_dq_acc::value) { + if (!deterministic) { + dq_accum = at::zeros({1, batch_size, num_heads, seqlen_q, head_size_8x}, opts.dtype(at::kFloat)); + } else { + const ck_tile::index_t kN0 = head_size_8x <= 128 ? 128 : 64; + const ck_tile::index_t nsplits = ck_tile::integer_divide_ceil(seqlen_k, kN0); + dq_accum = at::zeros({nsplits, batch_size, num_heads, seqlen_q, head_size_8x}, opts.dtype(at::kFloat)); + } } at::Tensor dk_expanded, dv_expanded; @@ -421,6 +529,29 @@ mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x p_dropout, drop_seed_offset); + // Device scratch allocator for newer AITER (kernel-managed dq workspace). The + // returned pointer must stay valid until aiter::mha_bwd returns, so keep the + // backing tensor alive in this scope. No-op against older AITER (field absent). + at::Tensor workspace; + auto workspace_alloc = [&workspace, opts](size_t bytes, bool zero_init) -> void* { + workspace = zero_init + ? at::zeros({static_cast(bytes)}, opts.dtype(at::kByte)) + : at::empty({static_cast(bytes)}, opts.dtype(at::kByte)); + return workspace.data_ptr(); + }; + // Required even in batch mode: the launcher stages dq-accumulation metadata via a + // host workspace and H2D-copies it, so prepare_workspace_async needs a pinned + // buffer. The returned shared_ptr owns the lifetime (aiter extends it across the + // stream). No-op against older AITER (field absent). + auto pinned_host_alloc = [](size_t bytes) -> std::shared_ptr { + auto t = std::make_shared(at::empty( + {static_cast(bytes)}, + at::TensorOptions().dtype(at::kByte).pinned_memory(true))); + return std::shared_ptr(t, t->data_ptr()); + }; + set_fmha_bwd_workspace_alloc(args, workspace_alloc); + set_fmha_bwd_pinned_host_alloc(args, pinned_host_alloc); + float t = aiter::mha_bwd(args, stream_config); TORCH_CHECK(t >= 0, "invalid argument for fmha_bwd"); diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_bwd_ck.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_bwd_ck.hip index 896afc7320e96..ddbf8408d2c65 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_bwd_ck.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_bwd_ck.hip @@ -3,34 +3,117 @@ ******************************************************************************/ #include +#include #include #include +#include +#include +#include namespace pytorch_flash { +// SFINAE for version-gated `aiter::mha_bwd_args` members (mirrors mha_bwd_ck.hip, the +// batch-mode counterpart). The struct evolved across AITER snapshots: +// - `sink_ptr` / `d_sink_ptr` were added, +// - older CK carried a separate `ck_mask_type`, +// - the explicit dq accumulation buffer (`dq_acc_ptr` + strides) was replaced by +// kernel-managed scratch via `workspace_alloc` / `pinned_host_alloc` callbacks. +// Value-init the struct, assign the always-present fields, then route the optional +// members through these helpers so a single source compiles against every snapshot. +template +struct has_fmha_bwd_args_sink_fields : std::false_type {}; + +template +struct has_fmha_bwd_args_sink_fields().sink_ptr)>> + : std::true_type {}; + +template +void set_fmha_bwd_sink_ptr_fields([[maybe_unused]] Args &args) +{ + if constexpr(has_fmha_bwd_args_sink_fields::value) { + args.sink_ptr = nullptr; + args.d_sink_ptr = nullptr; + } +} + +template +struct has_fmha_bwd_args_ck_mask_type : std::false_type {}; + +template +struct has_fmha_bwd_args_ck_mask_type().ck_mask_type)>> + : std::true_type {}; + +template +void set_fmha_bwd_ck_mask_type([[maybe_unused]] Args &args, [[maybe_unused]] int ck_mask_type) +{ + if constexpr(has_fmha_bwd_args_ck_mask_type::value) { + args.ck_mask_type = ck_mask_type; + } +>>>>>>> 0f509fb4651 (Add new workspace args) +} + +// Older `aiter::mha_bwd_args` consumes an explicit dq_acc buffer (pointer + strides); +// newer AITER removed these and lets the kernel manage scratch via `workspace_alloc` +// (+ `pinned_host_alloc`, required in group mode). Exactly one regime is present. +template +struct has_fmha_bwd_args_dq_acc : std::false_type {}; + +template +struct has_fmha_bwd_args_dq_acc().dq_acc_ptr)>> + : std::true_type {}; + +template +void set_fmha_bwd_dq_acc_fields([[maybe_unused]] Args &args, + [[maybe_unused]] void *dq_acc_ptr, + [[maybe_unused]] int64_t stride_dq_acc, + [[maybe_unused]] int64_t nhead_stride_dq_acc, + [[maybe_unused]] int64_t batch_stride_dq_acc, + [[maybe_unused]] int64_t split_stride_dq_acc) +{ + if constexpr(has_fmha_bwd_args_dq_acc::value) { + args.dq_acc_ptr = dq_acc_ptr; + args.stride_dq_acc = stride_dq_acc; + args.nhead_stride_dq_acc = nhead_stride_dq_acc; + args.batch_stride_dq_acc = batch_stride_dq_acc; + args.split_stride_dq_acc = split_stride_dq_acc; + } +} -fmha_bwd_traits get_ck_fmha_varlen_bwd_traits(const mask_info &mask, - std::string dtype, - int head_size, - bool has_dropout, - bool enable_bias, - bool deterministic, - bool bias_requires_grad) +template +struct has_fmha_bwd_args_workspace_alloc : std::false_type {}; + +template +struct has_fmha_bwd_args_workspace_alloc().workspace_alloc)>> + : std::true_type {}; + +template +void set_fmha_bwd_workspace_alloc([[maybe_unused]] Args &args, [[maybe_unused]] Fn &&workspace_alloc) { - return fmha_bwd_traits{head_size, - head_size, - dtype, - true, // is_group_mode - mask.type, - enable_bias ? bias_enum::elementwise_bias : bias_enum::no_bias, - bias_requires_grad, // has_dbias - has_dropout, - false, // s_randval - deterministic}; + if constexpr(has_fmha_bwd_args_workspace_alloc::value) { + args.workspace_alloc = std::forward(workspace_alloc); + } } -fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask, +template +struct has_fmha_bwd_args_pinned_host_alloc : std::false_type {}; + +template +struct has_fmha_bwd_args_pinned_host_alloc().pinned_host_alloc)>> + : std::true_type {}; + +template +void set_fmha_bwd_pinned_host_alloc([[maybe_unused]] Args &args, [[maybe_unused]] Fn &&pinned_host_alloc) +{ + if constexpr(has_fmha_bwd_args_pinned_host_alloc::value) { + args.pinned_host_alloc = std::forward(pinned_host_alloc); + } +} + +aiter::mha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask, + std::string dtype, + bool has_dropout, + bool deterministic, // sizes const int b, const int max_seqlen_q, @@ -110,12 +193,6 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask, ck_tile::index_t stride_dv = dv.stride(0); ck_tile::index_t nhead_stride_dv = dv.stride(1); - // dq_acc: (split, total_q, nheads, hdim) - ck_tile::index_t split_stride_dq_acc = dq_acc.stride(0); - ck_tile::long_index_t batch_stride_dq_acc = 0; - ck_tile::index_t stride_dq_acc = dq_acc.stride(1); - ck_tile::long_index_t nhead_stride_dq_acc = dq_acc.stride(2); - float p_undrop = 1.0 - p_dropout; // bias: (batch_size, nheads, seqlen_q, seqlen_k) @@ -148,81 +225,110 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask, batch_stride_dbias = dbias.stride(0); } - return fmha_bwd_args{q.data_ptr(), - k.data_ptr(), - v.data_ptr(), - attn_bias_ptr, // bias - out.data_ptr(), - softmax_lse.data_ptr(), - dout.data_ptr(), - d.data_ptr(), - nullptr, // rand_val - dq.data_ptr(), - dk.data_ptr(), - dv.data_ptr(), - dbias_ptr, // dbias - dq_acc.data_ptr(), // dq_acc - seqlens_q.data_ptr(), // seqstart_q - seqlens_k.data_ptr(), // seqstart_k - nullptr, // seqlen_q_ptr - nullptr, // seqlen_k_ptr - nullptr, // cu_seqlen_q_ptr - nullptr, // cu_seqlen_k_ptr - total_q, - total_k, - b, - max_seqlen_q, // max_seqlen_q - max_seqlen_k, // max_seqlen_k - hdim, // hdim_q - hdim, // hdim_v - h, // nhead - h_k, // nhead_k - softmax_scale, - stride_q, - stride_k, - stride_v, - stride_attn_bias, - stride_o, - 0, // stride_randval - stride_do, - stride_dq_acc, - stride_dq, - stride_dk, - stride_dv, - stride_dbias, // stride_dbias - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_bias, // nhead_stride_bias - nhead_stride_o, - 0, // nhead_stride_randval - nhead_stride_do, - nhead_stride_lse, - nhead_stride_dq_acc, - nhead_stride_dq, - nhead_stride_dk, - nhead_stride_dv, - nhead_stride_dbias, // nhead_stride_dbias - batch_stride_q, - batch_stride_k, - batch_stride_v, - batch_stride_bias, // batch_stride_bias - batch_stride_o, - 0, // batch_stride_randval - batch_stride_do, - batch_stride_lse, - batch_stride_dq_acc, - batch_stride_dq, - batch_stride_dk, - batch_stride_dv, - batch_stride_dbias, // batch_stride_dbias - split_stride_dq_acc, - mask.left, - mask.right, - static_cast(mask.type), - p_dropout, - p_undrop, - drop_seed_offset}; + bool enable_bias = attn_bias_.has_value(); + + aiter::mha_bwd_args args{}; + args.use_asm_v3 = hdim <= 192; // ASM v3 only supports head dim <= 192 + args.v3_atomic_fp32 = true; + args.v3_bf16_cvt = 1; + args.v3_api_check = false; + + args.hdim_q = hdim; + args.hdim_v = hdim; + args.data_type = dtype; + args.is_group_mode = true; + args.mask_type = static_cast(mask.type); + set_fmha_bwd_ck_mask_type(args, static_cast(mask.type)); + args.bias_type = + enable_bias ? static_cast(bias_enum::elementwise_bias) + : static_cast(bias_enum::no_bias); + args.has_dbias = bias_requires_grad; + args.has_dropout = has_dropout; + args.is_store_randval = false; + args.is_deterministic = deterministic; + + args.q_ptr = q.data_ptr(); + args.k_ptr = k.data_ptr(); + args.v_ptr = v.data_ptr(); + args.bias_ptr = attn_bias_ptr; + args.o_ptr = out.data_ptr(); + args.lse_ptr = softmax_lse.data_ptr(); + args.do_ptr = dout.data_ptr(); + args.d_ptr = d.data_ptr(); + args.rand_val_ptr = nullptr; + args.dq_ptr = dq.data_ptr(); + args.dk_ptr = dk.data_ptr(); + args.dv_ptr = dv.data_ptr(); + args.dbias_ptr = dbias_ptr; + set_fmha_bwd_sink_ptr_fields(args); + args.seqstart_q_ptr = seqlens_q.data_ptr(); + args.seqstart_k_ptr = seqlens_k.data_ptr(); + args.seqlen_q_ptr = nullptr; + args.seqlen_k_ptr = nullptr; + args.cu_seqlen_q_ptr = nullptr; + args.cu_seqlen_k_ptr = nullptr; + args.seqlen_q = total_q; + args.seqlen_k = total_k; + args.batch = b; + args.max_seqlen_q = max_seqlen_q; + args.max_seqlen_k = max_seqlen_k; + args.nhead_q = h; + args.nhead_k = h_k; + args.scale = softmax_scale; + args.stride_q = stride_q; + args.stride_k = stride_k; + args.stride_v = stride_v; + args.stride_bias = stride_attn_bias; + args.stride_o = stride_o; + args.stride_randval = 0; + args.stride_do = stride_do; + args.stride_dq = stride_dq; + args.stride_dk = stride_dk; + args.stride_dv = stride_dv; + args.stride_dbias = stride_dbias; + args.nhead_stride_q = nhead_stride_q; + args.nhead_stride_k = nhead_stride_k; + args.nhead_stride_v = nhead_stride_v; + args.nhead_stride_bias = nhead_stride_bias; + args.nhead_stride_o = nhead_stride_o; + args.nhead_stride_randval = 0; + args.nhead_stride_do = nhead_stride_do; + args.nhead_stride_lsed = nhead_stride_lse; + args.nhead_stride_dq = nhead_stride_dq; + args.nhead_stride_dk = nhead_stride_dk; + args.nhead_stride_dv = nhead_stride_dv; + args.nhead_stride_dbias = nhead_stride_dbias; + args.batch_stride_q = batch_stride_q; + args.batch_stride_k = batch_stride_k; + args.batch_stride_v = batch_stride_v; + args.batch_stride_bias = batch_stride_bias; + args.batch_stride_o = batch_stride_o; + args.batch_stride_randval = 0; + args.batch_stride_do = batch_stride_do; + args.batch_stride_lsed = batch_stride_lse; + args.batch_stride_dq = batch_stride_dq; + args.batch_stride_dk = batch_stride_dk; + args.batch_stride_dv = batch_stride_dv; + args.batch_stride_dbias = batch_stride_dbias; + args.window_size_left = mask.left; + args.window_size_right = mask.right; + args.p_drop = p_dropout; + args.p_undrop = p_undrop; + args.drop_seed_offset = drop_seed_offset; + + // Older AITER consumes an explicit dq accumulation buffer; gate so this compiles + // against newer AITER where the kernel manages dq scratch via workspace_alloc. + if constexpr(has_fmha_bwd_args_dq_acc::value) { + // dq_acc: (split, total_q, nheads, hdim) + set_fmha_bwd_dq_acc_fields( + args, + dq_acc.data_ptr(), + dq_acc.stride(1), // stride_dq_acc + dq_acc.stride(2), // nhead_stride_dq_acc + 0, // batch_stride_dq_acc (varlen) + dq_acc.stride(0)); // split_stride_dq_acc + } + return args; } std::tuple @@ -375,12 +481,16 @@ mha_varlen_bwd_ck(const at::Tensor &dout, // total_q x num_hea auto softmax_d = at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); at::Tensor dq_accum; - if (!deterministic) { - dq_accum = at::zeros({1, total_q, num_heads, head_size_8x}, opts.dtype(at::kFloat)); - } else { - const ck_tile::index_t kN0 = head_size_8x <= 128 ? 128 : 64; - const ck_tile::index_t nsplits = ck_tile::integer_divide_ceil(max_seqlen_k, kN0); - dq_accum = at::zeros({nsplits, total_q, num_heads, head_size_8x}, opts.dtype(at::kFloat)); + // Newer AITER manages dq scratch internally (via workspace_alloc); only allocate the + // explicit dq accumulation buffer when the struct still exposes the dq_acc fields. + if constexpr(has_fmha_bwd_args_dq_acc::value) { + if (!deterministic) { + dq_accum = at::zeros({1, total_q, num_heads, head_size_8x}, opts.dtype(at::kFloat)); + } else { + const ck_tile::index_t kN0 = head_size_8x <= 128 ? 128 : 64; + const ck_tile::index_t nsplits = ck_tile::integer_divide_ceil(max_seqlen_k, kN0); + dq_accum = at::zeros({nsplits, total_q, num_heads, head_size_8x}, opts.dtype(at::kFloat)); + } } at::Tensor dk_expanded, dv_expanded; @@ -410,18 +520,13 @@ mha_varlen_bwd_ck(const at::Tensor &dout, // total_q x num_hea if (max_seqlen_q > 0) { ck_tile::stream_config stream_config{stream}; dq.zero_(); // ck use atomic operation on dq - auto traits = - get_ck_fmha_varlen_bwd_traits(mask, - q_dtype_str, - head_size_8x, - is_dropout, - attn_bias_.has_value(), - deterministic, - bias_requires_grad); auto args = get_ck_fmha_varlen_bwd_args( mask, + q_dtype_str, + is_dropout, + deterministic, batch_size, max_seqlen_q, max_seqlen_k, @@ -447,7 +552,29 @@ mha_varlen_bwd_ck(const at::Tensor &dout, // total_q x num_hea softmax_scale, p_dropout, drop_seed_offset); - float t = fmha_bwd(traits, args, stream_config); + + // Group mode requires both callbacks on newer AITER (kernel-managed dq scratch + // and pinned-host staging for the async metadata pipeline). The device workspace + // pointer must stay valid until aiter::mha_bwd returns, so keep its backing + // tensor alive in this scope; the pinned buffer's lifetime is owned by the + // returned shared_ptr (aiter extends it across the stream). No-ops on older AITER. + at::Tensor workspace; + auto workspace_alloc = [&workspace, opts](size_t bytes, bool zero_init) -> void* { + workspace = zero_init + ? at::zeros({static_cast(bytes)}, opts.dtype(at::kByte)) + : at::empty({static_cast(bytes)}, opts.dtype(at::kByte)); + return workspace.data_ptr(); + }; + auto pinned_host_alloc = [](size_t bytes) -> std::shared_ptr { + auto t = std::make_shared(at::empty( + {static_cast(bytes)}, + at::TensorOptions().dtype(at::kByte).pinned_memory(true))); + return std::shared_ptr(t, t->data_ptr()); + }; + set_fmha_bwd_workspace_alloc(args, workspace_alloc); + set_fmha_bwd_pinned_host_alloc(args, pinned_host_alloc); + + float t = aiter::mha_bwd(args, stream_config); TORCH_CHECK(t >= 0, "invalid argument for fmha_bwd"); } else { // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. diff --git a/third_party/aiter b/third_party/aiter index 9a469a608b2c1..971d98b8ed400 160000 --- a/third_party/aiter +++ b/third_party/aiter @@ -1 +1 @@ -Subproject commit 9a469a608b2c10b7157df573a38d31e5bf4038b4 +Subproject commit 971d98b8ed4003639486134a435df7e2ef9c0475 diff --git a/third_party/composable_kernel b/third_party/composable_kernel index 9f4f62da466d0..2c0b7cbb0a618 160000 --- a/third_party/composable_kernel +++ b/third_party/composable_kernel @@ -1 +1 @@ -Subproject commit 9f4f62da466d092a5e62a0af062cabf8a4577417 +Subproject commit 2c0b7cbb0a618ed52f7d2f5263baaffa640b6500