Skip to content

Commit 5ed4003

Browse files
authored
Add support for bfloat16 and float8 tensors
Add the following extensions: GL_ARM_tensors_bfloat16 GL_ARM_tensors_float_e5m2 GL_ARM_tensors_float_e4m3 Also add GL_ARM_tensors extension definitions to preamble
1 parent c2ebbbd commit 5ed4003

10 files changed

Lines changed: 250 additions & 2 deletions

Test/baseResults/spv.tensorARM.all_accesses.comp.out

Lines changed: 129 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
spv.tensorARM.all_accesses.comp
22
// Module Version 10000
33
// Generated by (magic number): 8000b
4-
// Id's are bound by 199
4+
// Id's are bound by 276
55

66
Capability Shader
77
Capability Float16
@@ -10,13 +10,23 @@ spv.tensorARM.all_accesses.comp
1010
Capability Int16
1111
Capability Int8
1212
Capability TensorsARM
13+
Capability Float8EXT
14+
Capability BFloat16TypeKHR
1315
Extension "SPV_ARM_tensors"
16+
Extension "SPV_EXT_float8"
17+
Extension "SPV_KHR_bfloat16"
1418
1: ExtInstImport "GLSL.std.450"
1519
MemoryModel Logical GLSL450
1620
EntryPoint GLCompute 4 "main"
1721
ExecutionMode 4 LocalSize 1 1 1
1822
Source GLSL 460
1923
SourceExtension "GL_ARM_tensors"
24+
SourceExtension "GL_ARM_tensors_bfloat16"
25+
SourceExtension "GL_ARM_tensors_float_e4m3"
26+
SourceExtension "GL_ARM_tensors_float_e5m2"
27+
SourceExtension "GL_EXT_bfloat16"
28+
SourceExtension "GL_EXT_float_e4m3"
29+
SourceExtension "GL_EXT_float_e5m2"
2030
SourceExtension "GL_EXT_shader_explicit_arithmetic_types"
2131
Name 4 "main"
2232
Name 10 "coords"
@@ -59,6 +69,22 @@ spv.tensorARM.all_accesses.comp
5969
Name 189 "fr"
6070
Name 193 "fr64"
6171
Name 197 "frv"
72+
Name 201 "bf16w"
73+
Name 205 "fe5m2w"
74+
Name 209 "fe4m3w"
75+
Name 213 "bft16"
76+
Name 219 "fte5m2"
77+
Name 225 "fte4m3"
78+
Name 231 "bf16wv"
79+
Name 238 "fe5m2wv"
80+
Name 243 "fe4m3one"
81+
Name 246 "fe4m3wv"
82+
Name 254 "bf16r"
83+
Name 258 "fe5m2r"
84+
Name 262 "fe4m3r"
85+
Name 266 "bf16rv"
86+
Name 270 "fe5m2rv"
87+
Name 274 "fe4m3rv"
6288
Decorate 31(it8) Binding 0
6389
Decorate 31(it8) DescriptorSet 0
6490
Decorate 37(it16) Binding 1
@@ -81,6 +107,12 @@ spv.tensorARM.all_accesses.comp
81107
Decorate 166(ft32) DescriptorSet 0
82108
Decorate 172(ft64) Binding 10
83109
Decorate 172(ft64) DescriptorSet 0
110+
Decorate 213(bft16) Binding 11
111+
Decorate 213(bft16) DescriptorSet 0
112+
Decorate 219(fte5m2) Binding 12
113+
Decorate 219(fte5m2) DescriptorSet 0
114+
Decorate 225(fte4m3) Binding 13
115+
Decorate 225(fte4m3) DescriptorSet 0
84116
2: TypeVoid
85117
3: TypeFunction 2
86118
6: TypeInt 32 0
@@ -163,6 +195,32 @@ spv.tensorARM.all_accesses.comp
163195
176: TypeArray 146(float) 53
164196
177: TypePointer Function 176
165197
179: 176 ConstantComposite 149 149 149 149
198+
199: TypeFloat 16 0
199+
200: TypePointer Function 199(bfloat16_t)
200+
202:199(bfloat16_t) Constant 16256
201+
203: TypeFloat 8 4215
202+
204: TypePointer Function 203(floate5m2_t)
203+
206:203(floate5m2_t) Constant 60
204+
207: TypeFloat 8 4214
205+
208: TypePointer Function 207(floate4m3_t)
206+
210:207(floate4m3_t) Constant 56
207+
211: TypeTensorARM 199(bfloat16_t) 7
208+
212: TypePointer UniformConstant 211
209+
213(bft16): 212(ptr) Variable UniformConstant
210+
217: TypeTensorARM 203(floate5m2_t) 7
211+
218: TypePointer UniformConstant 217
212+
219(fte5m2): 218(ptr) Variable UniformConstant
213+
223: TypeTensorARM 207(floate4m3_t) 7
214+
224: TypePointer UniformConstant 223
215+
225(fte4m3): 224(ptr) Variable UniformConstant
216+
229: TypeArray 199(bfloat16_t) 53
217+
230: TypePointer Function 229
218+
232: 229 ConstantComposite 202 202 202 202
219+
236: TypeArray 203(floate5m2_t) 53
220+
237: TypePointer Function 236
221+
239: 236 ConstantComposite 206 206 206 206
222+
244: TypeArray 207(floate4m3_t) 53
223+
245: TypePointer Function 244
166224
4(main): 2 Function None 3
167225
5: Label
168226
10(coords): 9(ptr) Variable Function
@@ -194,6 +252,19 @@ spv.tensorARM.all_accesses.comp
194252
189(fr): 147(ptr) Variable Function
195253
193(fr64): 155(ptr) Variable Function
196254
197(frv): 177(ptr) Variable Function
255+
201(bf16w): 200(ptr) Variable Function
256+
205(fe5m2w): 204(ptr) Variable Function
257+
209(fe4m3w): 208(ptr) Variable Function
258+
231(bf16wv): 230(ptr) Variable Function
259+
238(fe5m2wv): 237(ptr) Variable Function
260+
243(fe4m3one): 208(ptr) Variable Function
261+
246(fe4m3wv): 245(ptr) Variable Function
262+
254(bf16r): 200(ptr) Variable Function
263+
258(fe5m2r): 204(ptr) Variable Function
264+
262(fe4m3r): 208(ptr) Variable Function
265+
266(bf16rv): 230(ptr) Variable Function
266+
270(fe5m2rv): 237(ptr) Variable Function
267+
274(fe4m3rv): 245(ptr) Variable Function
197268
Store 10(coords) 12
198269
Store 15(iw8) 16
199270
Store 19(iw16) 20
@@ -321,5 +392,62 @@ spv.tensorARM.all_accesses.comp
321392
196: 8 Load 10(coords)
322393
198: 176 TensorReadARM 195 196
323394
Store 197(frv) 198
395+
Store 201(bf16w) 202
396+
Store 205(fe5m2w) 206
397+
Store 209(fe4m3w) 210
398+
214: 211 Load 213(bft16)
399+
215: 8 Load 10(coords)
400+
216:199(bfloat16_t) Load 201(bf16w)
401+
TensorWriteARM 214 215 216
402+
220: 217 Load 219(fte5m2)
403+
221: 8 Load 10(coords)
404+
222:203(floate5m2_t) Load 205(fe5m2w)
405+
TensorWriteARM 220 221 222
406+
226: 223 Load 225(fte4m3)
407+
227: 8 Load 10(coords)
408+
228:207(floate4m3_t) Load 209(fe4m3w)
409+
TensorWriteARM 226 227 228
410+
Store 231(bf16wv) 232
411+
233: 211 Load 213(bft16)
412+
234: 8 Load 10(coords)
413+
235: 229 Load 231(bf16wv)
414+
TensorWriteARM 233 234 235
415+
Store 238(fe5m2wv) 239
416+
240: 217 Load 219(fte5m2)
417+
241: 8 Load 10(coords)
418+
242: 236 Load 238(fe5m2wv)
419+
TensorWriteARM 240 241 242
420+
Store 243(fe4m3one) 210
421+
247:207(floate4m3_t) Load 243(fe4m3one)
422+
248: 244 CompositeConstruct 247 247 247 247
423+
Store 246(fe4m3wv) 248
424+
249: 223 Load 225(fte4m3)
425+
250: 8 Load 10(coords)
426+
251: 244 Load 246(fe4m3wv)
427+
TensorWriteARM 249 250 251
428+
252: 211 Load 213(bft16)
429+
253: 8 Load 10(coords)
430+
255:199(bfloat16_t) TensorReadARM 252 253
431+
Store 254(bf16r) 255
432+
256: 217 Load 219(fte5m2)
433+
257: 8 Load 10(coords)
434+
259:203(floate5m2_t) TensorReadARM 256 257
435+
Store 258(fe5m2r) 259
436+
260: 223 Load 225(fte4m3)
437+
261: 8 Load 10(coords)
438+
263:207(floate4m3_t) TensorReadARM 260 261
439+
Store 262(fe4m3r) 263
440+
264: 211 Load 213(bft16)
441+
265: 8 Load 10(coords)
442+
267: 229 TensorReadARM 264 265
443+
Store 266(bf16rv) 267
444+
268: 217 Load 219(fte5m2)
445+
269: 8 Load 10(coords)
446+
271: 236 TensorReadARM 268 269
447+
Store 270(fe5m2rv) 271
448+
272: 223 Load 225(fte4m3)
449+
273: 8 Load 10(coords)
450+
275: 244 TensorReadARM 272 273
451+
Store 274(fe4m3rv) 275
324452
Return
325453
FunctionEnd
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
spv.tensorARM.unrequested_extension_types.comp
2+
ERROR: 0:8: 'tensor with bfloat16_t type' : required extension not requested: GL_ARM_tensors_bfloat16
3+
ERROR: 0:9: 'tensor with floate5m2_t type' : required extension not requested: GL_ARM_tensors_float_e5m2
4+
ERROR: 0:10: 'tensor with floate4m3_t type' : required extension not requested: GL_ARM_tensors_float_e4m3
5+
ERROR: 3 compilation errors. No code generated.
6+
7+
8+
SPIR-V is not generated for failed compile or link

Test/spv.tensorARM.all_accesses.comp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22

33
#extension GL_ARM_tensors : enable
44
#extension GL_EXT_shader_explicit_arithmetic_types : enable
5+
#extension GL_EXT_bfloat16 : enable
6+
#extension GL_EXT_float_e5m2 : enable
7+
#extension GL_EXT_float_e4m3 : enable
8+
#extension GL_ARM_tensors_bfloat16 : enable
9+
#extension GL_ARM_tensors_float_e5m2 : enable
10+
#extension GL_ARM_tensors_float_e4m3 : enable
511

612
// Int types
713
uniform tensorARM<int8_t, 1> it8;
@@ -20,6 +26,11 @@ uniform tensorARM<float16_t, 1> ft16;
2026
uniform tensorARM<float, 1> ft32;
2127
uniform tensorARM<float64_t, 1> ft64;
2228

29+
// bfloat16, e5m2 and e4m3 types
30+
uniform tensorARM<bfloat16_t, 1> bft16;
31+
uniform tensorARM<floate5m2_t, 1> fte5m2;
32+
uniform tensorARM<floate4m3_t, 1> fte4m3;
33+
2334
void main() {
2435
uint coords[] = {0};
2536

@@ -87,4 +98,38 @@ void main() {
8798
tensorReadARM(ft64, coords, fr64);
8899
float frv[4];
89100
tensorReadARM(ft32, coords, frv);
101+
102+
// bfloat16, e5m2 and e4m3 types
103+
bfloat16_t bf16w = bfloat16_t(1.0);
104+
floate5m2_t fe5m2w = floate5m2_t(1.0);
105+
floate4m3_t fe4m3w = floate4m3_t(1.0);
106+
tensorWriteARM(bft16, coords, bf16w);
107+
tensorWriteARM(fte5m2, coords, fe5m2w);
108+
tensorWriteARM(fte4m3, coords, fe4m3w);
109+
110+
const bfloat16_t bf16one = bfloat16_t(1.0);
111+
bfloat16_t bf16wv[4] = bfloat16_t[](bf16one, bf16one, bf16one, bf16one);
112+
tensorWriteARM(bft16, coords, bf16wv);
113+
114+
const floate5m2_t fe5m2one = floate5m2_t(1.0);
115+
floate5m2_t fe5m2wv[4] = floate5m2_t[](fe5m2one, fe5m2one, fe5m2one, fe5m2one);
116+
tensorWriteARM(fte5m2, coords, fe5m2wv);
117+
118+
floate4m3_t fe4m3one = floate4m3_t(1.0);
119+
floate4m3_t fe4m3wv[4] = floate4m3_t[](fe4m3one, fe4m3one, fe4m3one, fe4m3one);
120+
tensorWriteARM(fte4m3, coords, fe4m3wv);
121+
122+
bfloat16_t bf16r;
123+
floate5m2_t fe5m2r;
124+
floate4m3_t fe4m3r;
125+
tensorReadARM(bft16, coords, bf16r);
126+
tensorReadARM(fte5m2, coords, fe5m2r);
127+
tensorReadARM(fte4m3, coords, fe4m3r);
128+
129+
bfloat16_t bf16rv[4];
130+
floate5m2_t fe5m2rv[4];
131+
floate4m3_t fe4m3rv[4];
132+
tensorReadARM(bft16, coords, bf16rv);
133+
tensorReadARM(fte5m2, coords, fe5m2rv);
134+
tensorReadARM(fte4m3, coords, fe4m3rv);
90135
}

Test/spv.tensorARM.declare.comp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,30 @@
33
#extension GL_ARM_tensors : enable
44
#extension GL_EXT_shader_explicit_arithmetic_types : enable
55

6+
#if !defined GL_ARM_tensors
7+
# error GL_ARM_tensors is not defined
8+
#elif GL_ARM_tensors != 1
9+
# error GL_ARM_tensors is not equal to 1
10+
#endif
11+
12+
#if !defined GL_ARM_tensors_bfloat16
13+
# error GL_ARM_tensors_bfloat16 is not defined
14+
#elif GL_ARM_tensors_bfloat16 != 1
15+
# error GL_ARM_tensors_bfloat16 is not equal to 1
16+
#endif
17+
18+
#if !defined GL_ARM_tensors_float_e5m2
19+
# error GL_ARM_tensors_float_e5m2 is not defined
20+
#elif GL_ARM_tensors_float_e5m2 != 1
21+
# error GL_ARM_tensors_float_e5m2 is not equal to 1
22+
#endif
23+
24+
#if !defined GL_ARM_tensors_float_e4m3
25+
# error GL_ARM_tensors_float_e4m3 is not defined
26+
#elif GL_ARM_tensors_float_e4m3 != 1
27+
# error GL_ARM_tensors_float_e4m3 is not equal to 1
28+
#endif
29+
630
layout(binding = 0) uniform tensorARM<int32_t, 4> t;
731
layout(set = 0, binding = 1) uniform tensorARM<bool, 2> tb;
832

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#version 460 core
2+
3+
#extension GL_ARM_tensors : enable
4+
#extension GL_EXT_bfloat16 : enable
5+
#extension GL_EXT_float_e5m2 : enable
6+
#extension GL_EXT_float_e4m3 : enable
7+
8+
uniform tensorARM<bfloat16_t, 1> bfloat16;
9+
uniform tensorARM<floate5m2_t, 1> floate5m2;
10+
uniform tensorARM<floate4m3_t, 1> floate4m3;
11+
12+
void main() {}

glslang/MachineIndependent/Initialize.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5074,6 +5074,7 @@ void TBuiltIns::initialize(int version, EProfile profile, const SpvVersion& spvV
50745074
"int8_t", "int16_t", "int32_t", "int64_t",
50755075
"uint8_t", "uint16_t", "uint32_t", "uint64_t",
50765076
"float16_t", "float32_t", "float64_t",
5077+
"bfloat16_t", "floate5m2_t", "floate4m3_t",
50775078
};
50785079
std::ostringstream ostream;
50795080
for (auto t : tensorDataTypesARM) {

glslang/MachineIndependent/ParseHelper.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9262,10 +9262,24 @@ TIntermNode* TParseContext::declareVariable(const TSourceLoc& loc, TString& iden
92629262
publicType.typeParameters->basicType != EbtUint64 &&
92639263
publicType.typeParameters->basicType != EbtFloat16 &&
92649264
publicType.typeParameters->basicType != EbtFloat &&
9265-
publicType.typeParameters->basicType != EbtDouble) {
9265+
publicType.typeParameters->basicType != EbtDouble &&
9266+
publicType.typeParameters->basicType != EbtBFloat16 &&
9267+
publicType.typeParameters->basicType != EbtFloatE5M2 &&
9268+
publicType.typeParameters->basicType != EbtFloatE4M3) {
92669269
error(loc, "expected bool, integer or floating point type parameter", identifier.c_str(), "");
92679270
}
92689271

9272+
if (publicType.typeParameters->basicType == EbtBFloat16) {
9273+
requireExtensions(loc, 1, &E_GL_ARM_tensors_bfloat16, "tensor with bfloat16_t type");
9274+
}
9275+
9276+
if (publicType.typeParameters->basicType == EbtFloatE5M2) {
9277+
requireExtensions(loc, 1, &E_GL_ARM_tensors_float_e5m2, "tensor with floate5m2_t type");
9278+
}
9279+
9280+
if (publicType.typeParameters->basicType == EbtFloatE4M3) {
9281+
requireExtensions(loc, 1, &E_GL_ARM_tensors_float_e4m3, "tensor with floate4m3_t type");
9282+
}
92699283
}
92709284
} else {
92719285
if (publicType.typeParameters && publicType.typeParameters->arraySizes->getNumDims() != 0) {

glslang/MachineIndependent/Versions.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,9 @@ void TParseVersions::initializeExtensionBehavior()
323323
// ARM
324324
extensionBehavior[E_GL_ARM_shader_core_builtins] = EBhDisable;
325325
extensionBehavior[E_GL_ARM_tensors] = EBhDisable;
326+
extensionBehavior[E_GL_ARM_tensors_bfloat16] = EBhDisable;
327+
extensionBehavior[E_GL_ARM_tensors_float_e5m2] = EBhDisable;
328+
extensionBehavior[E_GL_ARM_tensors_float_e4m3] = EBhDisable;
326329

327330
// QCOM
328331
extensionBehavior[E_GL_QCOM_image_processing] = EBhDisable;
@@ -667,6 +670,15 @@ void TParseVersions::getPreamble(std::string& preamble)
667670
if (version >= 130) {
668671
preamble +="#define GL_FRAGMENT_PRECISION_HIGH 1\n";
669672
}
673+
674+
if (version >= 460) {
675+
preamble +=
676+
"#define GL_ARM_tensors 1\n"
677+
"#define GL_ARM_tensors_bfloat16 1\n"
678+
"#define GL_ARM_tensors_float_e5m2 1\n"
679+
"#define GL_ARM_tensors_float_e4m3 1\n"
680+
;
681+
}
670682
}
671683

672684
if ((!isEsProfile() && version >= 460) ||

glslang/MachineIndependent/Versions.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,9 @@ const char* const E_GL_NV_explicit_typecast = "GL_NV_explici
301301
// ARM
302302
const char* const E_GL_ARM_shader_core_builtins = "GL_ARM_shader_core_builtins";
303303
const char* const E_GL_ARM_tensors = "GL_ARM_tensors";
304+
const char* const E_GL_ARM_tensors_bfloat16 = "GL_ARM_tensors_bfloat16";
305+
const char* const E_GL_ARM_tensors_float_e5m2 = "GL_ARM_tensors_float_e5m2";
306+
const char* const E_GL_ARM_tensors_float_e4m3 = "GL_ARM_tensors_float_e4m3";
304307

305308
// Arrays of extensions for the above viewportEXTs duplications
306309

gtests/Spv.FromFile.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,7 @@ INSTANTIATE_TEST_SUITE_P(
604604
"spv.tensorARM.params.comp",
605605
"spv.tensorARM.read.comp",
606606
"spv.tensorARM.size.comp",
607+
"spv.tensorARM.unrequested_extension_types.comp",
607608
})),
608609
FileNameAsCustomTestSuffix
609610
);

0 commit comments

Comments
 (0)