Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
303 changes: 217 additions & 86 deletions aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_bwd_ck.hip
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,110 @@
#include <fmha_bwd.hpp>
#include <mask.hpp>

#include <memory>
#include <type_traits>
#include <utility>

namespace pytorch_flash {

// SFINAE for newer composable_kernel `fmha_bwd.hpp` vs older CK (see mha_fwd_ck.hip).
// `fmha_bwd_args` / `aiter::mha_bwd_args` may gain `sink_ptr` / `d_sink_ptr` after `dq_acc_ptr`.
// Same pattern as `mha_fwd_ck.hip`: value-init, assign fields, then `set_fmha_bwd_sink_ptr_fields`.
template <typename T, typename = void>
struct has_fmha_bwd_args_sink_fields : std::false_type {};

template <typename T>
struct has_fmha_bwd_args_sink_fields<T, std::void_t<decltype(std::declval<T&>().sink_ptr)>>
: std::true_type {};

template <typename Args>
void set_fmha_bwd_sink_ptr_fields([[maybe_unused]] Args &args)
{
if constexpr(has_fmha_bwd_args_sink_fields<Args>::value) {
args.sink_ptr = nullptr;
args.d_sink_ptr = nullptr;
}
}

// SFINAE for an older composable_kernel `fmha_bwd.hpp` snapshot (see mha_fwd_ck.hip).
// In older CK, `aiter::mha_bwd_args` carried a `ck_mask_type` field in its `fmha_bwd_traits`
// block (alongside the aiter-level `mask_type`); newer CK renamed/collapsed it into `mask_type`.
template <typename T, typename = void>
struct has_fmha_bwd_args_ck_mask_type : std::false_type {};

template <typename T>
struct has_fmha_bwd_args_ck_mask_type<T, std::void_t<decltype(std::declval<T&>().ck_mask_type)>>
: std::true_type {};

template <typename Args>
void set_fmha_bwd_ck_mask_type([[maybe_unused]] Args &args, [[maybe_unused]] int ck_mask_type)
{
if constexpr(has_fmha_bwd_args_ck_mask_type<Args>::value) {
args.ck_mask_type = ck_mask_type;
}
}

// SFINAE for the dq accumulation interface. Older `aiter::mha_bwd_args` consumes an
// explicit dq_acc buffer (pointer + strides); newer AITER removed these fields and
// instead lets the kernel manage scratch internally via a `workspace_alloc` callback.
// Exactly one regime is present in a given submodule snapshot, so gate both.
template <typename T, typename = void>
struct has_fmha_bwd_args_dq_acc : std::false_type {};

template <typename T>
struct has_fmha_bwd_args_dq_acc<T, std::void_t<decltype(std::declval<T&>().dq_acc_ptr)>>
: std::true_type {};

template <typename Args>
void set_fmha_bwd_dq_acc_fields([[maybe_unused]] Args &args,
[[maybe_unused]] void *dq_acc_ptr,
[[maybe_unused]] int64_t stride_dq_acc,
[[maybe_unused]] int64_t nhead_stride_dq_acc,
[[maybe_unused]] int64_t batch_stride_dq_acc,
[[maybe_unused]] int64_t split_stride_dq_acc)
{
if constexpr(has_fmha_bwd_args_dq_acc<Args>::value) {
args.dq_acc_ptr = dq_acc_ptr;
args.stride_dq_acc = stride_dq_acc;
args.nhead_stride_dq_acc = nhead_stride_dq_acc;
args.batch_stride_dq_acc = batch_stride_dq_acc;
args.split_stride_dq_acc = split_stride_dq_acc;
}
}

template <typename T, typename = void>
struct has_fmha_bwd_args_workspace_alloc : std::false_type {};

template <typename T>
struct has_fmha_bwd_args_workspace_alloc<T, std::void_t<decltype(std::declval<T&>().workspace_alloc)>>
: std::true_type {};

template <typename Args, typename Fn>
void set_fmha_bwd_workspace_alloc([[maybe_unused]] Args &args, [[maybe_unused]] Fn &&workspace_alloc)
{
if constexpr(has_fmha_bwd_args_workspace_alloc<Args>::value) {
args.workspace_alloc = std::forward<Fn>(workspace_alloc);
}
}

// Even in batch mode the CK backward launcher stages dq-accumulation metadata through a
// host workspace (host_ws_size_ is padded to 4 KiB for the non-QrQtrDor pipeline), so
// prepare_workspace_async requires a pinned-host allocator. Gate it like workspace_alloc.
template <typename T, typename = void>
struct has_fmha_bwd_args_pinned_host_alloc : std::false_type {};

template <typename T>
struct has_fmha_bwd_args_pinned_host_alloc<T, std::void_t<decltype(std::declval<T&>().pinned_host_alloc)>>
: std::true_type {};

template <typename Args, typename Fn>
void set_fmha_bwd_pinned_host_alloc([[maybe_unused]] Args &args, [[maybe_unused]] Fn &&pinned_host_alloc)
{
if constexpr(has_fmha_bwd_args_pinned_host_alloc<Args>::value) {
args.pinned_host_alloc = std::forward<Fn>(pinned_host_alloc);
}
}

aiter::mha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
std::string dtype,
bool has_dropout,
Expand Down Expand Up @@ -86,12 +188,6 @@ aiter::mha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
ck_tile::index_t stride_dv = dv.stride(1);
ck_tile::index_t nhead_stride_dv = dv.stride(2);

// dq_acc: (split, batch_size, nheads, seqlen_q, hdim)
ck_tile::index_t split_stride_dq_acc = dq_acc.stride(0);
ck_tile::long_index_t batch_stride_dq_acc = dq_acc.stride(1);
ck_tile::index_t stride_dq_acc = dq_acc.stride(3);
ck_tile::long_index_t nhead_stride_dq_acc = dq_acc.stride(2);

// bias: (batch_size, nheads, seqlen_q, seqlen_k)
void *attn_bias_ptr = nullptr;
ck_tile::index_t nhead_stride_bias = 0;
Expand Down Expand Up @@ -144,80 +240,88 @@ aiter::mha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
false, // is_store_randval
deterministic, // is_deterministic

// From ck fmha_bwd_args
q.data_ptr(),
k.data_ptr(),
v.data_ptr(),
attn_bias_ptr,
out.data_ptr(), // o_ptr
softmax_lse.data_ptr(), // lse_ptr
dout.data_ptr(), // do_ptr
d.data_ptr(),
nullptr, // rand_val_ptr
dq.data_ptr(),
dk.data_ptr(),
dv.data_ptr(),
dbias_ptr,
dq_acc.data_ptr(), // dq_acc_ptr
nullptr, // seqstart_q_ptr
nullptr, // seqstart_k_ptr
nullptr, // seqlen_q_ptr
nullptr, // seqlen_k_ptr
nullptr, // cu_seqlen_q_ptr
nullptr, // cu_seqlen_k_ptr
seqlen_q,
seqlen_k,
b, // batch
seqlen_q, // max_seqlen_q
seqlen_k, // max_seqlen_k
h, // nhead_q
h_k, // nhead_k
softmax_scale, // scale
stride_q,
stride_k,
stride_v,
stride_attn_bias, // stride_bias
stride_o,
0, // stride_randval
stride_do,
stride_dq_acc,
stride_dq,
stride_dk,
stride_dv,
stride_dbias,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_o,
0, // nhead_stride_randval
nhead_stride_do,
nhead_stride_lse,
nhead_stride_dq_acc,
nhead_stride_dq,
nhead_stride_dk,
nhead_stride_dv,
nhead_stride_dbias,
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_bias,
batch_stride_o,
0, // batch_stride_randval
batch_stride_do,
batch_stride_lse,
batch_stride_dq_acc,
batch_stride_dq,
batch_stride_dk,
batch_stride_dv,
batch_stride_dbias,
split_stride_dq_acc,
mask.left, // window_size_left
mask.right, // window_size_right
p_dropout, // p_drop
p_undrop,
drop_seed_offset
};
args.q_ptr = q.data_ptr();
args.k_ptr = k.data_ptr();
args.v_ptr = v.data_ptr();
args.bias_ptr = attn_bias_ptr;
args.o_ptr = out.data_ptr();
args.lse_ptr = softmax_lse.data_ptr();
args.do_ptr = dout.data_ptr();
args.d_ptr = d.data_ptr();
args.rand_val_ptr = nullptr;
args.dq_ptr = dq.data_ptr();
args.dk_ptr = dk.data_ptr();
args.dv_ptr = dv.data_ptr();
args.dbias_ptr = dbias_ptr;
set_fmha_bwd_sink_ptr_fields(args);
args.seqstart_q_ptr = nullptr;
args.seqstart_k_ptr = nullptr;
args.seqlen_q_ptr = nullptr;
args.seqlen_k_ptr = nullptr;
args.cu_seqlen_q_ptr = nullptr;
args.cu_seqlen_k_ptr = nullptr;
args.seqlen_q = seqlen_q;
args.seqlen_k = seqlen_k;
args.batch = b;
args.max_seqlen_q = seqlen_q;
args.max_seqlen_k = seqlen_k;
args.nhead_q = h;
args.nhead_k = h_k;
args.scale = softmax_scale;
args.stride_q = stride_q;
args.stride_k = stride_k;
args.stride_v = stride_v;
args.stride_bias = stride_attn_bias;
args.stride_o = stride_o;
args.stride_randval = 0;
args.stride_do = stride_do;
args.stride_dq = stride_dq;
args.stride_dk = stride_dk;
args.stride_dv = stride_dv;
args.stride_dbias = stride_dbias;
args.nhead_stride_q = nhead_stride_q;
args.nhead_stride_k = nhead_stride_k;
args.nhead_stride_v = nhead_stride_v;
args.nhead_stride_bias = nhead_stride_bias;
args.nhead_stride_o = nhead_stride_o;
args.nhead_stride_randval = 0;
args.nhead_stride_do = nhead_stride_do;
args.nhead_stride_lsed = nhead_stride_lse;
args.nhead_stride_dq = nhead_stride_dq;
args.nhead_stride_dk = nhead_stride_dk;
args.nhead_stride_dv = nhead_stride_dv;
args.nhead_stride_dbias = nhead_stride_dbias;
args.batch_stride_q = batch_stride_q;
args.batch_stride_k = batch_stride_k;
args.batch_stride_v = batch_stride_v;
args.batch_stride_bias = batch_stride_bias;
args.batch_stride_o = batch_stride_o;
args.batch_stride_randval = 0;
args.batch_stride_do = batch_stride_do;
args.batch_stride_lsed = batch_stride_lse;
args.batch_stride_dq = batch_stride_dq;
args.batch_stride_dk = batch_stride_dk;
args.batch_stride_dv = batch_stride_dv;
args.batch_stride_dbias = batch_stride_dbias;
args.window_size_left = mask.left;
args.window_size_right = mask.right;
args.p_drop = p_dropout;
args.p_undrop = p_undrop;
args.drop_seed_offset = drop_seed_offset;

// Older AITER consumes an explicit dq accumulation buffer; gate so this compiles
// against newer AITER where the kernel manages dq scratch via workspace_alloc.
if constexpr(has_fmha_bwd_args_dq_acc<aiter::mha_bwd_args>::value) {
// dq_acc: (split, batch_size, nheads, seqlen_q, hdim)
set_fmha_bwd_dq_acc_fields(
args,
dq_acc.data_ptr(),
dq_acc.stride(3), // stride_dq_acc
dq_acc.stride(2), // nhead_stride_dq_acc
dq_acc.stride(1), // batch_stride_dq_acc
dq_acc.stride(0)); // split_stride_dq_acc
}
return args;
}

std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
Expand Down Expand Up @@ -358,12 +462,16 @@ mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x
auto softmax_d = at::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
at::Tensor dq_accum;

if (!deterministic) {
dq_accum = at::zeros({1, batch_size, num_heads, seqlen_q, head_size_8x}, opts.dtype(at::kFloat));
} else {
const ck_tile::index_t kN0 = head_size_8x <= 128 ? 128 : 64;
const ck_tile::index_t nsplits = ck_tile::integer_divide_ceil(seqlen_k, kN0);
dq_accum = at::zeros({nsplits, batch_size, num_heads, seqlen_q, head_size_8x}, opts.dtype(at::kFloat));
// Newer AITER manages dq scratch internally (via workspace_alloc); only allocate the
// explicit dq accumulation buffer when the struct still exposes the dq_acc fields.
if constexpr(has_fmha_bwd_args_dq_acc<aiter::mha_bwd_args>::value) {
if (!deterministic) {
dq_accum = at::zeros({1, batch_size, num_heads, seqlen_q, head_size_8x}, opts.dtype(at::kFloat));
} else {
const ck_tile::index_t kN0 = head_size_8x <= 128 ? 128 : 64;
const ck_tile::index_t nsplits = ck_tile::integer_divide_ceil(seqlen_k, kN0);
dq_accum = at::zeros({nsplits, batch_size, num_heads, seqlen_q, head_size_8x}, opts.dtype(at::kFloat));
}
}

at::Tensor dk_expanded, dv_expanded;
Expand Down Expand Up @@ -421,6 +529,29 @@ mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x
p_dropout,
drop_seed_offset);

// Device scratch allocator for newer AITER (kernel-managed dq workspace). The
// returned pointer must stay valid until aiter::mha_bwd returns, so keep the
// backing tensor alive in this scope. No-op against older AITER (field absent).
at::Tensor workspace;
auto workspace_alloc = [&workspace, opts](size_t bytes, bool zero_init) -> void* {
workspace = zero_init
? at::zeros({static_cast<int64_t>(bytes)}, opts.dtype(at::kByte))
: at::empty({static_cast<int64_t>(bytes)}, opts.dtype(at::kByte));
return workspace.data_ptr();
};
// Required even in batch mode: the launcher stages dq-accumulation metadata via a
// host workspace and H2D-copies it, so prepare_workspace_async needs a pinned
// buffer. The returned shared_ptr owns the lifetime (aiter extends it across the
// stream). No-op against older AITER (field absent).
auto pinned_host_alloc = [](size_t bytes) -> std::shared_ptr<void> {
auto t = std::make_shared<at::Tensor>(at::empty(
{static_cast<int64_t>(bytes)},
at::TensorOptions().dtype(at::kByte).pinned_memory(true)));
return std::shared_ptr<void>(t, t->data_ptr());
};
set_fmha_bwd_workspace_alloc(args, workspace_alloc);
set_fmha_bwd_pinned_host_alloc(args, pinned_host_alloc);

float t = aiter::mha_bwd(args, stream_config);

TORCH_CHECK(t >= 0, "invalid argument for fmha_bwd");
Expand Down
Loading