From d7ae9b3b4571fd45aba981320d7c030c8305ed82 Mon Sep 17 00:00:00 2001 From: Vittoria Lanzo Date: Fri, 15 May 2026 23:04:21 +0200 Subject: [PATCH 1/2] [fast_math] Add bfloat16_t PTX specializations for fast_exp and fast_tanh --- include/cutlass/fast_math.h | 78 +++++++++++++++++++++ test/unit/epilogue/thread/activation.cu | 91 ++++++++++++++++++++++++- 2 files changed, 168 insertions(+), 1 deletion(-) diff --git a/include/cutlass/fast_math.h b/include/cutlass/fast_math.h index 8fa30f925f..20823236b6 100644 --- a/include/cutlass/fast_math.h +++ b/include/cutlass/fast_math.h @@ -45,6 +45,7 @@ #include "cutlass/uint128.h" #include "cutlass/coord.h" #include "cutlass/half.h" +#include "cutlass/bfloat16.h" /** * \file @@ -900,6 +901,15 @@ half_t fast_exp(half_t x) { #endif } +CUTLASS_HOST_DEVICE +bfloat16_t fast_exp(bfloat16_t x) { + #if defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 11) && (__CUDA_ARCH__ >= 800) + return bfloat16_t(::hexp(reinterpret_cast<__nv_bfloat16 const &>(x.storage))); + #else + return bfloat16_t(fast_exp(float(x))); + #endif +} + CUTLASS_HOST_DEVICE float fast_log(float x) { #if defined(__CUDA_ARCH__) @@ -954,6 +964,17 @@ half_t fast_tanh(half_t x) { #endif } +CUTLASS_HOST_DEVICE +bfloat16_t fast_tanh(bfloat16_t x) { + #if defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDA_ARCH__ >= 900) + uint16_t bits = x.storage; + asm volatile("tanh.approx.bf16 %0, %1;" : "=h"(bits) : "h"(bits)); + return bfloat16_t::bitcast(bits); + #else + return bfloat16_t(fast_tanh(float(x))); + #endif +} + ///////////////////////////////////////////////////////////////////////////////////////////////// template @@ -992,6 +1013,34 @@ struct fast_exp_op> { }; #endif // #if defined(__CUDA_ARCH__) +#if defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 11) && (__CUDA_ARCH__ >= 800) +template +struct fast_exp_op> { + CUTLASS_DEVICE + Array operator()(Array const &rhs) const { + + Array result; + + // use x2 specialization + __nv_bfloat162 const *in = reinterpret_cast<__nv_bfloat162 const *>(&rhs); + __nv_bfloat162 *out = reinterpret_cast<__nv_bfloat162 *>(&result); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + out[i] = ::h2exp(in[i]); + } + + // residual + if (N % 2) { + bfloat16_t last = rhs[N - 1]; + result[N - 1] = bfloat16_t(::hexp(last.to_nv_bfloat16())); + } + + return result; + } +}; +#endif // #if defined(__CUDA_ARCH__) + template struct fast_exp_op> { CUTLASS_HOST_DEVICE @@ -1048,6 +1097,35 @@ struct fast_tanh_op> { }; #endif // #if defined(__CUDA_ARCH__) +#if defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDA_ARCH__ >= 900) +template +struct fast_tanh_op> { + CUTLASS_DEVICE + Array operator()(Array const &rhs) const { + + Array result; + + // use x2 specialization + uint32_t const *in = reinterpret_cast(&rhs); + uint32_t *out = reinterpret_cast(&result); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + asm volatile("tanh.approx.bf16x2 %0, %1;" : "=r"(out[i]) : "r"(in[i])); + } + + // residual — use distinct names (in_raw, out_raw) to avoid shadowing the uint32_t pointers above + if (N % 2) { + uint16_t const *in_raw = reinterpret_cast(&rhs); + uint16_t *out_raw = reinterpret_cast(&result); + asm volatile("tanh.approx.bf16 %0, %1;" : "=h"(out_raw[N - 1]) : "h"(in_raw[N - 1])); + } + + return result; + } +}; +#endif // #if defined(__CUDA_ARCH__) + template struct fast_tanh_op> { CUTLASS_HOST_DEVICE diff --git a/test/unit/epilogue/thread/activation.cu b/test/unit/epilogue/thread/activation.cu index e747d003cf..3ae88a912c 100644 --- a/test/unit/epilogue/thread/activation.cu +++ b/test/unit/epilogue/thread/activation.cu @@ -36,6 +36,7 @@ #include "cutlass/layout/layout.h" #include "cutlass/epilogue/thread/activation.h" +#include "cutlass/fast_math.h" #include "cutlass/util/host_tensor.h" @@ -644,9 +645,97 @@ TEST(Epilogue_thread_gelu_taylor, device_f16) { case 207: tolerance_override = 0.15; break; } - EXPECT_LT(std::abs(rel_error), tolerance_override) + EXPECT_LT(std::abs(rel_error), tolerance_override) << "Input[" << i << "]: " << input << ", Got: " << got << ", expected: " << expected; } } ///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Regression tests: bfloat16_t fast_exp and fast_tanh overloads +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Fast_math_bfloat16, fast_exp_zero) { + cutlass::bfloat16_t x(0.0f); + cutlass::bfloat16_t y = cutlass::fast_exp(x); + EXPECT_NEAR(float(y), 1.0f, 1e-3f); +} + +TEST(Fast_math_bfloat16, fast_exp_one) { + cutlass::bfloat16_t x(1.0f); + cutlass::bfloat16_t y = cutlass::fast_exp(x); + float ref = std::exp(1.0f); + EXPECT_NEAR(float(y), ref, ref * 2e-2f); +} + +TEST(Fast_math_bfloat16, fast_exp_neg_one) { + cutlass::bfloat16_t x(-1.0f); + cutlass::bfloat16_t y = cutlass::fast_exp(x); + float ref = std::exp(-1.0f); + EXPECT_NEAR(float(y), ref, std::abs(ref) * 2e-2f); +} + +TEST(Fast_math_bfloat16, fast_tanh_zero) { + cutlass::bfloat16_t x(0.0f); + cutlass::bfloat16_t y = cutlass::fast_tanh(x); + EXPECT_NEAR(float(y), 0.0f, 1e-3f); +} + +TEST(Fast_math_bfloat16, fast_tanh_one) { + cutlass::bfloat16_t x(1.0f); + cutlass::bfloat16_t y = cutlass::fast_tanh(x); + float ref = std::tanh(1.0f); + EXPECT_NEAR(float(y), ref, std::abs(ref) * 2e-2f); +} + +// On host, the bfloat16-specific device specialization is #ifdef'd out; the generic +// Array fallback is selected, invoking fast_exp(bfloat16_t) element-wise. +TEST(Fast_math_bfloat16, fast_exp_op_array4_host) { + float inputs[4] = {0.0f, 1.0f, -1.0f, 2.0f}; + float refs[4] = {std::exp(0.0f), std::exp(1.0f), std::exp(-1.0f), std::exp(2.0f)}; + + cutlass::fast_exp_op> op; + cutlass::Array in_arr, out_arr; + for (int i = 0; i < 4; ++i) in_arr[i] = cutlass::bfloat16_t(inputs[i]); + out_arr = op(in_arr); + + for (int i = 0; i < 4; ++i) + EXPECT_NEAR(float(out_arr[i]), refs[i], std::abs(refs[i]) * 2e-2f) << "element " << i; +} + +// On host, the bfloat16-specific device specialization is #ifdef'd out; the generic +// Array fallback is selected, invoking fast_tanh(bfloat16_t) element-wise. +TEST(Fast_math_bfloat16, fast_tanh_op_array4_host) { + float inputs[4] = {0.0f, 1.0f, -1.0f, 2.0f}; + float refs[4] = {std::tanh(0.0f), std::tanh(1.0f), std::tanh(-1.0f), std::tanh(2.0f)}; + + cutlass::fast_tanh_op> op; + cutlass::Array in_arr, out_arr; + for (int i = 0; i < 4; ++i) in_arr[i] = cutlass::bfloat16_t(inputs[i]); + out_arr = op(in_arr); + + for (int i = 0; i < 4; ++i) + EXPECT_NEAR(float(out_arr[i]), refs[i], std::abs(refs[i]) * 2e-2f + 1e-6f) << "element " << i; +} + +// N=1 exercises the generic scalar fallback on host (device specialization requires __CUDA_ARCH__). +// Device-side residual asm path coverage requires a device-side unit test. +TEST(Fast_math_bfloat16, fast_exp_op_array1_odd_residual) { + cutlass::fast_exp_op> op; + cutlass::Array in_arr, out_arr; + in_arr[0] = cutlass::bfloat16_t(1.0f); + out_arr = op(in_arr); + EXPECT_NEAR(float(out_arr[0]), std::exp(1.0f), std::exp(1.0f) * 2e-2f); +} + +TEST(Fast_math_bfloat16, fast_tanh_op_array1_odd_residual) { + cutlass::fast_tanh_op> op; + cutlass::Array in_arr, out_arr; + in_arr[0] = cutlass::bfloat16_t(1.0f); + out_arr = op(in_arr); + float ref = std::tanh(1.0f); + EXPECT_NEAR(float(out_arr[0]), ref, std::abs(ref) * 2e-2f); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// From 73a4b05ef452893540a063cf98634815c90329e1 Mon Sep 17 00:00:00 2001 From: Vittoria Lanzo Date: Sat, 16 May 2026 07:22:41 +0200 Subject: [PATCH 2/2] [fast_math] Add device tests for bfloat16_t fast_exp_op and fast_tanh_op Two new GTest cases in Fast_math_bfloat16: - device_fast_exp_op_array8_sm80: runs fast_exp_op> on device via test_Epilogue_thread_activation; skips below SM80 - device_fast_tanh_op_array8_sm90: runs fast_tanh_op> on device; skips below SM90 (tanh.approx.bf16x2 path) --- test/unit/epilogue/thread/activation.cu | 82 +++++++++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/test/unit/epilogue/thread/activation.cu b/test/unit/epilogue/thread/activation.cu index 3ae88a912c..09427e9df8 100644 --- a/test/unit/epilogue/thread/activation.cu +++ b/test/unit/epilogue/thread/activation.cu @@ -739,3 +739,85 @@ TEST(Fast_math_bfloat16, fast_tanh_op_array1_odd_residual) { } ///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Device tests: bfloat16_t fast_exp_op and fast_tanh_op (PTX specialization paths) +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Fast_math_bfloat16, device_fast_exp_op_array8_sm80) { + int dev; + cudaDeviceProp prop; + cudaGetDevice(&dev); + cudaGetDeviceProperties(&prop, dev); + if (prop.major < 8) { + GTEST_SKIP() << "Requires SM80+ for ::h2exp(__nv_bfloat162) path"; + } + + int const kN = 128; + int const kV = 8; + + using Element = cutlass::bfloat16_t; + using Func = cutlass::fast_exp_op>; + + cutlass::HostTensor tensor_Destination({1, kN}); + cutlass::HostTensor tensor_Source({1, kN}); + + for (int i = 0; i < kN; ++i) + tensor_Source.host_data(i) = Element(-2.0f + 4.0f * float(i) / float(kN - 1)); + + tensor_Destination.sync_device(); + tensor_Source.sync_device(); + + dim3 grid(1, 1, 1); + dim3 block(kN / kV, 1, 1); + test_Epilogue_thread_activation<<>>( + tensor_Destination.device_data(), tensor_Source.device_data()); + tensor_Destination.sync_host(); + + for (int i = 0; i < kN; ++i) { + float v = -2.0f + 4.0f * float(i) / float(kN - 1); + float got = float(tensor_Destination.host_data(i)); + float ref = std::exp(v); + EXPECT_NEAR(got, ref, std::abs(ref) * 2e-2f) << "element " << i; + } +} + +TEST(Fast_math_bfloat16, device_fast_tanh_op_array8_sm90) { + int dev; + cudaDeviceProp prop; + cudaGetDevice(&dev); + cudaGetDeviceProperties(&prop, dev); + if (prop.major < 9) { + GTEST_SKIP() << "Requires SM90+ for tanh.approx.bf16x2 path"; + } + + int const kN = 128; + int const kV = 8; + + using Element = cutlass::bfloat16_t; + using Func = cutlass::fast_tanh_op>; + + cutlass::HostTensor tensor_Destination({1, kN}); + cutlass::HostTensor tensor_Source({1, kN}); + + for (int i = 0; i < kN; ++i) + tensor_Source.host_data(i) = Element(-2.0f + 4.0f * float(i) / float(kN - 1)); + + tensor_Destination.sync_device(); + tensor_Source.sync_device(); + + dim3 grid(1, 1, 1); + dim3 block(kN / kV, 1, 1); + test_Epilogue_thread_activation<<>>( + tensor_Destination.device_data(), tensor_Source.device_data()); + tensor_Destination.sync_host(); + + for (int i = 0; i < kN; ++i) { + float v = -2.0f + 4.0f * float(i) / float(kN - 1); + float got = float(tensor_Destination.host_data(i)); + float ref = std::tanh(v); + EXPECT_NEAR(got, ref, std::abs(ref) * 2e-2f + 1e-4f) << "element " << i; + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////