Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions include/cutlass/fast_math.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#include "cutlass/uint128.h"
#include "cutlass/coord.h"
#include "cutlass/half.h"
#include "cutlass/bfloat16.h"

/**
* \file
Expand Down Expand Up @@ -900,6 +901,15 @@ half_t fast_exp(half_t x) {
#endif
}

CUTLASS_HOST_DEVICE
bfloat16_t fast_exp(bfloat16_t x) {
#if defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 11) && (__CUDA_ARCH__ >= 800)
return bfloat16_t(::hexp(reinterpret_cast<__nv_bfloat16 const &>(x.storage)));
#else
return bfloat16_t(fast_exp(float(x)));
#endif
}

CUTLASS_HOST_DEVICE
float fast_log(float x) {
#if defined(__CUDA_ARCH__)
Expand Down Expand Up @@ -954,6 +964,17 @@ half_t fast_tanh(half_t x) {
#endif
}

CUTLASS_HOST_DEVICE
bfloat16_t fast_tanh(bfloat16_t x) {
#if defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDA_ARCH__ >= 900)
uint16_t bits = x.storage;
asm volatile("tanh.approx.bf16 %0, %1;" : "=h"(bits) : "h"(bits));
return bfloat16_t::bitcast(bits);
#else
return bfloat16_t(fast_tanh(float(x)));
#endif
}

/////////////////////////////////////////////////////////////////////////////////////////////////

template <typename T>
Expand Down Expand Up @@ -992,6 +1013,34 @@ struct fast_exp_op<Array<half_t, N>> {
};
#endif // #if defined(__CUDA_ARCH__)

#if defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 11) && (__CUDA_ARCH__ >= 800)
template <int N>
struct fast_exp_op<Array<bfloat16_t, N>> {
CUTLASS_DEVICE
Array<bfloat16_t, N> operator()(Array<bfloat16_t, N> const &rhs) const {

Array<bfloat16_t, N> result;

// use x2 specialization
__nv_bfloat162 const *in = reinterpret_cast<__nv_bfloat162 const *>(&rhs);
__nv_bfloat162 *out = reinterpret_cast<__nv_bfloat162 *>(&result);

CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / 2; ++i) {
out[i] = ::h2exp(in[i]);
}

// residual
if (N % 2) {
bfloat16_t last = rhs[N - 1];
result[N - 1] = bfloat16_t(::hexp(last.to_nv_bfloat16()));
}

return result;
}
};
#endif // #if defined(__CUDA_ARCH__)

template <typename T, int N>
struct fast_exp_op<Array<T, N>> {
CUTLASS_HOST_DEVICE
Expand Down Expand Up @@ -1048,6 +1097,35 @@ struct fast_tanh_op<Array<half_t, N>> {
};
#endif // #if defined(__CUDA_ARCH__)

#if defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDA_ARCH__ >= 900)
template <int N>
struct fast_tanh_op<Array<bfloat16_t, N>> {
CUTLASS_DEVICE
Array<bfloat16_t, N> operator()(Array<bfloat16_t, N> const &rhs) const {

Array<bfloat16_t, N> result;

// use x2 specialization
uint32_t const *in = reinterpret_cast<uint32_t const *>(&rhs);
uint32_t *out = reinterpret_cast<uint32_t *>(&result);

CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / 2; ++i) {
asm volatile("tanh.approx.bf16x2 %0, %1;" : "=r"(out[i]) : "r"(in[i]));
}

// residual — use distinct names (in_raw, out_raw) to avoid shadowing the uint32_t pointers above
if (N % 2) {
uint16_t const *in_raw = reinterpret_cast<uint16_t const *>(&rhs);
uint16_t *out_raw = reinterpret_cast<uint16_t *>(&result);
asm volatile("tanh.approx.bf16 %0, %1;" : "=h"(out_raw[N - 1]) : "h"(in_raw[N - 1]));
}

return result;
}
};
#endif // #if defined(__CUDA_ARCH__)

template <typename T, int N>
struct fast_tanh_op<Array<T, N>> {
CUTLASS_HOST_DEVICE
Expand Down
173 changes: 172 additions & 1 deletion test/unit/epilogue/thread/activation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

#include "cutlass/layout/layout.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/fast_math.h"

#include "cutlass/util/host_tensor.h"

Expand Down Expand Up @@ -644,9 +645,179 @@ TEST(Epilogue_thread_gelu_taylor, device_f16) {
case 207: tolerance_override = 0.15; break;
}

EXPECT_LT(std::abs(rel_error), tolerance_override)
EXPECT_LT(std::abs(rel_error), tolerance_override)
<< "Input[" << i << "]: " << input << ", Got: " << got << ", expected: " << expected;
}
}

/////////////////////////////////////////////////////////////////////////////////////////////////
//
// Regression tests: bfloat16_t fast_exp and fast_tanh overloads
//
/////////////////////////////////////////////////////////////////////////////////////////////////

TEST(Fast_math_bfloat16, fast_exp_zero) {
cutlass::bfloat16_t x(0.0f);
cutlass::bfloat16_t y = cutlass::fast_exp(x);
EXPECT_NEAR(float(y), 1.0f, 1e-3f);
}

TEST(Fast_math_bfloat16, fast_exp_one) {
cutlass::bfloat16_t x(1.0f);
cutlass::bfloat16_t y = cutlass::fast_exp(x);
float ref = std::exp(1.0f);
EXPECT_NEAR(float(y), ref, ref * 2e-2f);
}

TEST(Fast_math_bfloat16, fast_exp_neg_one) {
cutlass::bfloat16_t x(-1.0f);
cutlass::bfloat16_t y = cutlass::fast_exp(x);
float ref = std::exp(-1.0f);
EXPECT_NEAR(float(y), ref, std::abs(ref) * 2e-2f);
}

TEST(Fast_math_bfloat16, fast_tanh_zero) {
cutlass::bfloat16_t x(0.0f);
cutlass::bfloat16_t y = cutlass::fast_tanh(x);
EXPECT_NEAR(float(y), 0.0f, 1e-3f);
}

TEST(Fast_math_bfloat16, fast_tanh_one) {
cutlass::bfloat16_t x(1.0f);
cutlass::bfloat16_t y = cutlass::fast_tanh(x);
float ref = std::tanh(1.0f);
EXPECT_NEAR(float(y), ref, std::abs(ref) * 2e-2f);
}

// On host, the bfloat16-specific device specialization is #ifdef'd out; the generic
// Array<T,N> fallback is selected, invoking fast_exp(bfloat16_t) element-wise.
TEST(Fast_math_bfloat16, fast_exp_op_array4_host) {
float inputs[4] = {0.0f, 1.0f, -1.0f, 2.0f};
float refs[4] = {std::exp(0.0f), std::exp(1.0f), std::exp(-1.0f), std::exp(2.0f)};

cutlass::fast_exp_op<cutlass::Array<cutlass::bfloat16_t, 4>> op;
cutlass::Array<cutlass::bfloat16_t, 4> in_arr, out_arr;
for (int i = 0; i < 4; ++i) in_arr[i] = cutlass::bfloat16_t(inputs[i]);
out_arr = op(in_arr);

for (int i = 0; i < 4; ++i)
EXPECT_NEAR(float(out_arr[i]), refs[i], std::abs(refs[i]) * 2e-2f) << "element " << i;
}

// On host, the bfloat16-specific device specialization is #ifdef'd out; the generic
// Array<T,N> fallback is selected, invoking fast_tanh(bfloat16_t) element-wise.
TEST(Fast_math_bfloat16, fast_tanh_op_array4_host) {
float inputs[4] = {0.0f, 1.0f, -1.0f, 2.0f};
float refs[4] = {std::tanh(0.0f), std::tanh(1.0f), std::tanh(-1.0f), std::tanh(2.0f)};

cutlass::fast_tanh_op<cutlass::Array<cutlass::bfloat16_t, 4>> op;
cutlass::Array<cutlass::bfloat16_t, 4> in_arr, out_arr;
for (int i = 0; i < 4; ++i) in_arr[i] = cutlass::bfloat16_t(inputs[i]);
out_arr = op(in_arr);

for (int i = 0; i < 4; ++i)
EXPECT_NEAR(float(out_arr[i]), refs[i], std::abs(refs[i]) * 2e-2f + 1e-6f) << "element " << i;
}

// N=1 exercises the generic scalar fallback on host (device specialization requires __CUDA_ARCH__).
// Device-side residual asm path coverage requires a device-side unit test.
TEST(Fast_math_bfloat16, fast_exp_op_array1_odd_residual) {
cutlass::fast_exp_op<cutlass::Array<cutlass::bfloat16_t, 1>> op;
cutlass::Array<cutlass::bfloat16_t, 1> in_arr, out_arr;
in_arr[0] = cutlass::bfloat16_t(1.0f);
out_arr = op(in_arr);
EXPECT_NEAR(float(out_arr[0]), std::exp(1.0f), std::exp(1.0f) * 2e-2f);
}

TEST(Fast_math_bfloat16, fast_tanh_op_array1_odd_residual) {
cutlass::fast_tanh_op<cutlass::Array<cutlass::bfloat16_t, 1>> op;
cutlass::Array<cutlass::bfloat16_t, 1> in_arr, out_arr;
in_arr[0] = cutlass::bfloat16_t(1.0f);
out_arr = op(in_arr);
float ref = std::tanh(1.0f);
EXPECT_NEAR(float(out_arr[0]), ref, std::abs(ref) * 2e-2f);
}

/////////////////////////////////////////////////////////////////////////////////////////////////
//
// Device tests: bfloat16_t fast_exp_op and fast_tanh_op (PTX specialization paths)
//
/////////////////////////////////////////////////////////////////////////////////////////////////

TEST(Fast_math_bfloat16, device_fast_exp_op_array8_sm80) {
int dev;
cudaDeviceProp prop;
cudaGetDevice(&dev);
cudaGetDeviceProperties(&prop, dev);
if (prop.major < 8) {
GTEST_SKIP() << "Requires SM80+ for ::h2exp(__nv_bfloat162) path";
}

int const kN = 128;
int const kV = 8;

using Element = cutlass::bfloat16_t;
using Func = cutlass::fast_exp_op<cutlass::Array<Element, kV>>;

cutlass::HostTensor<Element, cutlass::layout::RowMajor> tensor_Destination({1, kN});
cutlass::HostTensor<Element, cutlass::layout::RowMajor> tensor_Source({1, kN});

for (int i = 0; i < kN; ++i)
tensor_Source.host_data(i) = Element(-2.0f + 4.0f * float(i) / float(kN - 1));

tensor_Destination.sync_device();
tensor_Source.sync_device();

dim3 grid(1, 1, 1);
dim3 block(kN / kV, 1, 1);
test_Epilogue_thread_activation<Element, kV, Func><<<grid, block>>>(
tensor_Destination.device_data(), tensor_Source.device_data());
tensor_Destination.sync_host();

for (int i = 0; i < kN; ++i) {
float v = -2.0f + 4.0f * float(i) / float(kN - 1);
float got = float(tensor_Destination.host_data(i));
float ref = std::exp(v);
EXPECT_NEAR(got, ref, std::abs(ref) * 2e-2f) << "element " << i;
}
}

TEST(Fast_math_bfloat16, device_fast_tanh_op_array8_sm90) {
int dev;
cudaDeviceProp prop;
cudaGetDevice(&dev);
cudaGetDeviceProperties(&prop, dev);
if (prop.major < 9) {
GTEST_SKIP() << "Requires SM90+ for tanh.approx.bf16x2 path";
}

int const kN = 128;
int const kV = 8;

using Element = cutlass::bfloat16_t;
using Func = cutlass::fast_tanh_op<cutlass::Array<Element, kV>>;

cutlass::HostTensor<Element, cutlass::layout::RowMajor> tensor_Destination({1, kN});
cutlass::HostTensor<Element, cutlass::layout::RowMajor> tensor_Source({1, kN});

for (int i = 0; i < kN; ++i)
tensor_Source.host_data(i) = Element(-2.0f + 4.0f * float(i) / float(kN - 1));

tensor_Destination.sync_device();
tensor_Source.sync_device();

dim3 grid(1, 1, 1);
dim3 block(kN / kV, 1, 1);
test_Epilogue_thread_activation<Element, kV, Func><<<grid, block>>>(
tensor_Destination.device_data(), tensor_Source.device_data());
tensor_Destination.sync_host();

for (int i = 0; i < kN; ++i) {
float v = -2.0f + 4.0f * float(i) / float(kN - 1);
float got = float(tensor_Destination.host_data(i));
float ref = std::tanh(v);
EXPECT_NEAR(got, ref, std::abs(ref) * 2e-2f + 1e-4f) << "element " << i;
}
}

/////////////////////////////////////////////////////////////////////////////////////////////////