diff --git a/README.md b/README.md index 5b68e11a..7ca18c9e 100644 --- a/README.md +++ b/README.md @@ -139,13 +139,13 @@ global_var GV0 in spv.StorageClass.Output: s32 func F0() -> spv.OpTypeVoid { loop(v0: s32 <- 1s32, v1: s32 <- 1s32) { - v2 = spv.OpSLessThan(v1, 10s32): bool + v2 = s.lt(v1, 10s32): bool (v3: s32, v4: s32) = if v2 { - v5 = spv.OpIMul(v0, v1): s32 - v6 = spv.OpIAdd(v1, 1s32): s32 + v5 = i.mul(v0, v1): s32 + v6 = i.add(v1, 1s32): s32 (v5, v6) } else { - (spv.OpUndef: s32, spv.OpUndef: s32) + (undef: s32, undef: s32) } (v3, v4) -> (v0, v1) } while v2 diff --git a/src/cfg.rs b/src/cfg.rs index de8a6298..6a09cff8 100644 --- a/src/cfg.rs +++ b/src/cfg.rs @@ -1,15 +1,13 @@ //! Control-flow graph (CFG) abstractions and utilities. use crate::{ - spv, AttrSet, Const, ConstDef, ConstKind, Context, ControlNode, ControlNodeDef, + scalar, spv, AttrSet, Const, ConstDef, ConstKind, Context, ControlNode, ControlNodeDef, ControlNodeKind, ControlNodeOutputDecl, ControlRegion, ControlRegionDef, - EntityOrientedDenseMap, FuncDefBody, FxIndexMap, FxIndexSet, SelectionKind, Type, TypeKind, - Value, + EntityOrientedDenseMap, FuncDefBody, FxIndexMap, FxIndexSet, SelectionKind, Type, Value, }; use itertools::{Either, Itertools}; use smallvec::SmallVec; use std::mem; -use std::rc::Rc; /// The control-flow graph (CFG) of a function, as control-flow instructions /// ([`ControlInst`]s) attached to [`ControlRegion`]s, as an "action on exit", i.e. @@ -593,32 +591,9 @@ struct PartialControlRegion { impl<'a> Structurizer<'a> { pub fn new(cx: &'a Context, func_def_body: &'a mut FuncDefBody) -> Self { - // FIXME(eddyb) SPIR-T should have native booleans itself. - let wk = &spv::spec::Spec::get().well_known; - let type_bool = cx.intern(TypeKind::SpvInst { - spv_inst: wk.OpTypeBool.into(), - type_and_const_inputs: [].into_iter().collect(), - }); - let const_true = cx.intern(ConstDef { - attrs: AttrSet::default(), - ty: type_bool, - kind: ConstKind::SpvInst { - spv_inst_and_const_inputs: Rc::new(( - wk.OpConstantTrue.into(), - [].into_iter().collect(), - )), - }, - }); - let const_false = cx.intern(ConstDef { - attrs: AttrSet::default(), - ty: type_bool, - kind: ConstKind::SpvInst { - spv_inst_and_const_inputs: Rc::new(( - wk.OpConstantFalse.into(), - [].into_iter().collect(), - )), - }, - }); + let type_bool = cx.intern(scalar::Type::Bool); + let const_true = cx.intern(scalar::Const::TRUE); + let const_false = cx.intern(scalar::Const::FALSE); let (loop_header_to_exit_targets, incoming_edge_counts_including_loop_exits) = func_def_body @@ -1568,14 +1543,6 @@ impl<'a> Structurizer<'a> { /// Create an undefined constant (as a placeholder where a value needs to be /// present, but won't actually be used), of type `ty`. fn const_undef(&self, ty: Type) -> Const { - // FIXME(eddyb) SPIR-T should have native undef itself. - let wk = &spv::spec::Spec::get().well_known; - self.cx.intern(ConstDef { - attrs: AttrSet::default(), - ty, - kind: ConstKind::SpvInst { - spv_inst_and_const_inputs: Rc::new((wk.OpUndef.into(), [].into_iter().collect())), - }, - }) + self.cx.intern(ConstDef { attrs: AttrSet::default(), ty, kind: ConstKind::Undef }) } } diff --git a/src/lib.rs b/src/lib.rs index 7723c951..88aeeca8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -168,7 +168,9 @@ pub mod passes { pub mod qptr; } pub mod qptr; +pub mod scalar; pub mod spv; +pub mod vector; use smallvec::SmallVec; use std::borrow::Cow; @@ -453,16 +455,30 @@ impl Ord for OrdAssertEq { pub use context::Type; /// Definition for a [`Type`]. -// -// FIXME(eddyb) maybe special-case some basic types like integers. #[derive(PartialEq, Eq, Hash)] pub struct TypeDef { pub attrs: AttrSet, pub kind: TypeKind, } -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Clone, PartialEq, Eq, Hash, derive_more::From)] pub enum TypeKind { + /// Scalar (`bool`, integer, and floating-point) type, with limitations + /// on the supported bit-widths (power-of-two multiples of a byte). + /// + /// **Note**: pointers are never scalars (like SPIR-V, but unlike other IRs). + /// + /// See also the [`scalar`] module for more documentation and definitions. + #[from] + Scalar(scalar::Type), + + /// Vector (small array of [`scalar`]s) type, with some limitations on the + /// supported component counts (but all standard ones should be included). + /// + /// See also the [`vector`] module for more documentation and definitions. + #[from] + Vector(vector::Type), + /// "Quasi-pointer", an untyped pointer-like abstract scalar that can represent /// both memory locations (in any address space) and other kinds of locations /// (e.g. SPIR-V `OpVariable`s in non-memory "storage classes"). @@ -490,12 +506,18 @@ pub enum TypeKind { SpvStringLiteralForExtInst, } -// HACK(eddyb) this behaves like an implicit conversion for `cx.intern(...)`. -impl context::InternInCx for TypeKind { - fn intern_in_cx(self, cx: &Context) -> Type { - cx.intern(TypeDef { attrs: Default::default(), kind: self }) +// HACK(eddyb) this behaves like an implicit conversion for `cx.intern(...)`, +// and the macro is only used because coherence bans `impl>`. +macro_rules! impl_intern_type_kind { + ($($kind:ty),+ $(,)?) => { + $(impl context::InternInCx for $kind { + fn intern_in_cx(self, cx: &Context) -> Type { + cx.intern(TypeDef { attrs: Default::default(), kind: self.into() }) + } + })+ } } +impl_intern_type_kind!(TypeKind, scalar::Type, vector::Type); // HACK(eddyb) this is like `Either`, only used in `TypeKind::SpvInst`, // and only because SPIR-V type definitions can references both types and consts. @@ -505,6 +527,22 @@ pub enum TypeOrConst { Const(Const), } +// HACK(eddyb) on `Type` instead of `TypeDef` for ergonomics reasons. +impl Type { + pub fn as_scalar(self, cx: &Context) -> Option { + match cx[self].kind { + TypeKind::Scalar(ty) => Some(ty), + _ => None, + } + } + pub fn as_vector(self, cx: &Context) -> Option { + match cx[self].kind { + TypeKind::Vector(ty) => Some(ty), + _ => None, + } + } +} + /// Interned handle for a [`ConstDef`](crate::ConstDef) (a constant value). pub use context::Const; @@ -518,8 +556,38 @@ pub struct ConstDef { pub kind: ConstKind, } -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Clone, PartialEq, Eq, Hash, derive_more::From)] pub enum ConstKind { + /// Undeterminate value (i.e. SPIR-V `OpUndef`, LLVM `undef`). + // + // FIXME(eddyb) could it be possible to adopt LLVM's newer `poison`+`freeze` + // model, without being forced to never lift back to `OpUndef`? + Undef, + + /// Scalar (`bool`, integer, and floating-point) constant, which must have + /// a type of [`TypeKind::Scalar`] (of the same [`scalar::Type`]). + /// + /// See also the [`scalar`] module for more documentation and definitions. + // + // FIXME(eddyb) maybe document the 128-bit limitation?. + // FIXME(eddyb) this technically makes the `scalar::Type` redundant, could + // it get out of sync? (perhaps "forced canonicalization" could be used to + // enforce that interning simply doesn't allow such scenarios?). + #[from] + Scalar(scalar::Const), + + /// Vector (small array of [`scalar`]s) constant, which must have + /// a type of [`TypeKind::Vector`] (of the same [`vector::Type`]). + /// + /// See also the [`vector`] module for more documentation and definitions. + // + // FIXME(eddyb) maybe document the 128-bit limitation inherited from `scalar::Const`? + // FIXME(eddyb) this technically makes the `vector::Type` redundant, could + // it get out of sync? (perhaps "forced canonicalization" could be used to + // enforce that interning simply doesn't allow such scenarios?). + #[from] + Vector(vector::Const), + PtrToGlobalVar(GlobalVar), // HACK(eddyb) this is a fallback case that should become increasingly rare @@ -534,6 +602,40 @@ pub enum ConstKind { SpvStringLiteralForExtInst(InternedStr), } +// HACK(eddyb) this behaves like an implicit conversion for `cx.intern(...)`, +// like the `TypeKind` one, but this one is even weirder because it also interns +// the inherent type of the constant, as a `Type` (with empty attributes). +macro_rules! impl_intern_const_kind { + ($($kind:ty),+ $(,)?) => { + $(impl context::InternInCx for $kind { + fn intern_in_cx(self, cx: &Context) -> Const { + cx.intern(ConstDef { + attrs: Default::default(), + ty: cx.intern(self.ty()), + kind: self.into(), + }) + } + })+ + } +} +impl_intern_const_kind!(scalar::Const, vector::Const); + +// HACK(eddyb) on `Const` instead of `ConstDef` for ergonomics reasons. +impl Const { + pub fn as_scalar(self, cx: &Context) -> Option<&scalar::Const> { + match &cx[self].kind { + ConstKind::Scalar(ct) => Some(ct), + _ => None, + } + } + pub fn as_vector(self, cx: &Context) -> Option<&vector::Const> { + match &cx[self].kind { + ConstKind::Vector(ct) => Some(ct), + _ => None, + } + } +} + /// Declarations ([`GlobalVarDecl`], [`FuncDecl`]) can contain a full definition, /// or only be an import of a definition (e.g. from another module). #[derive(Clone)] @@ -825,13 +927,24 @@ pub enum ControlNodeKind { }, } +// FIXME(eddyb) consider interning this, perhaps in a similar vein to `DataInstForm`. #[derive(Clone)] pub enum SelectionKind { /// Two-case selection based on boolean condition, i.e. `if`-`else`, with /// the two cases being "then" and "else" (in that order). BoolCond, - SpvInst(spv::Inst), + /// `N+1`-case selection based on comparing an integer scrutinee against + /// `N` constants, i.e. `switch`, with the last case being the "default" + /// (making it the only case without a matching entry in `case_consts`). + Switch { + // FIXME(eddyb) avoid some of the `scalar::Const` overhead here, as there + // is only a single type and we shouldn't need to store more bits per case, + // than the actual width of the integer type. + // FIXME(eddyb) consider storing this more like sorted compressed keyset, + // as there can be no duplicates, and in many cases it may be contiguous. + case_consts: Vec, + }, } /// Entity handle for a [`DataInstDef`](crate::DataInstDef) (an SSA instruction). @@ -868,6 +981,18 @@ pub struct DataInstFormDef { #[derive(Clone, PartialEq, Eq, Hash, derive_more::From)] pub enum DataInstKind { + /// Scalar (`bool`, integer, and floating-point) pure operations. + /// + /// See also the [`scalar`] module for more documentation and definitions. + #[from] + Scalar(scalar::Op), + + /// Vector (small array of [`scalar`]s) pure operations. + /// + /// See also the [`vector`] module for more documentation and definitions. + #[from] + Vector(vector::Op), + // FIXME(eddyb) try to split this into recursive and non-recursive calls, // to avoid needing special handling for recursion where it's impossible. FuncCall(Func), diff --git a/src/print/mod.rs b/src/print/mod.rs index c37eca03..6b68c8d0 100644 --- a/src/print/mod.rs +++ b/src/print/mod.rs @@ -24,8 +24,8 @@ use crate::print::multiversion::Versions; use crate::qptr::{self, QPtrAttr, QPtrMemUsage, QPtrMemUsageKind, QPtrOp, QPtrUsage}; use crate::visit::{InnerVisit, Visit, Visitor}; use crate::{ - cfg, spv, AddrSpace, Attr, AttrSet, AttrSetDef, Const, ConstDef, ConstKind, Context, - ControlNode, ControlNodeDef, ControlNodeKind, ControlNodeOutputDecl, ControlRegion, + cfg, scalar, spv, vector, AddrSpace, Attr, AttrSet, AttrSetDef, Const, ConstDef, ConstKind, + Context, ControlNode, ControlNodeDef, ControlNodeKind, ControlNodeOutputDecl, ControlRegion, ControlRegionDef, ControlRegionInputDecl, DataInst, DataInstDef, DataInstForm, DataInstFormDef, DataInstKind, DeclDef, Diag, DiagLevel, DiagMsgPart, EntityListIter, ExportKey, Exportee, Func, FuncDecl, FuncParam, FxIndexMap, FxIndexSet, GlobalVar, GlobalVarDecl, GlobalVarDefBody, @@ -673,7 +673,6 @@ enum UseStyle { impl<'a> Printer<'a> { fn new(plan: &Plan<'a>) -> Self { let cx = plan.cx; - let wk = &spv::spec::Spec::get().well_known; // HACK(eddyb) move this elsewhere. enum SmallSet { @@ -813,53 +812,32 @@ impl<'a> Printer<'a> { CxInterned::Type(ty) => { let ty_def = &cx[ty]; - // FIXME(eddyb) remove the duplication between - // here and `TypeDef`'s `Print` impl. - let has_compact_print_or_is_leaf = match &ty_def.kind { - TypeKind::SpvInst { spv_inst, type_and_const_inputs } => { - [ - wk.OpTypeBool, - wk.OpTypeInt, - wk.OpTypeFloat, - wk.OpTypeVector, - ] - .contains(&spv_inst.opcode) - || type_and_const_inputs.is_empty() + let is_leaf = match &ty_def.kind { + TypeKind::SpvInst { type_and_const_inputs, .. } => { + type_and_const_inputs.is_empty() } - TypeKind::QPtr | TypeKind::SpvStringLiteralForExtInst => { - true - } + TypeKind::Scalar(_) + | TypeKind::Vector(_) + | TypeKind::QPtr + | TypeKind::SpvStringLiteralForExtInst => true, }; - ty_def.attrs == AttrSet::default() - && has_compact_print_or_is_leaf + ty_def.attrs == AttrSet::default() && is_leaf } CxInterned::Const(ct) => { let ct_def = &cx[ct]; - // FIXME(eddyb) remove the duplication between - // here and `ConstDef`'s `Print` impl. - let (has_compact_print, has_nested_consts) = match &ct_def.kind - { + let has_nested_consts = match &ct_def.kind { ConstKind::SpvInst { spv_inst_and_const_inputs } => { - let (spv_inst, const_inputs) = + let (_spv_inst, const_inputs) = &**spv_inst_and_const_inputs; - ( - [ - wk.OpConstantFalse, - wk.OpConstantTrue, - wk.OpConstant, - ] - .contains(&spv_inst.opcode), - !const_inputs.is_empty(), - ) + !const_inputs.is_empty() } - _ => (false, false), + _ => false, }; - ct_def.attrs == AttrSet::default() - && (has_compact_print || !has_nested_consts) + ct_def.attrs == AttrSet::default() && !has_nested_consts } } } @@ -2378,77 +2356,43 @@ impl Print for TypeDef { let wk = &spv::spec::Spec::get().well_known; - // FIXME(eddyb) should this be done by lowering SPIR-V types to SPIR-T? let kw = |kw| printer.declarative_keyword_style().apply(kw).into(); - let compact_def = if let &TypeKind::SpvInst { - spv_inst: spv::Inst { opcode, ref imms }, - ref type_and_const_inputs, - } = kind - { - if opcode == wk.OpTypeBool { - Some(kw("bool".into())) - } else if opcode == wk.OpTypeInt { - let (width, signed) = match imms[..] { - [spv::Imm::Short(_, width), spv::Imm::Short(_, signedness)] => { - (width, signedness != 0) - } - _ => unreachable!(), - }; - Some(if signed { kw(format!("s{width}")) } else { kw(format!("u{width}")) }) - } else if opcode == wk.OpTypeFloat { - let width = match imms[..] { - [spv::Imm::Short(_, width)] => width, - _ => unreachable!(), - }; - - Some(kw(format!("f{width}"))) - } else if opcode == wk.OpTypeVector { - let (elem_ty, elem_count) = match (&imms[..], &type_and_const_inputs[..]) { - (&[spv::Imm::Short(_, elem_count)], &[TypeOrConst::Type(elem_ty)]) => { - (elem_ty, elem_count) - } - _ => unreachable!(), - }; - - Some(pretty::Fragment::new([ - elem_ty.print(printer), - "×".into(), - printer.numeric_literal_style().apply(format!("{elem_count}")).into(), - ])) - } else { - None + // FIXME(eddyb) should this just be `fmt::Display` on `scalar::Type`? + let print_scalar = |ty: scalar::Type| { + let width = ty.bit_width(); + match ty { + scalar::Type::Bool => "bool".into(), + scalar::Type::SInt(_) => format!("s{width}"), + scalar::Type::UInt(_) => format!("u{width}"), + scalar::Type::Float(_) => format!("f{width}"), } - } else { - None }; AttrsAndDef { attrs: attrs.print(printer), - def_without_name: if let Some(def) = compact_def { - def - } else { - match kind { - // FIXME(eddyb) should this be shortened to `qtr`? - TypeKind::QPtr => printer.declarative_keyword_style().apply("qptr").into(), - - TypeKind::SpvInst { spv_inst, type_and_const_inputs } => printer - .pretty_spv_inst( - printer.spv_op_style(), - spv_inst.opcode, - &spv_inst.imms, - type_and_const_inputs.iter().map(|&ty_or_ct| match ty_or_ct { - TypeOrConst::Type(ty) => ty.print(printer), - TypeOrConst::Const(ct) => ct.print(printer), - }), - ), - TypeKind::SpvStringLiteralForExtInst => pretty::Fragment::new([ - printer.error_style().apply("type_of").into(), - "(".into(), - printer.pretty_spv_opcode(printer.spv_op_style(), wk.OpString), - ")".into(), - ]), - } + def_without_name: match kind { + &TypeKind::Scalar(ty) => kw(print_scalar(ty)), + &TypeKind::Vector(ty) => kw(format!("{}×{}", print_scalar(ty.elem), ty.elem_count)), + + // FIXME(eddyb) should this be shortened to `qtr`? + TypeKind::QPtr => printer.declarative_keyword_style().apply("qptr").into(), + + TypeKind::SpvInst { spv_inst, type_and_const_inputs } => printer.pretty_spv_inst( + printer.spv_op_style(), + spv_inst.opcode, + &spv_inst.imms, + type_and_const_inputs.iter().map(|&ty_or_ct| match ty_or_ct { + TypeOrConst::Type(ty) => ty.print(printer), + TypeOrConst::Const(ct) => ct.print(printer), + }), + ), + TypeKind::SpvStringLiteralForExtInst => pretty::Fragment::new([ + printer.error_style().apply("type_of").into(), + "(".into(), + printer.pretty_spv_opcode(printer.spv_op_style(), wk.OpString), + ")".into(), + ]), }, } } @@ -2462,71 +2406,19 @@ impl Print for ConstDef { let wk = &spv::spec::Spec::get().well_known; let kw = |kw| printer.declarative_keyword_style().apply(kw).into(); - let literal_ty_suffix = |ty| { - pretty::Styles { - // HACK(eddyb) the exact type detracts from the value. - color_opacity: Some(0.4), - subscript: true, - ..printer.declarative_keyword_style() - } - .apply(ty) - }; - let compact_def = if let ConstKind::SpvInst { spv_inst_and_const_inputs } = kind { - let (spv_inst, _const_inputs) = &**spv_inst_and_const_inputs; - let &spv::Inst { opcode, ref imms } = spv_inst; - - if opcode == wk.OpConstantFalse { - Some(kw("false")) - } else if opcode == wk.OpConstantTrue { - Some(kw("true")) - } else if opcode == wk.OpConstant { - // HACK(eddyb) it's simpler to only handle a limited subset of - // integer/float bit-widths, for now. - let raw_bits = match imms[..] { - [spv::Imm::Short(_, x)] => Some(u64::from(x)), - [spv::Imm::LongStart(_, lo), spv::Imm::LongCont(_, hi)] => { - Some(u64::from(lo) | (u64::from(hi) << 32)) - } - _ => None, - }; - - if let ( - Some(raw_bits), - &TypeKind::SpvInst { - spv_inst: spv::Inst { opcode: ty_opcode, imms: ref ty_imms }, - .. - }, - ) = (raw_bits, &printer.cx[*ty].kind) - { - if ty_opcode == wk.OpTypeInt { - let (width, signed) = match ty_imms[..] { - [spv::Imm::Short(_, width), spv::Imm::Short(_, signedness)] => { - (width, signedness != 0) - } - _ => unreachable!(), - }; - - if width <= 64 { - let (printed_value, ty) = if signed { - let sext_raw_bits = - (raw_bits as u128 as i128) << (128 - width) >> (128 - width); - (format!("{sext_raw_bits}"), format!("s{width}")) - } else { - (format!("{raw_bits}"), format!("u{width}")) - }; - Some(pretty::Fragment::new([ - printer.numeric_literal_style().apply(printed_value), - literal_ty_suffix(ty), - ])) - } else { - None - } - } else if ty_opcode == wk.OpTypeFloat { - let width = match ty_imms[..] { - [spv::Imm::Short(_, width)] => width, - _ => unreachable!(), - }; + // FIXME(eddyb) should this be a method on `scalar::Const` instead? + let print_scalar = |ct: scalar::Const, include_type_suffix: bool| match ct { + scalar::Const::FALSE => kw("false"), + scalar::Const::TRUE => kw("true"), + _ => { + let ty = ct.ty(); + let width = ty.bit_width(); + let (maybe_printed_value, ty_prefix) = match ty { + scalar::Type::Bool => unreachable!(), + scalar::Type::SInt(_) => (ct.int_as_i128().map(|x| x.to_string()), 's'), + scalar::Type::UInt(_) => (ct.int_as_u128().map(|x| x.to_string()), 'u'), + scalar::Type::Float(_) => { /// Check that parsing the result of printing produces /// the original bits of the floating-point value, and /// only return `Some` if that is the case. @@ -2546,64 +2438,95 @@ impl Print for ConstDef { }) } - let printed_value = match width { - 32 => bitwise_roundtrip_float_print( - raw_bits as u32, - f32::from_bits, - f32::to_bits, - ), - 64 => bitwise_roundtrip_float_print( - raw_bits, - f64::from_bits, - f64::to_bits, - ), - _ => None, - }; - printed_value.map(|s| { - pretty::Fragment::new([ - printer.numeric_literal_style().apply(s), - literal_ty_suffix(format!("f{width}")), - ]) - }) - } else { - None + ( + match width { + 32 => bitwise_roundtrip_float_print( + ct.bits() as u32, + f32::from_bits, + f32::to_bits, + ), + 64 => bitwise_roundtrip_float_print( + ct.bits() as u64, + f64::from_bits, + f64::to_bits, + ), + _ => None, + }, + 'f', + ) } - } else { - None + }; + match maybe_printed_value { + Some(printed_value) => { + let printed_value = printer.numeric_literal_style().apply(printed_value); + if include_type_suffix { + let literal_ty_suffix = pretty::Styles { + // HACK(eddyb) the exact type detracts from the value. + color_opacity: Some(0.4), + subscript: true, + ..printer.declarative_keyword_style() + } + .apply(format!("{ty_prefix}{width}")); + pretty::Fragment::new([printed_value, literal_ty_suffix]) + } else { + printed_value.into() + } + } + // HACK(eddyb) fallback using the bitwise representation. + None => pretty::Fragment::new([ + printer + .demote_style_for_namespace_prefix(printer.declarative_keyword_style()) + .apply(format!("{ty_prefix}{width}.")) + .into(), + printer.declarative_keyword_style().apply("from_bits").into(), + pretty::join_comma_sep( + "(", + [ + // FIXME(eddyb) consider padding this with enough + // leading zeroes for its respective width. + printer.numeric_literal_style().apply(format!("0x{:x}", ct.bits())), + ], + ")", + ), + ]), } - } else { - None } - } else { - None }; - AttrsAndDef { - attrs: attrs.print(printer), - def_without_name: compact_def.unwrap_or_else(|| match kind { - &ConstKind::PtrToGlobalVar(gv) => { - pretty::Fragment::new(["&".into(), gv.print(printer)]) - } - ConstKind::SpvInst { spv_inst_and_const_inputs } => { - let (spv_inst, const_inputs) = &**spv_inst_and_const_inputs; - pretty::Fragment::new([ - printer.pretty_spv_inst( - printer.spv_op_style(), - spv_inst.opcode, - &spv_inst.imms, - const_inputs.iter().map(|ct| ct.print(printer)), - ), - printer.pretty_type_ascription_suffix(*ty), - ]) - } - &ConstKind::SpvStringLiteralForExtInst(s) => pretty::Fragment::new([ - printer.pretty_spv_opcode(printer.spv_op_style(), wk.OpString), - "(".into(), - printer.pretty_string_literal(&printer.cx[s]), - ")".into(), - ]), - }), - } + let def_without_name = match kind { + ConstKind::Undef => pretty::Fragment::new([ + printer.imperative_keyword_style().apply("undef").into(), + printer.pretty_type_ascription_suffix(*ty), + ]), + &ConstKind::Scalar(ct) => print_scalar(ct, true), + ConstKind::Vector(ct) => pretty::Fragment::new([ + ty.print(printer), + pretty::join_comma_sep("(", ct.elems().map(|elem| print_scalar(elem, false)), ")"), + ]), + &ConstKind::PtrToGlobalVar(gv) => { + pretty::Fragment::new(["&".into(), gv.print(printer)]) + } + + ConstKind::SpvInst { spv_inst_and_const_inputs } => { + let (spv_inst, const_inputs) = &**spv_inst_and_const_inputs; + pretty::Fragment::new([ + printer.pretty_spv_inst( + printer.spv_op_style(), + spv_inst.opcode, + &spv_inst.imms, + const_inputs.iter().map(|ct| ct.print(printer)), + ), + printer.pretty_type_ascription_suffix(*ty), + ]) + } + &ConstKind::SpvStringLiteralForExtInst(s) => pretty::Fragment::new([ + printer.pretty_spv_opcode(printer.spv_op_style(), wk.OpString), + "(".into(), + printer.pretty_string_literal(&printer.cx[s]), + ")".into(), + ]), + }; + AttrsAndDef { attrs: attrs.print(printer), def_without_name } } } @@ -3010,7 +2933,7 @@ impl Print for FuncAt<'_, ControlNode> { ( pretty::join_comma_sep( "(", - input_decls_and_uses.clone().zip(initial_inputs).map( + input_decls_and_uses.clone().zip_eq(initial_inputs).map( |((input_decl, input_use), initial)| { pretty::Fragment::new([ input_decl.print(printer).insert_name_before_def( @@ -3100,7 +3023,65 @@ impl Print for FuncAt<'_, DataInst> { let mut output_type_to_print = *output_type; + // FIXME(eddyb) should this be a method on `scalar::Op` instead? + let print_scalar = |op: scalar::Op| { + let name = op.name(); + let (namespace_prefix, name) = name.split_at(name.find('.').unwrap() + 1); + pretty::Fragment::new([ + printer + .demote_style_for_namespace_prefix(printer.declarative_keyword_style()) + .apply(namespace_prefix), + printer.declarative_keyword_style().apply(name), + ]) + }; + let def_without_type = match kind { + &DataInstKind::Scalar(op) => pretty::Fragment::new([ + print_scalar(op), + pretty::join_comma_sep("(", inputs.iter().map(|v| v.print(printer)), ")"), + ]), + + &DataInstKind::Vector(op) => { + let (name, extra_last_input) = match op { + vector::Op::Distribute(_) => ("vec.distribute", None), + vector::Op::Reduce(op) => (op.name(), None), + vector::Op::Whole(op) => ( + op.name(), + match op { + vector::WholeOp::Extract { elem_idx } + | vector::WholeOp::Insert { elem_idx } => Some( + printer.numeric_literal_style().apply(elem_idx.to_string()).into(), + ), + vector::WholeOp::New + | vector::WholeOp::DynExtract + | vector::WholeOp::DynInsert + | vector::WholeOp::Mul => None, + }, + ), + }; + let (namespace_prefix, name) = name.split_at(name.find('.').unwrap() + 1); + let mut pretty_name = pretty::Fragment::new([ + printer + .demote_style_for_namespace_prefix(printer.declarative_keyword_style()) + .apply(namespace_prefix), + printer.declarative_keyword_style().apply(name), + ]); + if let vector::Op::Distribute(op) = op { + pretty_name = pretty::Fragment::new([ + pretty_name, + pretty::join_comma_sep("(", [print_scalar(op)], ")"), + ]); + } + pretty::Fragment::new([ + pretty_name, + pretty::join_comma_sep( + "(", + inputs.iter().map(|v| v.print(printer)).chain(extra_last_input), + ")", + ), + ]) + } + &DataInstKind::FuncCall(func) => pretty::Fragment::new([ printer.declarative_keyword_style().apply("call").into(), " ".into(), @@ -3294,21 +3275,19 @@ impl Print for FuncAt<'_, DataInst> { let pseudo_imm_from_value = |v: Value| { if let Value::Const(ct) = v { match &printer.cx[ct].kind { + ConstKind::Undef + | ConstKind::Vector(_) + | ConstKind::PtrToGlobalVar(_) + | ConstKind::SpvInst { .. } => {} + &ConstKind::SpvStringLiteralForExtInst(s) => { return Some(PseudoImm::Str(&printer.cx[s])); } - ConstKind::SpvInst { spv_inst_and_const_inputs } => { - let (spv_inst, _const_inputs) = &**spv_inst_and_const_inputs; - if spv_inst.opcode == wk.OpConstant { - if let [spv::Imm::Short(_, x)] = spv_inst.imms[..] { - // HACK(eddyb) only allow unambiguously positive values. - if i32::try_from(x).and_then(u32::try_from) == Ok(x) { - return Some(PseudoImm::U32(x)); - } - } - } + // HACK(eddyb) lossless roundtrip through `i32` is most conservative + // option (only `0..=i32::MAX`, i.e. `0 <= x < 2**32, is allowed). + ConstKind::Scalar(ct) => { + return Some(PseudoImm::U32(u32::try_from(ct.int_as_i32()?).ok()?)); } - ConstKind::PtrToGlobalVar(_) => {} } } None @@ -3530,7 +3509,7 @@ impl SelectionKind { mut cases: impl ExactSizeIterator, ) -> pretty::Fragment { let kw = |kw| kw_style.apply(kw).into(); - match *self { + match self { SelectionKind::BoolCond => { assert_eq!(cases.len(), 2); let [then_case, else_case] = [cases.next().unwrap(), cases.next().unwrap()]; @@ -3547,27 +3526,36 @@ impl SelectionKind { "}".into(), ]) } - SelectionKind::SpvInst(spv::Inst { opcode, ref imms }) => { - let header = printer.pretty_spv_inst( - kw_style, - opcode, - imms, - [Some(scrutinee.print(printer))] - .into_iter() - .chain((0..cases.len()).map(|_| None)), - ); + SelectionKind::Switch { case_consts } => { + assert_eq!(cases.len(), case_consts.len() + 1); + + let case_patterns = case_consts + .iter() + .map(|&ct| { + let int_to_string = (ct.int_as_u128().map(|x| x.to_string())) + .or_else(|| ct.int_as_i128().map(|x| x.to_string())); + match int_to_string { + Some(v) => printer.numeric_literal_style().apply(v).into(), + None => { + let ct: Const = printer.cx.intern(ct); + ct.print(printer) + } + } + }) + .chain(["_".into()]); pretty::Fragment::new([ - header, + kw("switch"), + " ".into(), + scrutinee.print(printer), " {".into(), pretty::Node::IndentedBlock( - cases - .map(|case| { + case_patterns + .zip_eq(cases) + .map(|(case_pattern, case)| { pretty::Fragment::new([ pretty::Node::ForceLineSeparation.into(), - // FIXME(eddyb) this should pull information out - // of the instruction to be more precise. - kw("case"), + case_pattern, " => {".into(), pretty::Node::IndentedBlock(vec![case]).into(), "}".into(), diff --git a/src/qptr/analyze.rs b/src/qptr/analyze.rs index daaf1390..45183c1c 100644 --- a/src/qptr/analyze.rs +++ b/src/qptr/analyze.rs @@ -906,6 +906,8 @@ impl<'a> InferUsage<'a> { }); }; match &data_inst_form_def.kind { + DataInstKind::Scalar(_) | DataInstKind::Vector(_) => {} + &DataInstKind::FuncCall(callee) => { match self.infer_usage_in_func(module, callee) { FuncInferUsageState::Complete(callee_results) => { diff --git a/src/qptr/layout.rs b/src/qptr/layout.rs index 00def111..0617ddb7 100644 --- a/src/qptr/layout.rs +++ b/src/qptr/layout.rs @@ -2,7 +2,7 @@ use crate::qptr::shapes; use crate::{ - spv, AddrSpace, Attr, Const, ConstKind, Context, Diag, FxIndexMap, Type, TypeKind, TypeOrConst, + scalar, spv, AddrSpace, Attr, Const, Context, Diag, FxIndexMap, Type, TypeKind, TypeOrConst, }; use itertools::Either; use smallvec::SmallVec; @@ -182,18 +182,10 @@ impl<'a> LayoutCache<'a> { Self { cx, wk: &spv::spec::Spec::get().well_known, config, cache: Default::default() } } - // FIXME(eddyb) properly distinguish between zero-extension and sign-extension. fn const_as_u32(&self, ct: Const) -> Option { - if let ConstKind::SpvInst { spv_inst_and_const_inputs } = &self.cx[ct].kind { - let (spv_inst, _const_inputs) = &**spv_inst_and_const_inputs; - if spv_inst.opcode == self.wk.OpConstant && spv_inst.imms.len() == 1 { - match spv_inst.imms[..] { - [spv::Imm::Short(_, x)] => return Some(x), - _ => unreachable!(), - } - } - } - None + // HACK(eddyb) lossless roundtrip through `i32` is most conservative + // option (only `0..=i32::MAX`, i.e. `0 <= x < 2**32, is allowed). + u32::try_from(ct.as_scalar(&self.cx)?.int_as_i32()?).ok() } /// Attempt to compute a `TypeLayout` for a given (SPIR-V) `Type`. @@ -202,26 +194,16 @@ impl<'a> LayoutCache<'a> { return Ok(cached); } + let layout = self.layout_of_uncached(ty)?; + self.cache.borrow_mut().insert(ty, layout.clone()); + Ok(layout) + } + + fn layout_of_uncached(&self, ty: Type) -> Result { let cx = &self.cx; let wk = self.wk; let ty_def = &cx[ty]; - let (spv_inst, type_and_const_inputs) = match &ty_def.kind { - // FIXME(eddyb) treat `QPtr`s as scalars. - TypeKind::QPtr => { - return Err(LayoutError(Diag::bug( - ["`layout_of(qptr)` (already lowered?)".into()], - ))); - } - TypeKind::SpvInst { spv_inst, type_and_const_inputs } => { - (spv_inst, type_and_const_inputs) - } - TypeKind::SpvStringLiteralForExtInst => { - return Err(LayoutError(Diag::bug([ - "`layout_of(type_of(OpString<\"...\">))`".into() - ]))); - } - }; let scalar_with_size_and_align = |(size, align)| { TypeLayout::Concrete(Rc::new(MemTypeLayout { @@ -340,34 +322,59 @@ impl<'a> LayoutCache<'a> { } } }; - let short_imm_at = |i| match spv_inst.imms[i] { - spv::Imm::Short(_, x) => x, - _ => unreachable!(), - }; // FIXME(eddyb) !!! what if... types had a min/max size and then... // that would allow surrounding offsets to limit their size... but... ugh... // ugh this doesn't make any sense. maybe if the front-end specifies // offsets with "abstract types", it must configure `qptr::layout`? - let layout = if spv_inst.opcode == wk.OpTypeBool { - // FIXME(eddyb) make this properly abstract instead of only configurable. - scalar_with_size_and_align(self.config.abstract_bool_size_align) - } else if spv_inst.opcode == wk.OpTypePointer { + + let (spv_inst, type_and_const_inputs) = match &ty_def.kind { + TypeKind::Scalar(scalar::Type::Bool) => { + // FIXME(eddyb) make this properly abstract instead of only configurable. + return Ok(scalar_with_size_and_align(self.config.abstract_bool_size_align)); + } + TypeKind::Scalar(ty) => return Ok(scalar(ty.bit_width())), + + TypeKind::Vector(ty) => { + let len = u32::from(ty.elem_count.get()); + return array( + cx.intern(ty.elem), + ArrayParams { + fixed_len: Some(len), + known_stride: None, + + // NOTE(eddyb) this is specifically Vulkan "base alignment". + min_legacy_align: 1, + legacy_align_multiplier: if len <= 2 { 2 } else { 4 }, + }, + ); + } + + // FIXME(eddyb) treat `QPtr`s as scalars. + TypeKind::QPtr => { + return Err(LayoutError(Diag::bug( + ["`layout_of(qptr)` (already lowered?)".into()], + ))); + } + TypeKind::SpvInst { spv_inst, type_and_const_inputs } => { + (spv_inst, type_and_const_inputs) + } + TypeKind::SpvStringLiteralForExtInst => { + return Err(LayoutError(Diag::bug([ + "`layout_of(type_of(OpString<\"...\">))`".into() + ]))); + } + }; + let short_imm_at = |i| match spv_inst.imms[i] { + spv::Imm::Short(_, x) => x, + _ => unreachable!(), + }; + Ok(if spv_inst.opcode == wk.OpTypePointer { // FIXME(eddyb) make this properly abstract instead of only configurable. // FIXME(eddyb) categorize `OpTypePointer` by storage class and split on // logical vs physical here. scalar_with_size_and_align(self.config.logical_ptr_size_align) - } else if [wk.OpTypeInt, wk.OpTypeFloat].contains(&spv_inst.opcode) { - scalar(short_imm_at(0)) - } else if [wk.OpTypeVector, wk.OpTypeMatrix].contains(&spv_inst.opcode) { - let len = short_imm_at(0); - let (min_legacy_align, legacy_align_multiplier) = if spv_inst.opcode == wk.OpTypeVector - { - // NOTE(eddyb) this is specifically Vulkan "base alignment". - (1, if len <= 2 { 2 } else { 4 }) - } else { - (self.config.min_aggregate_legacy_align, 1) - }; + } else if spv_inst.opcode == wk.OpTypeMatrix { // NOTE(eddyb) `RowMajor` is disallowed on `OpTypeStruct` members below. array( match type_and_const_inputs[..] { @@ -375,10 +382,10 @@ impl<'a> LayoutCache<'a> { _ => unreachable!(), }, ArrayParams { - fixed_len: Some(len), + fixed_len: Some(short_imm_at(0)), known_stride: None, - min_legacy_align, - legacy_align_multiplier, + min_legacy_align: self.config.min_aggregate_legacy_align, + legacy_align_multiplier: 1, }, )? } else if [wk.OpTypeArray, wk.OpTypeRuntimeArray].contains(&spv_inst.opcode) { @@ -642,8 +649,6 @@ impl<'a> LayoutCache<'a> { spv_inst.opcode.name() ) .into()]))); - }; - self.cache.borrow_mut().insert(ty, layout.clone()); - Ok(layout) + }) } } diff --git a/src/qptr/lift.rs b/src/qptr/lift.rs index f962ded5..f624d060 100644 --- a/src/qptr/lift.rs +++ b/src/qptr/lift.rs @@ -7,13 +7,12 @@ use crate::func_at::FuncAtMut; use crate::qptr::{shapes, QPtrAttr, QPtrMemUsage, QPtrMemUsageKind, QPtrOp, QPtrUsage}; use crate::transform::{InnerInPlaceTransform, InnerTransform, Transformed, Transformer}; use crate::{ - spv, AddrSpace, Attr, AttrSet, AttrSetDef, Const, ConstDef, ConstKind, Context, ControlNode, - ControlNodeKind, DataInst, DataInstDef, DataInstFormDef, DataInstKind, DeclDef, Diag, - DiagLevel, EntityDefs, EntityOrientedDenseMap, Func, FuncDecl, FxIndexMap, GlobalVar, + scalar, spv, AddrSpace, Attr, AttrSet, AttrSetDef, Const, ConstDef, ConstKind, Context, + ControlNode, ControlNodeKind, DataInst, DataInstDef, DataInstFormDef, DataInstKind, DeclDef, + Diag, DiagLevel, EntityDefs, EntityOrientedDenseMap, Func, FuncDecl, FxIndexMap, GlobalVar, GlobalVarDecl, Module, Type, TypeDef, TypeKind, TypeOrConst, Value, }; use smallvec::SmallVec; -use std::cell::Cell; use std::mem; use std::num::NonZeroU32; use std::rc::Rc; @@ -27,8 +26,6 @@ pub struct LiftToSpvPtrs<'a> { cx: Rc, wk: &'static spv::spec::WellKnown, layout_cache: LayoutCache<'a>, - - cached_u32_type: Cell>, } impl<'a> LiftToSpvPtrs<'a> { @@ -37,7 +34,6 @@ impl<'a> LiftToSpvPtrs<'a> { cx: cx.clone(), wk: &spv::spec::Spec::get().well_known, layout_cache: LayoutCache::new(cx, layout_config), - cached_u32_type: Default::default(), } } @@ -291,7 +287,9 @@ impl<'a> LiftToSpvPtrs<'a> { spv_inst: spv_opcode.into(), type_and_const_inputs: [TypeOrConst::Type(element_type)] .into_iter() - .chain(fixed_len.map(|len| TypeOrConst::Const(self.const_u32(len)))) + .chain(fixed_len.map(|len| { + TypeOrConst::Const(self.cx.intern(scalar::Const::from_u32(len))) + })) .collect(), }, })) @@ -329,48 +327,6 @@ impl<'a> LiftToSpvPtrs<'a> { })) } - /// Get the (likely cached) `u32` type. - fn u32_type(&self) -> Type { - if let Some(cached) = self.cached_u32_type.get() { - return cached; - } - let wk = self.wk; - let ty = self.cx.intern(TypeKind::SpvInst { - spv_inst: spv::Inst { - opcode: wk.OpTypeInt, - imms: [ - spv::Imm::Short(wk.LiteralInteger, 32), - spv::Imm::Short(wk.LiteralInteger, 0), - ] - .into_iter() - .collect(), - }, - type_and_const_inputs: [].into_iter().collect(), - }); - self.cached_u32_type.set(Some(ty)); - ty - } - - fn const_u32(&self, x: u32) -> Const { - let wk = self.wk; - - self.cx.intern(ConstDef { - attrs: AttrSet::default(), - ty: self.u32_type(), - kind: ConstKind::SpvInst { - spv_inst_and_const_inputs: Rc::new(( - spv::Inst { - opcode: wk.OpConstant, - imms: [spv::Imm::Short(wk.LiteralContextDependentNumber, x)] - .into_iter() - .collect(), - }, - [].into_iter().collect(), - )), - }, - }) - } - /// Attempt to compute a `TypeLayout` for a given (SPIR-V) `Type`. fn layout_of(&self, ty: Type) -> Result { self.layout_cache.layout_of(ty).map_err(|LayoutError(err)| LiftError(err)) @@ -448,6 +404,8 @@ impl LiftToSpvPtrInstsInFunc<'_> { Ok((addr_space, self.lifter.layout_of(pointee_type)?)) }; let replacement_data_inst_def = match &data_inst_form_def.kind { + DataInstKind::Scalar(_) | DataInstKind::Vector(_) => return Ok(Transformed::Unchanged), + &DataInstKind::FuncCall(_callee) => { for &v in &data_inst_def.inputs { if self.lifter.as_spv_ptr_type(type_of_val(v)).is_some() { @@ -644,7 +602,7 @@ impl LiftToSpvPtrInstsInFunc<'_> { ])) })?; access_chain_inputs - .push(Value::Const(self.lifter.const_u32(idx_as_i32 as u32))); + .push(Value::Const(cx.intern(scalar::Const::from_u32(idx_as_i32 as u32)))); match &layout.components { Components::Scalar => unreachable!(), @@ -757,7 +715,7 @@ impl LiftToSpvPtrInstsInFunc<'_> { ])) })?; access_chain_inputs - .push(Value::Const(self.lifter.const_u32(idx_as_i32 as u32))); + .push(Value::Const(cx.intern(scalar::Const::from_u32(idx_as_i32 as u32)))); layout = match &layout.components { Components::Scalar => unreachable!(), @@ -945,7 +903,8 @@ impl LiftToSpvPtrInstsInFunc<'_> { let mut access_chain_inputs: SmallVec<_> = [ptr].into_iter().collect(); if let TypeLayout::HandleArray(handle, _) = pointee_layout { - access_chain_inputs.push(Value::Const(self.lifter.const_u32(0))); + access_chain_inputs + .push(Value::Const(self.lifter.cx.intern(scalar::Const::from_u32(0)))); pointee_layout = TypeLayout::Handle(handle); } match (pointee_layout, access_layout) { @@ -1014,8 +973,9 @@ impl LiftToSpvPtrInstsInFunc<'_> { format!("{idx} not representable as a positive s32").into() ])) })?; - access_chain_inputs - .push(Value::Const(self.lifter.const_u32(idx_as_i32 as u32))); + access_chain_inputs.push(Value::Const( + self.lifter.cx.intern(scalar::Const::from_u32(idx_as_i32 as u32)), + )); pointee_layout = match &pointee_layout.components { Components::Scalar => unreachable!(), diff --git a/src/qptr/lower.rs b/src/qptr/lower.rs index 512b6856..dec482e6 100644 --- a/src/qptr/lower.rs +++ b/src/qptr/lower.rs @@ -171,18 +171,10 @@ impl<'a> LowerFromSpvPtrs<'a> { } } - // FIXME(eddyb) properly distinguish between zero-extension and sign-extension. fn const_as_u32(&self, ct: Const) -> Option { - if let ConstKind::SpvInst { spv_inst_and_const_inputs } = &self.cx[ct].kind { - let (spv_inst, _const_inputs) = &**spv_inst_and_const_inputs; - if spv_inst.opcode == self.wk.OpConstant && spv_inst.imms.len() == 1 { - match spv_inst.imms[..] { - [spv::Imm::Short(_, x)] => return Some(x), - _ => unreachable!(), - } - } - } - None + // HACK(eddyb) lossless roundtrip through `i32` is most conservative + // option (only `0..=i32::MAX`, i.e. `0 <= x < 2**32, is allowed). + u32::try_from(ct.as_scalar(&self.cx)?.int_as_i32()?).ok() } /// Get the (likely cached) `QPtr` type. @@ -624,7 +616,10 @@ impl LowerFromSpvPtrInstsInFunc<'_> { match data_inst_form_def.kind { // Known semantics, no need to preserve SPIR-V pointer information. - DataInstKind::FuncCall(_) | DataInstKind::QPtr(_) => return, + DataInstKind::Scalar(_) + | DataInstKind::Vector(_) + | DataInstKind::FuncCall(_) + | DataInstKind::QPtr(_) => return, DataInstKind::SpvInst(_) | DataInstKind::SpvExtInst { .. } => {} } diff --git a/src/scalar.rs b/src/scalar.rs new file mode 100644 index 00000000..29de3a50 --- /dev/null +++ b/src/scalar.rs @@ -0,0 +1,469 @@ +//! Scalar (`bool`, integer, and floating-point) types and associated functionality. +//! +//! **Note**: pointers are never scalars (like SPIR-V, but unlike other IRs). + +// HACK(eddyb) this could be some `struct` with private fields, but this `enum` +// is only 2 bytes in size, and has better ergonomics overall. +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub enum Type { + Bool, + SInt(IntWidth), + UInt(IntWidth), + Float(FloatWidth), +} + +impl Type { + // HACK(eddyb) only common widths, as a convenience, expand as-needed. + pub const S32: Type = Type::SInt(IntWidth::I32); + pub const U32: Type = Type::UInt(IntWidth::I32); + pub const F32: Type = Type::Float(FloatWidth::F32); + pub const F64: Type = Type::Float(FloatWidth::F64); + + pub const fn bit_width(self) -> u32 { + match self { + Type::Bool => 1, + Type::SInt(w) | Type::UInt(w) => w.bits(), + Type::Float(w) => w.bits(), + } + } +} + +/// Bit-width of a supported integer type (only power-of-two multiples of a byte). +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub struct IntWidth { + // HACK(eddyb) this is so compact that only 3 bits of this byte are used + // to encode integer types from `i8` to `i128`, and so `Type` could all fit + // in one byte, but that'd need a new `enum` for `Bool`/`{S,U}Int`/`Float`. + log2_bytes: u8, +} + +impl IntWidth { + pub const I8: Self = Self::try_from_bits_unwrap(8); + pub const I16: Self = Self::try_from_bits_unwrap(16); + pub const I32: Self = Self::try_from_bits_unwrap(32); + pub const I64: Self = Self::try_from_bits_unwrap(64); + pub const I128: Self = Self::try_from_bits_unwrap(128); + + // FIXME(eddyb) remove when `Option::unwrap` is stabilized. + const fn try_from_bits_unwrap(bits: u32) -> Self { + match Self::try_from_bits(bits) { + Some(w) => w, + None => unreachable!(), + } + } + + pub const fn try_from_bits(bits: u32) -> Option { + if bits % 8 != 0 { + return None; + } + let bytes = bits / 8; + match bytes.checked_ilog2() { + Some(log2_bytes_u32) => { + let log2_bytes = log2_bytes_u32 as u8; + assert!(log2_bytes as u32 == log2_bytes_u32); + Some(Self { log2_bytes }) + } + None => None, + } + } + + pub const fn bits(self) -> u32 { + 8 * (1 << self.log2_bytes) + } +} + +/// Bit-width of a supported floating-point type (only power-of-two multiples of a byte). +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub struct FloatWidth(IntWidth); + +impl FloatWidth { + pub const F32: Self = Self::try_from_bits_unwrap(32); + pub const F64: Self = Self::try_from_bits_unwrap(64); + + // FIXME(eddyb) remove when `Option::unwrap` is stabilized. + const fn try_from_bits_unwrap(bits: u32) -> Self { + match Self::try_from_bits(bits) { + Some(w) => w, + None => unreachable!(), + } + } + + pub const fn try_from_bits(bits: u32) -> Option { + match IntWidth::try_from_bits(bits) { + Some(w) => Some(Self(w)), + None => None, + } + } + + pub const fn bits(self) -> u32 { + self.0.bits() + } +} + +// FIXME(eddyb) document the 128-bit limitations. +// HACK(eddyb) `(Type, u128)` would waste almost half its size on padding, and +// packing will only impact accessing the `bits`, while allowing e.g. being +// wrapped in an outer `enum`, before reaching the same size as `(u128, u128)`. +#[repr(packed)] +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub struct Const { + ty: Type, + bits: u128, +} + +impl Const { + pub const FALSE: Const = Const::from_bool(false); + pub const TRUE: Const = Const::from_bool(true); + + // FIXME(eddyb) document the panic conditions. + // FIXME(eddyb) make this public? + const fn from_bits_trunc(ty: Type, bits: u128) -> Const { + // FIXME(eddyb) this ensures `Const`s cannot be created when that could + // potentially need more than 128 bits for e.g. constant-folding. + let width = ty.bit_width(); + assert!(width <= 128); + + Const { ty, bits: bits & (!0u128 >> (128 - width)) } + } + + // FIXME(eddyb) document the panic conditions. + pub const fn from_bits(ty: Type, bits: u128) -> Const { + let ct_trunc = Const::from_bits_trunc(ty, bits); + assert!(ct_trunc.bits == bits); + ct_trunc + } + + pub const fn try_from_bits(ty: Type, bits: u128) -> Option { + let ct_trunc = Const::from_bits_trunc(ty, bits); + if ct_trunc.bits == bits { Some(ct_trunc) } else { None } + } + + pub const fn from_bool(v: bool) -> Const { + Const::from_bits(Type::Bool, v as u128) + } + + pub const fn from_u32(v: u32) -> Const { + Const::from_bits(Type::U32, v as u128) + } + + /// Returns `Some(ct)` iff `ty` is `{S,U}Int` and can represent `v: i128` + /// (i.e. `ct` has the same sign and absolute value as `v` does). + pub fn int_try_from_i128(ty: Type, v: i128) -> Option { + let ct_trunc = Const::from_bits_trunc(ty, v as u128); + (ct_trunc.int_as_i128() == Some(v)).then_some(ct_trunc) + } + + pub const fn ty(&self) -> Type { + self.ty + } + + pub const fn bits(&self) -> u128 { + self.bits + } + + /// Returns `Some(v)` iff `self` is `{S,U}Int` and representable by `v: i128` + /// (i.e. `self` has the same sign and absolute value as `v` does). + pub fn int_as_i128(&self) -> Option { + match self.ty { + Type::Bool | Type::Float(_) => None, + Type::SInt(_) => { + let width = self.ty.bit_width(); + Some((self.bits as i128) << (128 - width) >> (128 - width)) + } + Type::UInt(_) => self.bits.try_into().ok(), + } + } + + /// Returns `Some(v)` iff `self` is `{S,U}Int` and representable by `v: u128` + /// (i.e. `self` is positive and has the same absolute value as `v` does). + pub fn int_as_u128(&self) -> Option { + match self.ty { + Type::Bool | Type::Float(_) => None, + Type::SInt(_) => self.int_as_i128()?.try_into().ok(), + Type::UInt(_) => Some(self.bits), + } + } + + /// Returns `Some(v)` iff `self` is `{S,U}Int` and representable by `v: i32` + /// (i.e. `self` has the same sign and absolute value as `v` does). + pub fn int_as_i32(&self) -> Option { + self.int_as_i128()?.try_into().ok() + } + + /// Returns `Some(v)` iff `self` is `{S,U}Int` and representable by `v: u32` + /// (i.e. `self` is positive and has the same absolute value as `v` does). + pub fn int_as_u32(&self) -> Option { + self.int_as_u128()?.try_into().ok() + } +} + +/// Pure operations with scalar inputs and outputs. +// +// FIXME(eddyb) these are not some "perfect" grouping, but allow for more +// flexibility in users of this `enum` (and its component `enum`s). +#[derive(Copy, Clone, PartialEq, Eq, Hash, derive_more::From)] +pub enum Op { + BoolUnary(BoolUnOp), + BoolBinary(BoolBinOp), + + IntUnary(IntUnOp), + IntBinary(IntBinOp), + + FloatUnary(FloatUnOp), + FloatBinary(FloatBinOp), +} + +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub enum BoolUnOp { + Not, +} + +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub enum BoolBinOp { + Eq, + // FIXME(eddyb) should this be `Xor` instead? + Ne, + Or, + And, +} + +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub enum IntUnOp { + Neg, + Not, + CountOnes, + + // FIXME(eddyb) ideally `Trunc` should be separated and common. + TruncOrZeroExtend, + TruncOrSignExtend, +} + +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub enum IntBinOp { + // I×I→I + Add, + Sub, + Mul, + DivU, + DivS, + ModU, + RemS, + ModS, + ShrU, + ShrS, + Shl, + Or, + Xor, + And, + + // I×I→I×I + CarryingAdd, + BorrowingSub, + WideningMulU, + WideningMulS, + + // I×I→B + Eq, + Ne, + // FIXME(eddyb) deduplicate between signed and unsigned. + GtU, + GtS, + GeU, + GeS, + LtU, + LtS, + LeU, + LeS, +} + +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub enum FloatUnOp { + // F→F + Neg, + + // F→B + IsNan, + IsInf, + + // FIXME(eddyb) these are a complicated mix of signatures. + FromUInt, + FromSInt, + ToUInt, + ToSInt, + Convert, + QuantizeAsF16, +} + +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub enum FloatBinOp { + // F×F→F + Add, + Sub, + Mul, + Div, + Rem, + Mod, + + // F×F→B + Cmp(FloatCmp), + // FIXME(eddyb) this doesn't properly convey that this is effectively the + // boolean flip of the opposite comparison, e.g. `CmpOrUnord(Ge)` is really + // a fused version of `Not(Cmp(Lt))`, because `x < y` is never `true` for + // unordered `x` and `y` (i.e. `PartialOrd::partial_cmp(x, y) == None`), + // but that maps to `!(x < y)` always being `true` for unordered `x` and `y`, + // and thus `x >= y` is only equivalent for the ordered cases. + CmpOrUnord(FloatCmp), +} + +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub enum FloatCmp { + Eq, + Ne, + Lt, + Gt, + Le, + Ge, +} + +impl Op { + pub fn output_count(self) -> usize { + match self { + Op::IntBinary(op) => op.output_count(), + _ => 1, + } + } + + pub fn name(self) -> &'static str { + match self { + Op::BoolUnary(op) => op.name(), + Op::BoolBinary(op) => op.name(), + + Op::IntUnary(op) => op.name(), + Op::IntBinary(op) => op.name(), + + Op::FloatUnary(op) => op.name(), + Op::FloatBinary(op) => op.name(), + } + } +} + +impl BoolUnOp { + pub fn name(self) -> &'static str { + match self { + BoolUnOp::Not => "bool.not", + } + } +} + +impl BoolBinOp { + pub fn name(self) -> &'static str { + match self { + BoolBinOp::Eq => "bool.eq", + BoolBinOp::Ne => "bool.ne", + BoolBinOp::Or => "bool.or", + BoolBinOp::And => "bool.and", + } + } +} + +impl IntUnOp { + pub fn name(self) -> &'static str { + match self { + IntUnOp::Neg => "i.neg", + IntUnOp::Not => "i.not", + IntUnOp::CountOnes => "i.count_ones", + + IntUnOp::TruncOrZeroExtend => "u.trunc_or_zext", + IntUnOp::TruncOrSignExtend => "s.trunc_or_sext", + } + } +} + +impl IntBinOp { + pub fn output_count(self) -> usize { + // FIXME(eddyb) should these 4 go into a different `enum`? + match self { + IntBinOp::CarryingAdd + | IntBinOp::BorrowingSub + | IntBinOp::WideningMulU + | IntBinOp::WideningMulS => 2, + _ => 1, + } + } + + pub fn name(self) -> &'static str { + match self { + IntBinOp::Add => "i.add", + IntBinOp::Sub => "i.sub", + IntBinOp::Mul => "i.mul", + IntBinOp::DivU => "u.div", + IntBinOp::DivS => "s.div", + IntBinOp::ModU => "u.mod", + IntBinOp::RemS => "s.rem", + IntBinOp::ModS => "s.mod", + IntBinOp::ShrU => "u.shr", + IntBinOp::ShrS => "s.shr", + IntBinOp::Shl => "i.shl", + IntBinOp::Or => "i.or", + IntBinOp::Xor => "i.xor", + IntBinOp::And => "i.and", + IntBinOp::CarryingAdd => "i.carrying_add", + IntBinOp::BorrowingSub => "i.borrowing_sub", + IntBinOp::WideningMulU => "u.widening_mul", + IntBinOp::WideningMulS => "s.widening_mul", + IntBinOp::Eq => "i.eq", + IntBinOp::Ne => "i.ne", + IntBinOp::GtU => "u.gt", + IntBinOp::GtS => "s.gt", + IntBinOp::GeU => "u.ge", + IntBinOp::GeS => "s.ge", + IntBinOp::LtU => "u.lt", + IntBinOp::LtS => "s.lt", + IntBinOp::LeU => "u.le", + IntBinOp::LeS => "s.le", + } + } +} + +impl FloatUnOp { + pub fn name(self) -> &'static str { + match self { + FloatUnOp::Neg => "f.neg", + + FloatUnOp::IsNan => "f.is_nan", + FloatUnOp::IsInf => "f.is_inf", + + FloatUnOp::FromUInt => "f.from_uint", + FloatUnOp::FromSInt => "f.from_sint", + FloatUnOp::ToUInt => "f.to_uint", + FloatUnOp::ToSInt => "f.to_sint", + FloatUnOp::Convert => "f.convert", + FloatUnOp::QuantizeAsF16 => "f.quantize_as_f16", + } + } +} + +impl FloatBinOp { + pub fn name(self) -> &'static str { + match self { + FloatBinOp::Add => "f.add", + FloatBinOp::Sub => "f.sub", + FloatBinOp::Mul => "f.mul", + FloatBinOp::Div => "f.div", + FloatBinOp::Rem => "f.rem", + FloatBinOp::Mod => "f.mod", + FloatBinOp::Cmp(FloatCmp::Eq) => "f.eq", + FloatBinOp::Cmp(FloatCmp::Ne) => "f.ne", + FloatBinOp::Cmp(FloatCmp::Lt) => "f.lt", + FloatBinOp::Cmp(FloatCmp::Gt) => "f.gt", + FloatBinOp::Cmp(FloatCmp::Le) => "f.le", + FloatBinOp::Cmp(FloatCmp::Ge) => "f.ge", + FloatBinOp::CmpOrUnord(FloatCmp::Eq) => "f.eq_or_unord", + FloatBinOp::CmpOrUnord(FloatCmp::Ne) => "f.ne_or_unord", + FloatBinOp::CmpOrUnord(FloatCmp::Lt) => "f.lt_or_unord", + FloatBinOp::CmpOrUnord(FloatCmp::Gt) => "f.gt_or_unord", + FloatBinOp::CmpOrUnord(FloatCmp::Le) => "f.le_or_unord", + FloatBinOp::CmpOrUnord(FloatCmp::Ge) => "f.ge_or_unord", + } + } +} diff --git a/src/spv/canonical.rs b/src/spv/canonical.rs new file mode 100644 index 00000000..ea679873 --- /dev/null +++ b/src/spv/canonical.rs @@ -0,0 +1,529 @@ +//! Bidirectional (SPIR-V <-> SPIR-T) "canonical mappings". +//! +//! Both directions are defined close together as much as possible, to: +//! - limit code duplication, making it easy to add more mappings +//! - limit how much they could even go out of sync over time +//! - prevent naming e.g. SPIR-V opcodes, outside canonicalization +// +// FIXME(eddyb) should interning attempts check/apply these canonicalizations? + +use crate::spv::{self, spec}; +use crate::{scalar, vector, Const, ConstKind, Context, DataInstKind, Type, TypeKind, TypeOrConst}; +use lazy_static::lazy_static; +use smallvec::SmallVec; + +// FIXME(eddyb) these ones could maybe make use of build script generation. +macro_rules! def_mappable_ops { + ( + type { $($ty_op:ident),+ $(,)? } + const { $($ct_op:ident),+ $(,)? } + data_inst { $($di_op:ident),+ $(,)? } + $($enum_path:path { $($variant_op:ident <=> $variant:ident$(($($variant_args:tt)*))?),+ $(,)? })* + ) => { + #[allow(non_snake_case)] + struct MappableOps { + $($ty_op: spec::Opcode,)+ + $($ct_op: spec::Opcode,)+ + $($di_op: spec::Opcode,)+ + $($($variant_op: spec::Opcode,)+)* + } + impl MappableOps { + #[inline(always)] + #[must_use] + pub fn get() -> &'static MappableOps { + lazy_static! { + static ref MAPPABLE_OPS: MappableOps = { + let spv_spec = spec::Spec::get(); + MappableOps { + $($ty_op: spv_spec.instructions.lookup(stringify!($ty_op)).unwrap(),)+ + $($ct_op: spv_spec.instructions.lookup(stringify!($ct_op)).unwrap(),)+ + $($di_op: spv_spec.instructions.lookup(stringify!($di_op)).unwrap(),)+ + $($($variant_op: spv_spec.instructions.lookup(stringify!($variant_op)).unwrap(),)+)* + } + }; + } + &MAPPABLE_OPS + } + } + // NOTE(eddyb) these should stay private, hence not implementing `TryFrom`. + $(impl $enum_path { + fn try_from_opcode(opcode: spec::Opcode) -> Option { + let mo = MappableOps::get(); + $(if opcode == mo.$variant_op { + return Some(Self::$variant$(($($variant_args)*))?); + })+ + None + } + fn to_opcode(self) -> spec::Opcode { + let mo = MappableOps::get(); + match self { + $(Self::$variant$(($($variant_args)*))? => mo.$variant_op,)+ + } + } + })* + }; +} +def_mappable_ops! { + // FIXME(eddyb) these categories don't actually do anything right now + type { + OpTypeBool, + OpTypeInt, + OpTypeFloat, + OpTypeVector, + } + const { + OpUndef, + OpConstantFalse, + OpConstantTrue, + OpConstant, + } + data_inst { + OpVectorExtractDynamic, + OpVectorInsertDynamic, + OpVectorTimesScalar, + } + scalar::BoolUnOp { + OpLogicalNot <=> Not, + } + scalar::BoolBinOp { + OpLogicalEqual <=> Eq, + OpLogicalNotEqual <=> Ne, + OpLogicalOr <=> Or, + OpLogicalAnd <=> And, + } + scalar::IntUnOp { + OpSNegate <=> Neg, + OpNot <=> Not, + OpBitCount <=> CountOnes, + + OpUConvert <=> TruncOrZeroExtend, + OpSConvert <=> TruncOrSignExtend, + } + scalar::IntBinOp { + // I×I→I + OpIAdd <=> Add, + OpISub <=> Sub, + OpIMul <=> Mul, + OpUDiv <=> DivU, + OpSDiv <=> DivS, + OpUMod <=> ModU, + OpSRem <=> RemS, + OpSMod <=> ModS, + OpShiftRightLogical <=> ShrU, + OpShiftRightArithmetic <=> ShrS, + OpShiftLeftLogical <=> Shl, + OpBitwiseOr <=> Or, + OpBitwiseXor <=> Xor, + OpBitwiseAnd <=> And, + + // I×I→I×I + OpIAddCarry <=> CarryingAdd, + OpISubBorrow <=> BorrowingSub, + OpUMulExtended <=> WideningMulU, + OpSMulExtended <=> WideningMulS, + + // I×I→B + OpIEqual <=> Eq, + OpINotEqual <=> Ne, + OpUGreaterThan <=> GtU, + OpSGreaterThan <=> GtS, + OpUGreaterThanEqual <=> GeU, + OpSGreaterThanEqual <=> GeS, + OpULessThan <=> LtU, + OpSLessThan <=> LtS, + OpULessThanEqual <=> LeU, + OpSLessThanEqual <=> LeS, + } + scalar::FloatUnOp { + // F→F + OpFNegate <=> Neg, + + // F→B + OpIsNan <=> IsNan, + OpIsInf <=> IsInf, + + OpConvertUToF <=> FromUInt, + OpConvertSToF <=> FromSInt, + OpConvertFToU <=> ToUInt, + OpConvertFToS <=> ToSInt, + OpFConvert <=> Convert, + OpQuantizeToF16 <=> QuantizeAsF16, + } + scalar::FloatBinOp { + // F×F→F + OpFAdd <=> Add, + OpFSub <=> Sub, + OpFMul <=> Mul, + OpFDiv <=> Div, + OpFRem <=> Rem, + OpFMod <=> Mod, + + // F×F→B + OpFOrdEqual <=> Cmp(scalar::FloatCmp::Eq), + OpFOrdNotEqual <=> Cmp(scalar::FloatCmp::Ne), + OpFOrdLessThan <=> Cmp(scalar::FloatCmp::Lt), + OpFOrdGreaterThan <=> Cmp(scalar::FloatCmp::Gt), + OpFOrdLessThanEqual <=> Cmp(scalar::FloatCmp::Le), + OpFOrdGreaterThanEqual <=> Cmp(scalar::FloatCmp::Ge), + OpFUnordEqual <=> CmpOrUnord(scalar::FloatCmp::Eq), + OpFUnordNotEqual <=> CmpOrUnord(scalar::FloatCmp::Ne), + OpFUnordLessThan <=> CmpOrUnord(scalar::FloatCmp::Lt), + OpFUnordGreaterThan <=> CmpOrUnord(scalar::FloatCmp::Gt), + OpFUnordLessThanEqual <=> CmpOrUnord(scalar::FloatCmp::Le), + OpFUnordGreaterThanEqual <=> CmpOrUnord(scalar::FloatCmp::Ge), + } + vector::ReduceOp { + OpDot <=> Dot, + OpAny <=> Any, + OpAll <=> All, + } +} + +impl scalar::Const { + // HACK(eddyb) this is not private so `spv::lower` can use it for `OpSwitch`. + pub(super) fn try_decode_from_spv_imms( + ty: scalar::Type, + imms: &[spv::Imm], + ) -> Option { + // FIXME(eddyb) don't hardcode the 128-bit limitation, + // but query `scalar::Const` somehow instead. + if ty.bit_width() > 128 { + return None; + } + let imm_words = usize::try_from(ty.bit_width().div_ceil(32)).unwrap(); + if imms.len() != imm_words { + return None; + } + let mut bits = 0; + for (i, &imm) in imms.iter().enumerate() { + let w = match imm { + spv::Imm::Short(_, w) if imm_words == 1 => w, + spv::Imm::LongStart(_, w) if i == 0 && imm_words > 1 => w, + spv::Imm::LongCont(_, w) if i > 0 => w, + _ => return None, + }; + bits |= (w as u128) << (i * 32); + } + + // HACK(eddyb) signed integers are encoded sign-extended into immediates. + if let scalar::Type::SInt(_) = ty { + let imm_width = imm_words * 32; + scalar::Const::int_try_from_i128( + ty, + (bits as i128) << (128 - imm_width) >> (128 - imm_width), + ) + } else { + scalar::Const::try_from_bits(ty, bits) + } + } + + // HACK(eddyb) this is not private so `spv::lift` can use it for `OpSwitch`. + pub(super) fn encode_as_spv_imms(&self) -> impl Iterator { + let wk = &spec::Spec::get().well_known; + + let ty = self.ty(); + let imm_words = ty.bit_width().div_ceil(32); + + let bits = self.bits(); + + // HACK(eddyb) signed integers are encoded sign-extended into immediates. + let bits = if let scalar::Type::SInt(_) = ty { + let imm_width = imm_words * 32; + (self.int_as_i128().unwrap() as u128) & (!0 >> (128 - imm_width)) + } else { + bits + }; + + (0..imm_words).map(move |i| { + let k = wk.LiteralContextDependentNumber; + let w = (bits >> (i * 32)) as u32; + if imm_words == 1 { + spv::Imm::Short(k, w) + } else if i == 0 { + spv::Imm::LongStart(k, w) + } else { + spv::Imm::LongCont(k, w) + } + }) + } +} + +// FIXME(eddyb) decide on a visibility scope - `pub(super)` avoids some mistakes +// (using these methods outside of `spv::{lower,lift}`), but may be too restrictive. +impl spv::Inst { + // HACK(eddyb) exported only for `spv::read`/`LiteralContextDependentNumber`. + pub(super) fn int_or_float_type_bit_width(&self) -> Option { + let mo = MappableOps::get(); + + match self.imms[..] { + [spv::Imm::Short(_, bit_width), _] if self.opcode == mo.OpTypeInt => Some(bit_width), + [spv::Imm::Short(_, bit_width)] if self.opcode == mo.OpTypeFloat => Some(bit_width), + _ => None, + } + } + + // FIXME(eddyb) automate bidirectional mappings more (although the need + // for conditional, i.e. "partial", mappings, adds a lot of complexity). + pub(super) fn as_canonical_type( + &self, + cx: &Context, + type_and_const_inputs: &[TypeOrConst], + ) -> Option { + let Self { opcode, imms } = self; + let (&opcode, imms) = (opcode, &imms[..]); + + let mo = MappableOps::get(); + + let int_width = || scalar::IntWidth::try_from_bits(self.int_or_float_type_bit_width()?); + match (imms, type_and_const_inputs) { + ([], []) if opcode == mo.OpTypeBool => Some(scalar::Type::Bool.into()), + (&[_, spv::Imm::Short(_, 0)], []) if opcode == mo.OpTypeInt => { + Some(scalar::Type::UInt(int_width()?).into()) + } + (&[_, spv::Imm::Short(_, 1)], []) if opcode == mo.OpTypeInt => { + Some(scalar::Type::SInt(int_width()?).into()) + } + ([_], []) if opcode == mo.OpTypeFloat => Some( + scalar::Type::Float(scalar::FloatWidth::try_from_bits( + self.int_or_float_type_bit_width()?, + )?) + .into(), + ), + (&[spv::Imm::Short(_, elem_count)], &[TypeOrConst::Type(elem_type)]) + if opcode == mo.OpTypeVector => + { + Some( + vector::Type { + elem: elem_type.as_scalar(cx)?, + elem_count: u8::try_from(elem_count).ok()?.try_into().ok()?, + } + .into(), + ) + } + _ => None, + } + } + + pub(super) fn from_canonical_type( + cx: &Context, + type_kind: &TypeKind, + ) -> Option<(Self, SmallVec<[TypeOrConst; 2]>)> { + let wk = &spec::Spec::get().well_known; + let mo = MappableOps::get(); + + match type_kind { + &TypeKind::Scalar(ty) => Some(( + match ty { + scalar::Type::Bool => mo.OpTypeBool.into(), + scalar::Type::SInt(w) | scalar::Type::UInt(w) => spv::Inst { + opcode: mo.OpTypeInt, + imms: [ + spv::Imm::Short(wk.LiteralInteger, w.bits()), + spv::Imm::Short( + wk.LiteralInteger, + matches!(ty, scalar::Type::SInt(_)) as u32, + ), + ] + .into_iter() + .collect(), + }, + scalar::Type::Float(w) => spv::Inst { + opcode: mo.OpTypeFloat, + imms: [spv::Imm::Short(wk.LiteralInteger, w.bits())].into_iter().collect(), + }, + }, + [].into_iter().collect(), + )), + + TypeKind::Vector(ty) => Some(( + spv::Inst { + opcode: mo.OpTypeVector, + imms: [spv::Imm::Short(wk.LiteralInteger, ty.elem_count.get().into())] + .into_iter() + .collect(), + }, + [TypeOrConst::Type(cx.intern(ty.elem))].into_iter().collect(), + )), + + TypeKind::QPtr | TypeKind::SpvInst { .. } | TypeKind::SpvStringLiteralForExtInst => { + None + } + } + } + + // HACK(eddyb) this only exists as a helper for `spv::lower`. + pub(super) fn always_lower_as_const(&self) -> bool { + let mo = MappableOps::get(); + mo.OpUndef == self.opcode + } + + // FIXME(eddyb) automate bidirectional mappings more (although the need + // for conditional, i.e. "partial", mappings, adds a lot of complexity). + pub(super) fn as_canonical_const( + &self, + cx: &Context, + ty: Type, + const_inputs: &[Const], + ) -> Option { + let Self { opcode, imms } = self; + let (&opcode, imms) = (opcode, &imms[..]); + + let wk = &spec::Spec::get().well_known; + let mo = MappableOps::get(); + + match (imms, const_inputs) { + ([], []) if opcode == mo.OpUndef => Some(ConstKind::Undef), + ([], []) if opcode == mo.OpConstantFalse => Some(scalar::Const::FALSE.into()), + ([], []) if opcode == mo.OpConstantTrue => Some(scalar::Const::TRUE.into()), + (_, []) if opcode == mo.OpConstant => { + Some(scalar::Const::try_decode_from_spv_imms(ty.as_scalar(cx)?, imms)?.into()) + } + _ if opcode == wk.OpConstantComposite => { + let ty = ty.as_vector(cx)?; + let elems = (const_inputs.len() == usize::from(ty.elem_count.get()) + && const_inputs.iter().all(|ct| ct.as_scalar(cx).is_some())) + .then(|| const_inputs.iter().map(|ct| *ct.as_scalar(cx).unwrap()))?; + Some(vector::Const::from_elems(ty, elems).into()) + } + _ => None, + } + } + + pub(super) fn from_canonical_const( + cx: &Context, + const_kind: &ConstKind, + ) -> Option<(Self, SmallVec<[Const; 4]>)> { + let wk = &spec::Spec::get().well_known; + let mo = MappableOps::get(); + + match const_kind { + ConstKind::Undef => Some((mo.OpUndef.into(), [].into_iter().collect())), + &ConstKind::Scalar(ct) => Some(( + match ct { + scalar::Const::FALSE => mo.OpConstantFalse.into(), + scalar::Const::TRUE => mo.OpConstantTrue.into(), + _ => { + spv::Inst { opcode: mo.OpConstant, imms: ct.encode_as_spv_imms().collect() } + } + }, + [].into_iter().collect(), + )), + + ConstKind::Vector(ct) => Some(( + wk.OpConstantComposite.into(), + ct.elems().map(|elem| cx.intern(elem)).collect(), + )), + + ConstKind::PtrToGlobalVar(_) + | ConstKind::SpvInst { .. } + | ConstKind::SpvStringLiteralForExtInst(_) => None, + } + } + + pub(super) fn as_canonical_data_inst_kind( + &self, + cx: &Context, + output_types: &[Type], + ) -> Option { + let Self { opcode, imms } = self; + let (&opcode, imms) = (opcode, &imms[..]); + + let scalar_op = (scalar::BoolUnOp::try_from_opcode(opcode).map(scalar::Op::from)) + .or_else(|| scalar::BoolBinOp::try_from_opcode(opcode).map(scalar::Op::from)) + .or_else(|| scalar::IntUnOp::try_from_opcode(opcode).map(scalar::Op::from)) + .or_else(|| scalar::IntBinOp::try_from_opcode(opcode).map(scalar::Op::from)) + .or_else(|| scalar::FloatUnOp::try_from_opcode(opcode).map(scalar::Op::from)) + .or_else(|| scalar::FloatBinOp::try_from_opcode(opcode).map(scalar::Op::from)); + if let Some(op) = scalar_op { + assert_eq!(imms.len(), 0); + + let (_scalar_type, vec_elem_count) = (output_types.len() == op.output_count()) + .then(|| { + output_types.iter().map(|&ty| match cx[ty].kind { + TypeKind::Scalar(ty) => Some((ty, None)), + TypeKind::Vector(ty) => Some((ty.elem, Some(ty.elem_count))), + _ => None, + }) + }) + .and_then(|mut outputs| { + let first = outputs.next().unwrap()?; + outputs.all(|x| x == Some(first)).then_some(first) + })?; + + Some(if vec_elem_count.is_some() { + vector::Op::Distribute(op).into() + } else { + op.into() + }) + } else if let Some(op) = vector::ReduceOp::try_from_opcode(opcode).map(vector::Op::from) { + assert_eq!(imms.len(), 0); + Some(op.into()) + } else { + let wk = &spec::Spec::get().well_known; + let mo = MappableOps::get(); + + // FIXME(eddyb) automate this by supporting immediates in the macro. + let v_whole = |op| Some(vector::Op::Whole(op).into()); + match imms { + [] if opcode == wk.OpCompositeConstruct => v_whole(vector::WholeOp::New), + &[spv::Imm::Short(_, elem_idx)] if opcode == wk.OpCompositeExtract => { + v_whole(vector::WholeOp::Extract { elem_idx: elem_idx.try_into().ok()? }) + } + &[spv::Imm::Short(_, elem_idx)] if opcode == wk.OpCompositeInsert => { + v_whole(vector::WholeOp::Insert { elem_idx: elem_idx.try_into().ok()? }) + } + [] if opcode == mo.OpVectorExtractDynamic => v_whole(vector::WholeOp::DynExtract), + [] if opcode == mo.OpVectorInsertDynamic => v_whole(vector::WholeOp::DynInsert), + [] if opcode == mo.OpVectorTimesScalar => v_whole(vector::WholeOp::Mul), + _ => None, + } + } + } + + pub(super) fn from_canonical_data_inst_kind(data_inst_kind: &DataInstKind) -> Option { + match data_inst_kind { + &DataInstKind::Scalar(op) => Some(match op { + scalar::Op::BoolUnary(op) => op.to_opcode().into(), + scalar::Op::BoolBinary(op) => op.to_opcode().into(), + scalar::Op::IntUnary(op) => op.to_opcode().into(), + scalar::Op::IntBinary(op) => op.to_opcode().into(), + scalar::Op::FloatUnary(op) => op.to_opcode().into(), + scalar::Op::FloatBinary(op) => op.to_opcode().into(), + }), + &DataInstKind::Vector(op) => Some(match op { + vector::Op::Distribute(op) => { + Self::from_canonical_data_inst_kind(&DataInstKind::Scalar(op)).unwrap() + } + vector::Op::Reduce(op) => op.to_opcode().into(), + vector::Op::Whole(op) => { + let wk = &spec::Spec::get().well_known; + let mo = MappableOps::get(); + + // FIXME(eddyb) automate this by supporting immediates in the macro. + match op { + vector::WholeOp::New => wk.OpCompositeConstruct.into(), + vector::WholeOp::Extract { elem_idx } => spv::Inst { + opcode: wk.OpCompositeExtract, + imms: [spv::Imm::Short(wk.LiteralInteger, elem_idx.into())] + .into_iter() + .collect(), + }, + vector::WholeOp::Insert { elem_idx } => spv::Inst { + opcode: wk.OpCompositeInsert, + imms: [spv::Imm::Short(wk.LiteralInteger, elem_idx.into())] + .into_iter() + .collect(), + }, + vector::WholeOp::DynExtract => mo.OpVectorExtractDynamic.into(), + vector::WholeOp::DynInsert => mo.OpVectorInsertDynamic.into(), + vector::WholeOp::Mul => mo.OpVectorTimesScalar.into(), + } + } + }), + DataInstKind::FuncCall(_) + | DataInstKind::QPtr(_) + | DataInstKind::SpvInst(..) + | DataInstKind::SpvExtInst { .. } => None, + } + } +} diff --git a/src/spv/lift.rs b/src/spv/lift.rs index 690dd449..88170778 100644 --- a/src/spv/lift.rs +++ b/src/spv/lift.rs @@ -4,7 +4,7 @@ use crate::func_at::FuncAt; use crate::spv::{self, spec}; use crate::visit::{InnerVisit, Visitor}; use crate::{ - cfg, AddrSpace, Attr, AttrSet, Const, ConstDef, ConstKind, Context, ControlNode, + cfg, scalar, AddrSpace, Attr, AttrSet, Const, ConstDef, ConstKind, Context, ControlNode, ControlNodeKind, ControlNodeOutputDecl, ControlRegion, ControlRegionInputDecl, DataInst, DataInstDef, DataInstForm, DataInstFormDef, DataInstKind, DeclDef, EntityList, ExportKey, Exportee, Func, FuncDecl, FuncParam, FxIndexMap, FxIndexSet, GlobalVar, GlobalVarDefBody, @@ -121,14 +121,29 @@ impl Visitor<'_> for NeedsIdsCollector<'_> { return; } let ty_def = &self.cx[ty]; + + // HACK(eddyb) there isn't a great way to handle canonical types, but + // perhaps this result should be recorded in `self.globals`? + if let Some((_spv_inst, type_and_const_inputs)) = + spv::Inst::from_canonical_type(self.cx, &ty_def.kind) + { + for ty_or_ct in type_and_const_inputs { + match ty_or_ct { + TypeOrConst::Type(ty) => self.visit_type_use(ty), + TypeOrConst::Const(ct) => self.visit_const_use(ct), + } + } + } + match ty_def.kind { + TypeKind::Scalar(_) | TypeKind::Vector(_) | TypeKind::SpvInst { .. } => {} + // FIXME(eddyb) this should be a proper `Result`-based error instead, // and/or `spv::lift` should mutate the module for legalization. TypeKind::QPtr => { unreachable!("`TypeKind::QPtr` should be legalized away before lifting"); } - TypeKind::SpvInst { .. } => {} TypeKind::SpvStringLiteralForExtInst => { unreachable!( "`TypeKind::SpvStringLiteralForExtInst` should not be used \ @@ -136,6 +151,7 @@ impl Visitor<'_> for NeedsIdsCollector<'_> { ); } } + self.visit_type_def(ty_def); self.globals.insert(global); } @@ -145,8 +161,23 @@ impl Visitor<'_> for NeedsIdsCollector<'_> { return; } let ct_def = &self.cx[ct]; + + // HACK(eddyb) there isn't a great way to handle canonical consts, but + // perhaps this result should be recorded in `self.globals`? + if let Some((_spv_inst, const_inputs)) = + spv::Inst::from_canonical_const(self.cx, &ct_def.kind) + { + for ct in const_inputs { + self.visit_const_use(ct); + } + } + match ct_def.kind { - ConstKind::PtrToGlobalVar(_) | ConstKind::SpvInst { .. } => { + ConstKind::Undef + | ConstKind::Scalar(_) + | ConstKind::Vector(_) + | ConstKind::PtrToGlobalVar(_) + | ConstKind::SpvInst { .. } => { self.visit_const_def(ct_def); self.globals.insert(global); } @@ -216,7 +247,6 @@ impl Visitor<'_> for NeedsIdsCollector<'_> { } fn visit_data_inst_form_def(&mut self, data_inst_form_def: &DataInstFormDef) { - #[allow(clippy::match_same_arms)] match data_inst_form_def.kind { // FIXME(eddyb) this should be a proper `Result`-based error instead, // and/or `spv::lift` should mutate the module for legalization. @@ -224,9 +254,11 @@ impl Visitor<'_> for NeedsIdsCollector<'_> { unreachable!("`DataInstKind::QPtr` should be legalized away before lifting"); } - DataInstKind::FuncCall(_) => {} + DataInstKind::Scalar(_) + | DataInstKind::Vector(_) + | DataInstKind::FuncCall(_) + | DataInstKind::SpvInst(_) => {} - DataInstKind::SpvInst(_) => {} DataInstKind::SpvExtInst { ext_set, .. } => { self.ext_inst_imports.insert(&self.cx[ext_set]); } @@ -522,8 +554,6 @@ impl<'a> FuncLifting<'a> { func_decl: &'a FuncDecl, mut alloc_id: impl FnMut() -> Result, ) -> Result { - let wk = &spec::Spec::get().well_known; - let func_id = alloc_id()?; let param_ids = func_decl.params.iter().map(|_| alloc_id()).collect::>()?; @@ -758,15 +788,9 @@ impl<'a> FuncLifting<'a> { .collect(); let is_infinite_loop = match repeat_condition { - Value::Const(cond) => match &cx[cond].kind { - ConstKind::SpvInst { spv_inst_and_const_inputs } => { - let (spv_inst, _const_inputs) = - &**spv_inst_and_const_inputs; - spv_inst.opcode == wk.OpConstantTrue - } - _ => false, - }, - + Value::Const(cond) => { + matches!(cx[cond].kind, ConstKind::Scalar(scalar::Const::TRUE)) + } _ => false, }; if is_infinite_loop { @@ -1036,7 +1060,11 @@ impl LazyInst<'_, '_> { }; (gv_decl.attrs, import) } - ConstKind::SpvInst { .. } => (ct_def.attrs, None), + + ConstKind::Undef + | ConstKind::Scalar(_) + | ConstKind::Vector(_) + | ConstKind::SpvInst { .. } => (ct_def.attrs, None), // Not inserted into `globals` while visiting. ConstKind::SpvStringLiteralForExtInst(_) => unreachable!(), @@ -1102,29 +1130,72 @@ impl LazyInst<'_, '_> { let (result_id, attrs, _) = self.result_id_attrs_and_import(module, ids); let inst = match self { Self::Global(global) => match global { - Global::Type(ty) => match &cx[ty].kind { - TypeKind::SpvInst { spv_inst, type_and_const_inputs } => spv::InstWithIds { - without_ids: spv_inst.clone(), - result_type_id: None, - result_id, - ids: type_and_const_inputs - .iter() - .map(|&ty_or_ct| { - ids.globals[&match ty_or_ct { - TypeOrConst::Type(ty) => Global::Type(ty), - TypeOrConst::Const(ct) => Global::Const(ct), - }] - }) - .collect(), - }, + Global::Type(ty) => { + let ty_def = &cx[ty]; + match spv::Inst::from_canonical_type(cx, &ty_def.kind) + .as_ref() + .ok_or(&ty_def.kind) + { + Err(TypeKind::Scalar(_) | TypeKind::Vector(_)) => { + unreachable!("should've been handled as canonical") + } - // Not inserted into `globals` while visiting. - TypeKind::QPtr | TypeKind::SpvStringLiteralForExtInst => unreachable!(), - }, + Ok((spv_inst, type_and_const_inputs)) + | Err(TypeKind::SpvInst { spv_inst, type_and_const_inputs }) => { + spv::InstWithIds { + without_ids: spv_inst.clone(), + result_type_id: None, + result_id, + ids: type_and_const_inputs + .iter() + .map(|&ty_or_ct| { + ids.globals[&match ty_or_ct { + TypeOrConst::Type(ty) => Global::Type(ty), + TypeOrConst::Const(ct) => Global::Const(ct), + }] + }) + .collect(), + } + } + + // Not inserted into `globals` while visiting. + Err(TypeKind::QPtr | TypeKind::SpvStringLiteralForExtInst) => { + unreachable!() + } + } + } Global::Const(ct) => { let ct_def = &cx[ct]; - match &ct_def.kind { - &ConstKind::PtrToGlobalVar(gv) => { + match spv::Inst::from_canonical_const(cx, &ct_def.kind).ok_or(&ct_def.kind) { + // FIXME(eddyb) this duplicates the `ConstKind::SpvInst` + // case, only due to an inability to pattern-match `Rc`. + Ok((spv_inst, const_inputs)) => spv::InstWithIds { + without_ids: spv_inst, + result_type_id: Some(ids.globals[&Global::Type(ct_def.ty)]), + result_id, + ids: const_inputs + .iter() + .map(|&ct| ids.globals[&Global::Const(ct)]) + .collect(), + }, + Err(ConstKind::SpvInst { spv_inst_and_const_inputs }) => { + let (spv_inst, const_inputs) = &**spv_inst_and_const_inputs; + spv::InstWithIds { + without_ids: spv_inst.clone(), + result_type_id: Some(ids.globals[&Global::Type(ct_def.ty)]), + result_id, + ids: const_inputs + .iter() + .map(|&ct| ids.globals[&Global::Const(ct)]) + .collect(), + } + } + + Err(ConstKind::Undef | ConstKind::Scalar(_) | ConstKind::Vector(_)) => { + unreachable!("should've been handled as canonical") + } + + Err(&ConstKind::PtrToGlobalVar(gv)) => { assert!(ct_def.attrs == AttrSet::default()); let gv_decl = &module.global_vars[gv]; @@ -1157,21 +1228,8 @@ impl LazyInst<'_, '_> { } } - ConstKind::SpvInst { spv_inst_and_const_inputs } => { - let (spv_inst, const_inputs) = &**spv_inst_and_const_inputs; - spv::InstWithIds { - without_ids: spv_inst.clone(), - result_type_id: Some(ids.globals[&Global::Type(ct_def.ty)]), - result_id, - ids: const_inputs - .iter() - .map(|&ct| ids.globals[&Global::Const(ct)]) - .collect(), - } - } - // Not inserted into `globals` while visiting. - ConstKind::SpvStringLiteralForExtInst(_) => unreachable!(), + Err(ConstKind::SpvStringLiteralForExtInst(_)) => unreachable!(), } } }, @@ -1228,23 +1286,30 @@ impl LazyInst<'_, '_> { }, Self::DataInst { parent_func, result_id: _, data_inst_def } => { let DataInstFormDef { kind, output_type } = &cx[data_inst_def.form]; - let (inst, extra_initial_id_operand) = match kind { - // Disallowed while visiting. - DataInstKind::QPtr(_) => unreachable!(), + let (inst, extra_initial_id_operand) = + match spv::Inst::from_canonical_data_inst_kind(kind).ok_or(kind) { + Ok(spv_inst) => (spv_inst, None), - &DataInstKind::FuncCall(callee) => { - (wk.OpFunctionCall.into(), Some(ids.funcs[&callee].func_id)) - } - DataInstKind::SpvInst(inst) => (inst.clone(), None), - &DataInstKind::SpvExtInst { ext_set, inst } => ( - spv::Inst { - opcode: wk.OpExtInst, - imms: iter::once(spv::Imm::Short(wk.LiteralExtInstInteger, inst)) - .collect(), - }, - Some(ids.ext_inst_imports[&cx[ext_set]]), - ), - }; + Err(DataInstKind::Scalar(_) | DataInstKind::Vector(_)) => { + unreachable!("should've been handled as canonical") + } + + // Disallowed while visiting. + Err(DataInstKind::QPtr(_)) => unreachable!(), + + Err(&DataInstKind::FuncCall(callee)) => { + (wk.OpFunctionCall.into(), Some(ids.funcs[&callee].func_id)) + } + Err(DataInstKind::SpvInst(inst)) => (inst.clone(), None), + Err(&DataInstKind::SpvExtInst { ext_set, inst }) => ( + spv::Inst { + opcode: wk.OpExtInst, + imms: iter::once(spv::Imm::Short(wk.LiteralExtInstInteger, inst)) + .collect(), + }, + Some(ids.ext_inst_imports[&cx[ext_set]]), + ), + }; spv::InstWithIds { without_ids: inst, result_type_id: output_type.map(|ty| ids.globals[&Global::Type(ty)]), @@ -1277,6 +1342,14 @@ impl LazyInst<'_, '_> { ids: [merge_label_id, continue_label_id].into_iter().collect(), }, Self::Terminator { parent_func, terminator } => { + let mut ids: SmallVec<[_; 4]> = terminator + .inputs + .iter() + .map(|&v| value_to_id(parent_func, v)) + .chain(terminator.targets.iter().map(|&target| parent_func.label_ids[&target])) + .collect(); + + // FIXME(eddyb) move some of this to `spv::canonical`. let inst = match &*terminator.kind { cfg::ControlInstKind::Unreachable => wk.OpUnreachable.into(), cfg::ControlInstKind::Return => { @@ -1295,23 +1368,21 @@ impl LazyInst<'_, '_> { cfg::ControlInstKind::SelectBranch(SelectionKind::BoolCond) => { wk.OpBranchConditional.into() } - cfg::ControlInstKind::SelectBranch(SelectionKind::SpvInst(inst)) => { - inst.clone() + cfg::ControlInstKind::SelectBranch(SelectionKind::Switch { case_consts }) => { + // HACK(eddyb) move the default case from last back to first. + let default_target = ids.pop().unwrap(); + ids.insert(1, default_target); + + spv::Inst { + opcode: wk.OpSwitch, + imms: case_consts + .iter() + .flat_map(|ct| ct.encode_as_spv_imms()) + .collect(), + } } }; - spv::InstWithIds { - without_ids: inst, - result_type_id: None, - result_id: None, - ids: terminator - .inputs - .iter() - .map(|&v| value_to_id(parent_func, v)) - .chain( - terminator.targets.iter().map(|&target| parent_func.label_ids[&target]), - ) - .collect(), - } + spv::InstWithIds { without_ids: inst, result_type_id: None, result_id: None, ids } } Self::OpFunctionEnd => spv::InstWithIds { without_ids: wk.OpFunctionEnd.into(), diff --git a/src/spv/lower.rs b/src/spv/lower.rs index 1e62dc9e..b7a6077c 100644 --- a/src/spv/lower.rs +++ b/src/spv/lower.rs @@ -3,11 +3,11 @@ use crate::spv::{self, spec}; // FIXME(eddyb) import more to avoid `crate::` everywhere. use crate::{ - cfg, print, AddrSpace, Attr, AttrSet, Const, ConstDef, ConstKind, Context, ControlNodeDef, - ControlNodeKind, ControlRegion, ControlRegionDef, ControlRegionInputDecl, DataInstDef, - DataInstFormDef, DataInstKind, DeclDef, Diag, EntityDefs, EntityList, ExportKey, Exportee, - Func, FuncDecl, FuncDefBody, FuncParam, FxIndexMap, GlobalVarDecl, GlobalVarDefBody, Import, - InternedStr, Module, SelectionKind, Type, TypeDef, TypeKind, TypeOrConst, Value, + cfg, print, scalar, AddrSpace, Attr, AttrSet, Const, ConstDef, ConstKind, Context, + ControlNodeDef, ControlNodeKind, ControlRegion, ControlRegionDef, ControlRegionInputDecl, + DataInstDef, DataInstFormDef, DataInstKind, DeclDef, Diag, EntityDefs, EntityList, ExportKey, + Exportee, Func, FuncDecl, FuncDefBody, FuncParam, FxIndexMap, GlobalVarDecl, GlobalVarDefBody, + Import, InternedStr, Module, SelectionKind, Type, TypeDef, TypeKind, TypeOrConst, Value, }; use rustc_hash::FxHashMap; use smallvec::SmallVec; @@ -85,6 +85,20 @@ fn invalid(reason: &str) -> io::Error { io::Error::new(io::ErrorKind::InvalidData, format!("malformed SPIR-V ({reason})")) } +fn invalid_factory_for_spv_inst( + inst: &spv::Inst, + result_id: Option, + ids: &[spv::Id], +) -> impl Fn(&str) -> io::Error { + let opcode = inst.opcode; + let first_id_operand = ids.first().copied(); + move |msg: &str| { + let result_prefix = result_id.map(|id| format!("%{id} = ")).unwrap_or_default(); + let operand_suffix = first_id_operand.map(|id| format!(" %{id} ...")).unwrap_or_default(); + invalid(&format!("in {result_prefix}{}{operand_suffix}: {msg}", opcode.name())) + } +} + // FIXME(eddyb) provide more information about any normalization that happened: // * stats about deduplication that occured through interning // * sets of unused global vars and functions (and types+consts only they use) @@ -195,7 +209,7 @@ impl Module { while let Some(mut inst) = spv_insts.next().transpose()? { let opcode = inst.opcode; - let invalid = |msg: &str| invalid(&format!("in {}: {}", opcode.name(), msg)); + let invalid = invalid_factory_for_spv_inst(&inst, inst.result_id, &inst.ids); // Handle line debuginfo early, as it doesn't have its own section, // but rather can go almost anywhere among globals and functions. @@ -557,7 +571,7 @@ impl Module { } else if inst_category == spec::InstructionCategory::Type { assert!(inst.result_type_id.is_none()); let id = inst.result_id.unwrap(); - let type_and_const_inputs = inst + let type_and_const_inputs: SmallVec<_> = inst .ids .iter() .map(|&id| match id_defs.get(&id) { @@ -575,14 +589,20 @@ impl Module { let ty = cx.intern(TypeDef { attrs: mem::take(&mut attrs), - kind: TypeKind::SpvInst { spv_inst: inst.without_ids, type_and_const_inputs }, + kind: inst.as_canonical_type(&cx, &type_and_const_inputs).unwrap_or( + TypeKind::SpvInst { spv_inst: inst.without_ids, type_and_const_inputs }, + ), }); id_defs.insert(id, IdDef::Type(ty)); Seq::TypeConstOrGlobalVar - } else if inst_category == spec::InstructionCategory::Const || opcode == wk.OpUndef { + } else if inst_category == spec::InstructionCategory::Const + || inst.always_lower_as_const() + { let id = inst.result_id.unwrap(); - let const_inputs = inst + let ty = result_type.unwrap(); + + let const_inputs: SmallVec<_> = inst .ids .iter() .map(|&id| match id_defs.get(&id) { @@ -599,14 +619,16 @@ impl Module { let ct = cx.intern(ConstDef { attrs: mem::take(&mut attrs), - ty: result_type.unwrap(), - kind: ConstKind::SpvInst { - spv_inst_and_const_inputs: Rc::new((inst.without_ids, const_inputs)), - }, + ty, + kind: inst.as_canonical_const(&cx, ty, &const_inputs).unwrap_or_else(|| { + ConstKind::SpvInst { + spv_inst_and_const_inputs: Rc::new((inst.without_ids, const_inputs)), + } + }), }); id_defs.insert(id, IdDef::Const(ct)); - if opcode == wk.OpUndef { + if inst_category != spec::InstructionCategory::Const { // `OpUndef` can appear either among constants, or in a // function, so at most advance `seq` to globals. seq.max(Some(Seq::TypeConstOrGlobalVar)).unwrap() @@ -843,7 +865,7 @@ impl Module { #[derive(Copy, Clone)] enum LocalIdDef { - Value(Value), + Value(Type, Value), BlockLabel(ControlRegion), } @@ -871,6 +893,7 @@ impl Module { let IntraFuncInst { without_ids: spv::Inst { opcode, ref imms }, result_id, + result_type, .. } = *raw_inst; @@ -885,10 +908,10 @@ impl Module { DeclDef::Present(def) => def.body, }; - LocalIdDef::Value(Value::ControlRegionInput { - region: body, - input_idx: idx, - }) + LocalIdDef::Value( + result_type.unwrap(), + Value::ControlRegionInput { region: body, input_idx: idx }, + ) } else { let is_entry_block = !has_blocks; has_blocks = true; @@ -939,10 +962,13 @@ impl Module { .push(value_id); } - LocalIdDef::Value(Value::ControlRegionInput { - region: current_block, - input_idx: phi_idx, - }) + LocalIdDef::Value( + result_type.unwrap(), + Value::ControlRegionInput { + region: current_block, + input_idx: phi_idx, + }, + ) } else { // HACK(eddyb) can't get a `DataInst` without // defining it (as a dummy) first. @@ -956,7 +982,7 @@ impl Module { } .into(), ); - LocalIdDef::Value(Value::DataInstOutput(inst)) + LocalIdDef::Value(result_type.unwrap(), Value::DataInstOutput(inst)) } }; local_id_defs.insert(id, local_id_def); @@ -1005,50 +1031,52 @@ impl Module { ref ids, } = *raw_inst; - let invalid = |msg: &str| invalid(&format!("in {}: {}", opcode.name(), msg)); + let invalid = invalid_factory_for_spv_inst(&raw_inst.without_ids, result_id, ids); // FIXME(eddyb) find a more compact name and/or make this a method. // FIXME(eddyb) this returns `LocalIdDef` even for global values. - let lookup_global_or_local_id_for_data_or_control_inst_input = - |id| match id_defs.get(&id) { - Some(&IdDef::Const(ct)) => Ok(LocalIdDef::Value(Value::Const(ct))), - Some(id_def @ IdDef::Type(_)) => Err(invalid(&format!( - "unsupported use of {} as an operand for \ + let lookup_global_or_local_id_for_data_or_control_inst_input = |id| match id_defs + .get(&id) + { + Some(&IdDef::Const(ct)) => Ok(LocalIdDef::Value(cx[ct].ty, Value::Const(ct))), + Some(id_def @ IdDef::Type(_)) => Err(invalid(&format!( + "unsupported use of {} as an operand for \ an instruction in a function", - id_def.descr(&cx), - ))), - Some(id_def @ IdDef::Func(_)) => Err(invalid(&format!( - "unsupported use of {} outside `OpFunctionCall`", - id_def.descr(&cx), - ))), - Some(id_def @ IdDef::SpvDebugString(s)) => { - if opcode == wk.OpExtInst { - // HACK(eddyb) intern `OpString`s as `Const`s on - // the fly, as it's a less likely usage than the - // `OpLine` one. - let ct = cx.intern(ConstDef { - attrs: AttrSet::default(), - ty: cx.intern(TypeKind::SpvStringLiteralForExtInst), - kind: ConstKind::SpvStringLiteralForExtInst(*s), - }); - Ok(LocalIdDef::Value(Value::Const(ct))) - } else { - Err(invalid(&format!( - "unsupported use of {} outside `OpSource`, \ + id_def.descr(&cx), + ))), + Some(id_def @ IdDef::Func(_)) => Err(invalid(&format!( + "unsupported use of {} outside `OpFunctionCall`", + id_def.descr(&cx), + ))), + Some(id_def @ IdDef::SpvDebugString(s)) => { + if opcode == wk.OpExtInst { + // HACK(eddyb) intern `OpString`s as `Const`s on + // the fly, as it's a less likely usage than the + // `OpLine` one. + let ty = cx.intern(TypeKind::SpvStringLiteralForExtInst); + let ct = cx.intern(ConstDef { + attrs: AttrSet::default(), + ty, + kind: ConstKind::SpvStringLiteralForExtInst(*s), + }); + Ok(LocalIdDef::Value(ty, Value::Const(ct))) + } else { + Err(invalid(&format!( + "unsupported use of {} outside `OpSource`, \ `OpLine`, or `OpExtInst`", - id_def.descr(&cx), - ))) - } + id_def.descr(&cx), + ))) } - Some(id_def @ IdDef::SpvExtInstImport(_)) => Err(invalid(&format!( - "unsupported use of {} outside `OpExtInst`", - id_def.descr(&cx), - ))), - None => local_id_defs - .get(&id) - .copied() - .ok_or_else(|| invalid(&format!("undefined ID %{id}",))), - }; + } + Some(id_def @ IdDef::SpvExtInstImport(_)) => Err(invalid(&format!( + "unsupported use of {} outside `OpExtInst`", + id_def.descr(&cx), + ))), + None => local_id_defs + .get(&id) + .copied() + .ok_or_else(|| invalid(&format!("undefined ID %{id}",))), + }; if opcode == wk.OpFunctionParameter { if current_block_control_region_and_details.is_some() { @@ -1086,7 +1114,7 @@ impl Module { // to be able to have an entry in `local_id_defs`. let control_region = match local_id_defs[&result_id.unwrap()] { LocalIdDef::BlockLabel(control_region) => control_region, - LocalIdDef::Value(_) => unreachable!(), + LocalIdDef::Value(..) => unreachable!(), }; let current_block_details = &block_details[&control_region]; assert_eq!(current_block_details.label_id, result_id.unwrap()); @@ -1122,7 +1150,7 @@ impl Module { }; let phi_value_id_to_value = |phi_key: &PhiKey, id| { match lookup_global_or_local_id_for_data_or_control_inst_input(id)? { - LocalIdDef::Value(v) => Ok(v), + LocalIdDef::Value(_, v) => Ok(v), LocalIdDef::BlockLabel { .. } => Err(invalid(&format!( "unsupported use of block label as the value for {}", descr_phi_case(phi_key) @@ -1173,10 +1201,11 @@ impl Module { // Split the operands into value inputs (e.g. a branch's // condition or an `OpSwitch`'s selector) and target blocks. let mut inputs = SmallVec::new(); + let mut input_types = SmallVec::<[_; 2]>::new(); let mut targets = SmallVec::new(); for &id in ids { match lookup_global_or_local_id_for_data_or_control_inst_input(id)? { - LocalIdDef::Value(v) => { + LocalIdDef::Value(ty, v) => { if !targets.is_empty() { return Err(invalid( "out of order: value operand \ @@ -1184,6 +1213,7 @@ impl Module { )); } inputs.push(v); + input_types.push(ty); } LocalIdDef::BlockLabel(target) => { record_cfg_edge(target)?; @@ -1192,6 +1222,7 @@ impl Module { } } + // FIXME(eddyb) move some of this to `spv::canonical`. let kind = if opcode == wk.OpUnreachable { assert!(targets.is_empty() && inputs.is_empty()); cfg::ControlInstKind::Unreachable @@ -1209,9 +1240,62 @@ impl Module { assert_eq!((targets.len(), inputs.len()), (2, 1)); cfg::ControlInstKind::SelectBranch(SelectionKind::BoolCond) } else if opcode == wk.OpSwitch { - cfg::ControlInstKind::SelectBranch(SelectionKind::SpvInst( - raw_inst.without_ids.clone(), - )) + assert_eq!(inputs.len(), 1); + + // HACK(eddyb) `spv::read` has to "redundantly" validate + // that such a type is `OpTypeInt`/`OpTypeFloat`, but + // there is still a limitation when it comes to `scalar::Const`. + // FIXME(eddyb) don't hardcode the 128-bit limitation, + // but query `scalar::Const` somehow instead. + let scrutinee_type = input_types[0]; + let scrutinee_type = scrutinee_type + .as_scalar(&cx) + .filter(|ty| { + matches!(ty, scalar::Type::UInt(_) | scalar::Type::SInt(_)) + && ty.bit_width() <= 128 + }) + .ok_or_else(|| { + invalid( + &print::Plan::for_root( + &cx, + &Diag::err([ + "unsupported `OpSwitch` scrutinee type `".into(), + scrutinee_type.into(), + "`".into(), + ]) + .message, + ) + .pretty_print() + .to_string(), + ) + })?; + + // FIXME(eddyb) move some of this to `spv::canonical`. + let imm_words_per_case = + usize::try_from(scrutinee_type.bit_width().div_ceil(32)).unwrap(); + + // NOTE(eddyb) these sanity-checks are redundant with `spv::read`. + assert_eq!(imms.len() % imm_words_per_case, 0); + assert_eq!(targets.len(), 1 + imms.len() / imm_words_per_case); + + let case_consts = imms + .chunks(imm_words_per_case) + .map(|case_imms| { + scalar::Const::try_decode_from_spv_imms(scrutinee_type, case_imms) + .ok_or_else(|| { + invalid(&format!( + "invalid {}-bit `OpSwitch` case constant", + scrutinee_type.bit_width() + )) + }) + }) + .collect::>()?; + + // HACK(eddyb) move the default case from first to last. + let default_target = targets.remove(0); + targets.push(default_target); + + cfg::ControlInstKind::SelectBranch(SelectionKind::Switch { case_consts }) } else { return Err(invalid("unsupported control-flow instruction")); }; @@ -1256,7 +1340,7 @@ impl Module { let loop_merge_target = match lookup_global_or_local_id_for_data_or_control_inst_input(ids[0])? { - LocalIdDef::Value(_) => return Err(invalid("expected label ID")), + LocalIdDef::Value(..) => return Err(invalid("expected label ID")), LocalIdDef::BlockLabel(target) => target, }; @@ -1274,7 +1358,13 @@ impl Module { // some "structured regions" replacement for the CFG. } else { let mut ids = &ids[..]; - let kind = if opcode == wk.OpFunctionCall { + let kind = if let Some(kind) = raw_inst.without_ids.as_canonical_data_inst_kind( + &cx, + result_type.map(|ty| [ty]).as_ref().map_or(&[][..], |tys| &tys[..]), + ) { + // FIXME(eddyb) sanity-check the number/types of inputs. + kind + } else if opcode == wk.OpFunctionCall { assert!(imms.is_empty()); let callee_id = ids[0]; let maybe_callee = id_defs @@ -1349,7 +1439,7 @@ impl Module { .map(|&id| { match lookup_global_or_local_id_for_data_or_control_inst_input(id)? { - LocalIdDef::Value(v) => Ok(v), + LocalIdDef::Value(_, v) => Ok(v), LocalIdDef::BlockLabel { .. } => Err(invalid( "unsupported use of block label as a value, \ in non-terminator instruction", @@ -1360,7 +1450,7 @@ impl Module { }; let inst = match result_id { Some(id) => match local_id_defs[&id] { - LocalIdDef::Value(Value::DataInstOutput(inst)) => { + LocalIdDef::Value(_, Value::DataInstOutput(inst)) => { // A dummy was defined earlier, to be able to // have an entry in `local_id_defs`. func_def_body.data_insts[inst] = data_inst_def.into(); diff --git a/src/spv/mod.rs b/src/spv/mod.rs index eb5a2e7d..09728c1a 100644 --- a/src/spv/mod.rs +++ b/src/spv/mod.rs @@ -2,6 +2,7 @@ // NOTE(eddyb) all the modules are declared here, but they're documented "inside" // (i.e. using inner doc comments). +pub mod canonical; pub mod lift; pub mod lower; pub mod print; diff --git a/src/spv/print.rs b/src/spv/print.rs index 1fc99dd6..714b064d 100644 --- a/src/spv/print.rs +++ b/src/spv/print.rs @@ -77,6 +77,9 @@ impl TokensForOperand { // FIXME(eddyb) keep a `&'static spec::Spec` if that can even speed up anything. struct OperandPrinter, ID, IDS: Iterator> { + // FIXME(eddyb) use a field like this to interpret `Opcode`/`OperandKind`, too. + wk: &'static spv::spec::WellKnown, + /// Input immediate operands to print from (may be grouped e.g. into literals). imms: iter::Peekable, @@ -123,7 +126,41 @@ impl, ID, IDS: Iterator> OperandPrint let def = kind.def(); assert!(matches!(def, spec::OperandKindDef::Literal { .. })); - let literal_token = if kind == spec::Spec::get().well_known.LiteralString { + let literal_token = if kind == self.wk.LiteralSpecConstantOpInteger { + assert_eq!(words.len(), 1); + let (_, inner_name, inner_def) = match u16::try_from(first_word) + .ok() + .and_then(spec::Opcode::try_from_u16_with_name_and_def) + { + Some(opcode_name_and_def) => opcode_name_and_def, + None => { + self.out.tokens.push(Token::Error(format!( + "/* {first_word} not a valid `OpSpecConstantOp` opcode */" + ))); + return; + } + }; + + // FIXME(eddyb) deduplicate this with `enumerant_params`. + self.out.tokens.push(Token::EnumerandName(inner_name)); + + let mut first = true; + for (inner_mode, inner_name_and_kind) in inner_def.all_operands_with_names() { + if inner_mode == spec::OperandMode::Optional && self.is_exhausted() { + break; + } + + self.out.tokens.push(Token::Punctuation(if first { "(" } else { ", " })); + first = false; + + let (inner_name, inner_kind) = inner_name_and_kind.name_and_kind(); + self.operand(inner_name, inner_kind); + } + if !first { + self.out.tokens.push(Token::Punctuation(")")); + } + return; + } else if kind == self.wk.LiteralString { // FIXME(eddyb) deduplicate with `spv::extract_literal_string`. let bytes: SmallVec<[u8; 64]> = words .into_iter() @@ -260,6 +297,7 @@ impl, ID, IDS: Iterator> OperandPrint /// an enumerand with parameters (which consumes more immediates). pub fn operand_from_imms(imms: impl IntoIterator) -> TokensForOperand { let mut printer = OperandPrinter { + wk: &spec::Spec::get().well_known, imms: imms.into_iter().peekable(), ids: iter::empty().peekable(), out: TokensForOperand::default(), @@ -282,6 +320,7 @@ pub fn inst_operands( ids: impl IntoIterator, ) -> impl Iterator> { OperandPrinter { + wk: &spec::Spec::get().well_known, imms: imms.into_iter().peekable(), ids: ids.into_iter().peekable(), out: TokensForOperand::default(), diff --git a/src/spv/read.rs b/src/spv/read.rs index 58178055..cfa57a1d 100644 --- a/src/spv/read.rs +++ b/src/spv/read.rs @@ -12,15 +12,14 @@ use std::{fs, io, iter, slice}; /// /// Used currently only to help parsing `LiteralContextDependentNumber`. enum KnownIdDef { - TypeInt(NonZeroU32), - TypeFloat(NonZeroU32), + TypeIntOrFloat(NonZeroU32), Uncategorized { opcode: spec::Opcode, result_type_id: Option }, } impl KnownIdDef { fn result_type_id(&self) -> Option { match *self { - Self::TypeInt(_) | Self::TypeFloat(_) => None, + Self::TypeIntOrFloat(_) => None, Self::Uncategorized { result_type_id, .. } => result_type_id, } } @@ -28,6 +27,9 @@ impl KnownIdDef { // FIXME(eddyb) keep a `&'static spec::Spec` if that can even speed up anything. struct InstParser<'a> { + // FIXME(eddyb) use a field like this to interpret `Opcode`/`OperandKind`, too. + wk: &'static spv::spec::WellKnown, + /// IDs defined so far in the module. known_ids: &'a FxHashMap, @@ -60,6 +62,9 @@ enum InstParseError { /// The type of a `LiteralContextDependentNumber` was not a supported type /// (one of either `OpTypeInt` or `OpTypeFloat`). UnsupportedContextSensitiveLiteralType { type_opcode: spec::Opcode }, + + /// Unsupported `OpSpecConstantOp` (`LiteralSpecConstantOpInteger`) opcode. + UnsupportedSpecConstantOpOpcode(u32), } impl InstParseError { @@ -94,6 +99,9 @@ impl InstParseError { Self::UnsupportedContextSensitiveLiteralType { type_opcode } => { format!("{} is not a supported literal type", type_opcode.name()).into() } + Self::UnsupportedSpecConstantOpOpcode(opcode) => { + format!("{opcode} is not a supported opcode (for `OpSpecConstantOp`)").into() + } } } } @@ -174,11 +182,8 @@ impl InstParser<'_> { .and_then(|id| self.known_ids.get(&id)) .ok_or(Error::MissingContextSensitiveLiteralType)?; - let extra_word_count = match *contextual_type { - KnownIdDef::TypeInt(width) | KnownIdDef::TypeFloat(width) => { - // HACK(eddyb) `(width + 31) / 32 - 1` but without overflow. - (width.get() - 1) / 32 - } + let word_count = match *contextual_type { + KnownIdDef::TypeIntOrFloat(width) => width.get().div_ceil(32), KnownIdDef::Uncategorized { opcode, .. } => { return Err(Error::UnsupportedContextSensitiveLiteralType { type_opcode: opcode, @@ -186,11 +191,11 @@ impl InstParser<'_> { } }; - if extra_word_count == 0 { + if word_count == 1 { self.inst.imms.push(spv::Imm::Short(kind, word)); } else { self.inst.imms.push(spv::Imm::LongStart(kind, word)); - for _ in 0..extra_word_count { + for _ in 1..word_count { let word = self.words.next().ok_or(Error::NotEnoughWords)?; self.inst.imms.push(spv::Imm::LongCont(kind, word)); } @@ -198,6 +203,22 @@ impl InstParser<'_> { } } + // HACK(eddyb) this isn't cleanly uniform because it's an odd special case. + if kind == self.wk.LiteralSpecConstantOpInteger { + // FIXME(eddyb) this partially duplicates the main instruction parsing. + let (_, _, inner_def) = u16::try_from(word) + .ok() + .and_then(spec::Opcode::try_from_u16_with_name_and_def) + .ok_or(Error::UnsupportedSpecConstantOpOpcode(word))?; + + for (inner_mode, inner_kind) in inner_def.all_operands() { + if inner_mode == spec::OperandMode::Optional && self.is_exhausted() { + break; + } + self.operand(inner_kind)?; + } + } + Ok(()) } @@ -304,9 +325,6 @@ impl ModuleParser { impl Iterator for ModuleParser { type Item = io::Result; fn next(&mut self) -> Option { - let spv_spec = spec::Spec::get(); - let wk = &spv_spec.well_known; - let words = &bytemuck::cast_slice::(&self.word_bytes)[self.next_word..]; let &opcode = words.first()?; @@ -324,6 +342,7 @@ impl Iterator for ModuleParser { } let parser = InstParser { + wk: &spec::Spec::get().well_known, known_ids: &self.known_ids, words: words[1..inst_len].iter().copied(), inst: spv::InstWithIds { @@ -341,24 +360,11 @@ impl Iterator for ModuleParser { // HACK(eddyb) `Option::map` allows using `?` for `Result` in the closure. let maybe_known_id_result = inst.result_id.map(|id| { - let known_id_def = if opcode == wk.OpTypeInt { - KnownIdDef::TypeInt(match inst.imms[0] { - spv::Imm::Short(kind, n) => { - assert_eq!(kind, wk.LiteralInteger); - n.try_into().ok().ok_or_else(|| invalid("Width cannot be 0"))? - } - _ => unreachable!(), - }) - } else if opcode == wk.OpTypeFloat { - KnownIdDef::TypeFloat(match inst.imms[0] { - spv::Imm::Short(kind, n) => { - assert_eq!(kind, wk.LiteralInteger); - n.try_into().ok().ok_or_else(|| invalid("Width cannot be 0"))? - } - _ => unreachable!(), - }) - } else { - KnownIdDef::Uncategorized { opcode, result_type_id: inst.result_type_id } + let known_id_def = match inst.int_or_float_type_bit_width() { + Some(w) => KnownIdDef::TypeIntOrFloat( + w.try_into().ok().ok_or_else(|| invalid("Width cannot be 0"))?, + ), + None => KnownIdDef::Uncategorized { opcode, result_type_id: inst.result_type_id }, }; let old = self.known_ids.insert(id, known_id_def); diff --git a/src/spv/spec.rs b/src/spv/spec.rs index 2b00cb24..81ddb800 100644 --- a/src/spv/spec.rs +++ b/src/spv/spec.rs @@ -117,10 +117,6 @@ def_well_known! { OpNoLine, OpTypeVoid, - OpTypeBool, - OpTypeInt, - OpTypeFloat, - OpTypeVector, OpTypeMatrix, OpTypeArray, OpTypeRuntimeArray, @@ -133,10 +129,8 @@ def_well_known! { OpTypeSampledImage, OpTypeAccelerationStructureKHR, - OpConstantFalse, - OpConstantTrue, - OpConstant, - OpUndef, + // FIXME(eddyb) hide these from code, lowering should handle most cases. + OpConstantComposite, OpVariable, @@ -166,6 +160,11 @@ def_well_known! { OpPtrAccessChain, OpInBoundsPtrAccessChain, OpBitcast, + + // FIXME(eddyb) hide these from code, lowering should handle most cases. + OpCompositeInsert, + OpCompositeExtract, + OpCompositeConstruct, ], operand_kind: OperandKind = [ Capability, @@ -183,6 +182,7 @@ def_well_known! { LiteralExtInstInteger, LiteralString, LiteralContextDependentNumber, + LiteralSpecConstantOpInteger, ], // FIXME(eddyb) find a way to namespace these to avoid conflicts. addressing_model: u32 = [ diff --git a/src/spv/write.rs b/src/spv/write.rs index 0d0a9312..083dfea6 100644 --- a/src/spv/write.rs +++ b/src/spv/write.rs @@ -7,6 +7,9 @@ use std::{fs, io, iter, slice}; // FIXME(eddyb) keep a `&'static spec::Spec` if that can even speed up anything. struct OperandEmitter<'a> { + // FIXME(eddyb) use a field like this to interpret `Opcode`/`OperandKind`, too. + wk: &'static spv::spec::WellKnown, + /// Input immediate operands of an instruction. imms: iter::Copied>, @@ -32,6 +35,9 @@ enum OperandEmitError { /// Unsupported enumerand value. UnsupportedEnumerand(spec::OperandKind, u32), + + /// Unsupported `OpSpecConstantOp` (`LiteralSpecConstantOpInteger`) opcode. + UnsupportedSpecConstantOpOpcode(u32), } impl OperandEmitError { @@ -60,6 +66,9 @@ impl OperandEmitError { _ => unreachable!(), } } + Self::UnsupportedSpecConstantOpOpcode(opcode) => { + format!("{opcode} is not a supported opcode (for `OpSpecConstantOp`)").into() + } } } } @@ -140,6 +149,23 @@ impl OperandEmitter<'_> { } } + // HACK(eddyb) this isn't cleanly uniform because it's an odd special case. + if kind == self.wk.LiteralSpecConstantOpInteger { + // FIXME(eddyb) this partially duplicates the main instruction emission. + let &word = self.out.last().unwrap(); + let (_, _, inner_def) = u16::try_from(word) + .ok() + .and_then(spec::Opcode::try_from_u16_with_name_and_def) + .ok_or(Error::UnsupportedSpecConstantOpOpcode(word))?; + + for (inner_mode, inner_kind) in inner_def.all_operands() { + if inner_mode == spec::OperandMode::Optional && self.is_exhausted() { + break; + } + self.operand(inner_kind)?; + } + } + Ok(()) } @@ -221,6 +247,7 @@ impl ModuleEmitter { ); OperandEmitter { + wk: &spec::Spec::get().well_known, imms: inst.imms.iter().copied(), ids: inst.ids.iter().copied(), out: &mut self.words, diff --git a/src/transform.rs b/src/transform.rs index 6b97b8aa..053ecb60 100644 --- a/src/transform.rs +++ b/src/transform.rs @@ -424,7 +424,10 @@ impl InnerTransform for TypeDef { transform!({ attrs -> transformer.transform_attr_set_use(*attrs), kind -> match kind { - TypeKind::QPtr | TypeKind::SpvStringLiteralForExtInst => Transformed::Unchanged, + TypeKind::Scalar(_) + | TypeKind::Vector(_) + | TypeKind::QPtr + | TypeKind::SpvStringLiteralForExtInst => Transformed::Unchanged, TypeKind::SpvInst { spv_inst, type_and_const_inputs } => Transformed::map_iter( type_and_const_inputs.iter(), @@ -457,6 +460,11 @@ impl InnerTransform for ConstDef { attrs -> transformer.transform_attr_set_use(*attrs), ty -> transformer.transform_type_use(*ty), kind -> match kind { + ConstKind::Undef + | ConstKind::Scalar(_) + | ConstKind::Vector(_) + | ConstKind::SpvStringLiteralForExtInst(_) => Transformed::Unchanged, + ConstKind::PtrToGlobalVar(gv) => transform!({ gv -> transformer.transform_global_var_use(*gv), } => ConstKind::PtrToGlobalVar(gv)), @@ -470,7 +478,6 @@ impl InnerTransform for ConstDef { spv_inst_and_const_inputs: Rc::new((spv_inst.clone(), new_iter.collect())), }) } - ConstKind::SpvStringLiteralForExtInst(_) => Transformed::Unchanged }, } => Self { attrs, @@ -635,7 +642,7 @@ impl InnerInPlaceTransform for FuncAtMut<'_, ControlNode> { } } ControlNodeKind::Select { - kind: SelectionKind::BoolCond | SelectionKind::SpvInst(_), + kind: SelectionKind::BoolCond | SelectionKind::Switch { case_consts: _ }, scrutinee, cases: _, } => { @@ -716,7 +723,10 @@ impl InnerTransform for DataInstFormDef { | QPtrOp::Load | QPtrOp::Store => Transformed::Unchanged, }, - DataInstKind::SpvInst(_) | DataInstKind::SpvExtInst { .. } => Transformed::Unchanged, + DataInstKind::Scalar(_) + | DataInstKind::Vector(_) + | DataInstKind::SpvInst(_) + | DataInstKind::SpvExtInst { .. } => Transformed::Unchanged, }, // FIXME(eddyb) this should be replaced with an impl of `InnerTransform` // for `Option` or some other helper, to avoid "manual transpose". @@ -740,7 +750,7 @@ impl InnerInPlaceTransform for cfg::ControlInst { | cfg::ControlInstKind::ExitInvocation(cfg::ExitInvocationKind::SpvInst(_)) | cfg::ControlInstKind::Branch | cfg::ControlInstKind::SelectBranch( - SelectionKind::BoolCond | SelectionKind::SpvInst(_), + SelectionKind::BoolCond | SelectionKind::Switch { case_consts: _ }, ) => {} } for v in inputs { diff --git a/src/vector.rs b/src/vector.rs new file mode 100644 index 00000000..0b1d42f5 --- /dev/null +++ b/src/vector.rs @@ -0,0 +1,180 @@ +//! Vector types (small arrays of [`scalar`](crate::scalar)s) and associated functionality. +//! +//! **Note**: these are similar to SIMD types in other IRs, but SPIR-V often uses +//! its `OpTypeVector` to represent geometrical vectors, colors, etc. without any +//! expectation of SIMD execution (which most GPU execution models use implicitly, +//! i.e. one non-uniform scalar becomes a hardware SIMD vector, while a high-level +//! "vector" of N "lanes", becomes N separate hardware SIMD vectors). + +use crate::scalar; +use smallvec::SmallVec; +use std::num::NonZeroU8; +use std::rc::Rc; + +// FIXME(eddyb) this entire module shorthands "element" as "elem", is that good? + +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub struct Type { + pub elem: scalar::Type, + // FIXME(eddyb) maybe wrap this in a type that abstracts away the encoding? + pub elem_count: NonZeroU8, +} + +// FIXME(eddyb) document the 128-bit limitations inherited from `scalar::Const`. +#[derive(Clone, PartialEq, Eq, Hash)] +pub struct Const(ConstRepr); + +// HACK(eddyb) `#[repr(packed)]` not allowed on `enum`s themselves. +#[repr(packed)] +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +struct Packed(T); + +// FIXME(eddyb) maybe build an abstraction for "N-dimensional" bit arrays? +#[derive(Clone, PartialEq, Eq, Hash)] +#[repr(u8)] +enum ConstRepr { + // HACK(eddyb) `(Type, u128)` would waste almost half its size on padding, and + // packing will only impact accessing the bits, while allowing e.g. being + // wrapped in an outer `enum`, before reaching the same size as `(u128, u128)`. + Inline(Type, Packed), + + // HACK(eddyb) this does raise the alignment, but the size and alignment are + // kept at one pointer (so likely half of `u128`) - `Packed>` is sadly + // not an option because `#[derive(...)]` + `#[repr(packed)]` often requires + // `Copy` in order to be able to safely take references (to a copy of a field). + Boxed(Type, Rc>), +} + +impl Const { + pub const fn ty(&self) -> Type { + match self.0 { + ConstRepr::Inline(ty, _) | ConstRepr::Boxed(ty, _) => ty, + } + } + + pub fn from_elems(ty: Type, elems: impl IntoIterator) -> Const { + let elem_width = ty.elem.bit_width(); + assert!(elem_width <= 128); + + let expected_elem_count = u32::from(ty.elem_count.get()); + + let num_limbs = elem_width.checked_mul(expected_elem_count).unwrap().div_ceil(128); + assert_ne!(num_limbs, 0); + let mut limbs = SmallVec::<[u128; 1]>::from_elem(0, usize::try_from(num_limbs).unwrap()); + + let mut found_elem_count = 0; + for ct in elems { + let i: u32 = found_elem_count; + found_elem_count = found_elem_count.checked_add(1).unwrap(); + if i >= expected_elem_count { + continue; + } + + // FIXME(eddyb) get better names (perhaps from miri-like memory?). + let first_bit_idx = i.checked_mul(elem_width).unwrap(); + let limb_idx = first_bit_idx / 128; + let intra_limb_first_bit_idx = first_bit_idx % 128; + assert!(intra_limb_first_bit_idx + elem_width <= 128); + + limbs[usize::try_from(limb_idx).unwrap()] |= ct.bits() << intra_limb_first_bit_idx; + } + assert_eq!(found_elem_count, expected_elem_count); + + match limbs.into_inner() { + Ok([limb]) => Const(ConstRepr::Inline(ty, Packed(limb))), + Err(limbs) => Const(ConstRepr::Boxed(ty, Rc::new(limbs.into_vec()))), + } + } + + pub fn get_elem(&self, i: usize) -> Option { + let ty = self.ty(); + if i >= usize::from(ty.elem_count.get()) { + return None; + } + let i = u32::try_from(i).unwrap(); + let elem_width = ty.elem.bit_width(); + assert!(elem_width <= 128); + + // FIXME(eddyb) get better names (perhaps from miri-like memory?). + let first_bit_idx = i.checked_mul(elem_width).unwrap(); + let limb_idx = first_bit_idx / 128; + let intra_limb_first_bit_idx = first_bit_idx % 128; + assert!(intra_limb_first_bit_idx + elem_width <= 128); + + let limb = match &self.0 { + ConstRepr::Inline(_, limb) => { + assert_eq!(limb_idx, 0); + limb.0 + } + ConstRepr::Boxed(_, limbs) => limbs[usize::try_from(limb_idx).unwrap()], + }; + + Some(scalar::Const::from_bits( + ty.elem, + (limb >> intra_limb_first_bit_idx) & (!0 >> (128 - elem_width)), + )) + } + + pub fn elems(&self) -> impl Iterator + '_ { + let ty = self.ty(); + // FIXME(eddyb) there should be a more efficient way to do this. + (0..usize::from(ty.elem_count.get())).map(|i| self.get_elem(i).unwrap()) + } +} + +/// Pure operations with vector inputs and/or outputs. +#[derive(Copy, Clone, PartialEq, Eq, Hash, derive_more::From)] +pub enum Op { + Distribute(scalar::Op), + Reduce(ReduceOp), + + // FIXME(eddyb) find a better name for this category of ops. + Whole(WholeOp), +} + +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub enum ReduceOp { + // FIXME(eddyb) also support all the new integer dot product instructions. + Dot, + // FIXME(eddyb) model these using their respective `BoolBinOp`s? + Any, + All, +} + +impl ReduceOp { + pub fn name(self) -> &'static str { + match self { + ReduceOp::Dot => "vec.dot", + ReduceOp::Any => "vec.any", + ReduceOp::All => "vec.all", + } + } +} + +// FIXME(eddyb) find a better name for this category of ops. +// FIXME(eddyb) also support `OpVectorShuffle`. +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub enum WholeOp { + // FIXME(eddyb) better name for this (pack? make? "construct" is too long). + New, + Extract { elem_idx: u8 }, + Insert { elem_idx: u8 }, + DynExtract, + DynInsert, + + // FIXME(eddyb) may need a better name to indicate "scalar product". + Mul, +} + +impl WholeOp { + pub fn name(self) -> &'static str { + match self { + WholeOp::New => "vec.new", + WholeOp::Extract { .. } => "vec.extract", + WholeOp::Insert { .. } => "vec.insert", + WholeOp::DynExtract => "vec.dyn_extract", + WholeOp::DynInsert => "vec.dyn_insert", + WholeOp::Mul => "vec.mul", + } + } +} diff --git a/src/visit.rs b/src/visit.rs index 7bb837d5..5a54a74c 100644 --- a/src/visit.rs +++ b/src/visit.rs @@ -315,7 +315,10 @@ impl InnerVisit for TypeDef { visitor.visit_attr_set_use(*attrs); match kind { - TypeKind::QPtr | TypeKind::SpvStringLiteralForExtInst => {} + TypeKind::Scalar(_) + | TypeKind::Vector(_) + | TypeKind::QPtr + | TypeKind::SpvStringLiteralForExtInst => {} TypeKind::SpvInst { spv_inst: _, type_and_const_inputs } => { for &ty_or_ct in type_and_const_inputs { @@ -336,6 +339,11 @@ impl InnerVisit for ConstDef { visitor.visit_attr_set_use(*attrs); visitor.visit_type_use(*ty); match kind { + ConstKind::Undef + | ConstKind::Scalar(_) + | ConstKind::Vector(_) + | ConstKind::SpvStringLiteralForExtInst(_) => {} + &ConstKind::PtrToGlobalVar(gv) => visitor.visit_global_var_use(gv), ConstKind::SpvInst { spv_inst_and_const_inputs } => { let (_spv_inst, const_inputs) = &**spv_inst_and_const_inputs; @@ -343,7 +351,6 @@ impl InnerVisit for ConstDef { visitor.visit_const_use(ct); } } - ConstKind::SpvStringLiteralForExtInst(_) => {} } } } @@ -474,7 +481,7 @@ impl<'a> FuncAt<'a, ControlNode> { } } ControlNodeKind::Select { - kind: SelectionKind::BoolCond | SelectionKind::SpvInst(_), + kind: SelectionKind::BoolCond | SelectionKind::Switch { case_consts: _ }, scrutinee, cases, } => { @@ -534,7 +541,10 @@ impl InnerVisit for DataInstFormDef { | QPtrOp::Load | QPtrOp::Store => {} }, - DataInstKind::SpvInst(_) | DataInstKind::SpvExtInst { .. } => {} + DataInstKind::Scalar(_) + | DataInstKind::Vector(_) + | DataInstKind::SpvInst(_) + | DataInstKind::SpvExtInst { .. } => {} } if let Some(ty) = *output_type { visitor.visit_type_use(ty); @@ -553,7 +563,7 @@ impl InnerVisit for cfg::ControlInst { | cfg::ControlInstKind::ExitInvocation(cfg::ExitInvocationKind::SpvInst(_)) | cfg::ControlInstKind::Branch | cfg::ControlInstKind::SelectBranch( - SelectionKind::BoolCond | SelectionKind::SpvInst(_), + SelectionKind::BoolCond | SelectionKind::Switch { case_consts: _ }, ) => {} } for v in inputs {