diff --git a/rust/lance-linalg/src/distance/cosine.rs b/rust/lance-linalg/src/distance/cosine.rs index 995191b77eb..5326780a6ab 100644 --- a/rust/lance-linalg/src/distance/cosine.rs +++ b/rust/lance-linalg/src/distance/cosine.rs @@ -168,7 +168,7 @@ impl Cosine for f16 { kernel::cosine_f16_avx512(x.as_ptr(), x_norm, y.as_ptr(), y.len() as u32) }, #[cfg(all(feature = "fp16kernels", target_arch = "x86_64"))] - SimdSupport::Avx2 => unsafe { + SimdSupport::Avx2 | SimdSupport::Avx512 => unsafe { kernel::cosine_f16_avx2(x.as_ptr(), x_norm, y.as_ptr(), y.len() as u32) }, #[cfg(all(feature = "fp16kernels", target_arch = "loongarch64"))] diff --git a/rust/lance-linalg/src/distance/dot.rs b/rust/lance-linalg/src/distance/dot.rs index 5903d24e0e5..8cf5406b9d4 100644 --- a/rust/lance-linalg/src/distance/dot.rs +++ b/rust/lance-linalg/src/distance/dot.rs @@ -203,7 +203,7 @@ impl Dot for f16 { kernel::dot_f16_avx512(x.as_ptr(), y.as_ptr(), x.len() as u32) }, #[cfg(all(feature = "fp16kernels", target_arch = "x86_64"))] - SimdSupport::Avx2 => unsafe { + SimdSupport::Avx2 | SimdSupport::Avx512 => unsafe { kernel::dot_f16_avx2(x.as_ptr(), y.as_ptr(), x.len() as u32) }, #[cfg(all(feature = "fp16kernels", target_arch = "loongarch64"))] diff --git a/rust/lance-linalg/src/distance/l2.rs b/rust/lance-linalg/src/distance/l2.rs index c47aedd749f..b479adb9458 100644 --- a/rust/lance-linalg/src/distance/l2.rs +++ b/rust/lance-linalg/src/distance/l2.rs @@ -233,7 +233,7 @@ impl L2 for f16 { kernel::l2_f16_avx512(x.as_ptr(), y.as_ptr(), x.len() as u32) }, #[cfg(all(feature = "fp16kernels", target_arch = "x86_64"))] - SimdSupport::Avx2 => unsafe { + SimdSupport::Avx2 | SimdSupport::Avx512 => unsafe { kernel::l2_f16_avx2(x.as_ptr(), y.as_ptr(), x.len() as u32) }, #[cfg(all(feature = "fp16kernels", target_arch = "loongarch64"))] diff --git a/rust/lance-linalg/src/distance/norm_l2.rs b/rust/lance-linalg/src/distance/norm_l2.rs index b1daf85ab3b..08e62860931 100644 --- a/rust/lance-linalg/src/distance/norm_l2.rs +++ b/rust/lance-linalg/src/distance/norm_l2.rs @@ -64,7 +64,7 @@ impl Normalize for f16 { kernel::norm_l2_f16_avx512(vector.as_ptr(), vector.len() as u32) }, #[cfg(all(feature = "fp16kernels", target_arch = "x86_64"))] - SimdSupport::Avx2 => unsafe { + SimdSupport::Avx2 | SimdSupport::Avx512 => unsafe { kernel::norm_l2_f16_avx2(vector.as_ptr(), vector.len() as u32) }, #[cfg(all(feature = "fp16kernels", target_arch = "loongarch64"))]