diff --git a/.agents/skills/contrib-op-shape-inference-memory-safety/SKILL.md b/.agents/skills/contrib-op-shape-inference-memory-safety/SKILL.md new file mode 100644 index 0000000000000..e709455b60696 --- /dev/null +++ b/.agents/skills/contrib-op-shape-inference-memory-safety/SKILL.md @@ -0,0 +1,237 @@ +--- +name: contrib-op-shape-inference-memory-safety +description: "Audit and fix out-of-range output writes in ONNX Runtime operator shape-inference functions. Use when reviewing or fixing a contrib (or standard) op TypeAndShapeInference where a getNumOutputs() guard precedes a write to a higher output index - optional trailing outputs make a smaller output count schema-valid, so getOutputType(index) can run one past the declared outputs at Graph::Resolve." +--- + +# Contrib-Op Shape-Inference Output-Index Safety + +Reusable method for finding and fixing the bug class where an operator's +`TypeAndShapeInference` function guards an output write with `getNumOutputs() > N` but then +writes an output index **greater than** `N`. For a node that declares fewer outputs, the +written index is past the end of the inference context's output vector. + +> **Scope**: schema-level shape inference in `onnxruntime/core/graph/contrib_ops/*.cc` and +> `shape_inference_functions.cc`. This runs once during `Graph::Resolve` (model-load time), +> **EP-agnostic** - there is no per-EP (CPU/CUDA/ROCm) kernel duplicate of this code to +> chase. Op *kernels* allocate outputs via the bounds-safe `OpKernelContext::Output(index)` +> and are a separate concern. + +## 1. The pattern + +```cpp +// onnxruntime/core/graph/contrib_ops/bert_defs.cc (before) +propagateElemTypeFromInputToOutput(ctx, 0, 0); +if (ctx.getNumOutputs() > 1) { // guard says "> 1" + propagateElemTypeFromInputToOutput(ctx, 0, 1); + propagateElemTypeFromInputToOutput(ctx, 0, 2); // but writes index 2 +} +``` + +The guard `getNumOutputs() > 1` admits a node with **exactly 2 outputs** (indices 0, 1), yet +the body writes index **2**. The implication "`> 1` ⇒ index 2 exists" is false: `> 1` only +guarantees indices 0 and 1. + +### Why a smaller output count is valid + +Trailing outputs declared `OpSchema::Optional` **lower `min_output`**. ONNX derives +`min_output` = number of required outputs, `max_output` = total declared. The model checker +(`checker::check_node`) only enforces `min_output <= N <= max_output`. + +| Op | Output decls | min / max | A 2-output node? | +|---|---|---|---| +| `DecoderAttention` | out (req), new_key_cache (Opt), new_value_cache (Opt) | 1 / 3 | passes checker | +| `MultiHeadAttention` | out (req), present_key (Opt), present_value (Opt), qk (Opt) | 1 / 4 | passes checker | +| `DecoderMaskedMultiHeadAttention` | out (req) + 3 Optional | 1 / 4 | passes checker | + +So a node with `output=['out','present_key']` is schema-valid, passes the checker, and then +reaches the index-2 write. **A passing checker is not a guarantee the index is in range.** + +## 2. The sink (why the write is not caught) + +```cpp +// onnxruntime/core/graph/graph.cc - InferenceContextImpl +const TypeProto* getInputType(size_t index) const override { + return node_.InputDefs().at(index)->TypeAsProto(); // .at() -> bounds-checked +} +TypeProto* getOutputType(size_t index) override { + return &node_output_types_[index]; // operator[] -> NOT bounds-checked +} +``` + +- `node_output_types_` is sized to `node.OutputDefs().size()` in the `InferenceContextImpl` + ctor, so for a 2-output node it has 2 elements; `getOutputType(2)` returns one past the end. +- `getInputType` uses `.at()` (would throw on a bad index); `getOutputType` uses raw + `operator[]` (no check) - the asymmetry is the root cause. +- The call runs at `Graph::Resolve` → `InferAndVerifyTypeMatch` → `RunInferencing`. The + surrounding `ORT_TRY/ORT_CATCH(const std::exception&)` only catches *thrown* + `fail_shape_inference`; a raw out-of-range `operator[]` does not throw, so the catch does + not help. +- Because this is schema-level inference, it is **EP-independent** - no CUDA/ROCm copy. + +## 3. Audit technique — always sweep siblings + +Do not stop at the reported function. Grep **every** shape-inference guard and compare its +threshold against the **highest output index written before the next guard**. + +```bash +git grep -n 'getNumOutputs' -- \ + onnxruntime/core/graph/contrib_ops/*.cc \ + onnxruntime/core/graph/contrib_ops/shape_inference_functions.cc +``` + +For each `if (ctx.getNumOutputs() > N)` block, find the largest `index` passed to +`propagateElemTypeFromInputToOutput(ctx, _, index)` / `updateOutputShape(ctx, index, _)` / +`getOutputType(index)` inside it. **Rule: the guard must require strictly more outputs than +the highest index written** (write index `k` ⇒ guard must ensure `getNumOutputs() > k`). + +Correct exemplars already in the tree to copy: + +| Exemplar | Pattern | Why it is safe | +|---|---|---| +| `BaseGroupQueryAttention...` | `if (getNumOutputs() >= 3)` then writes idx 2 | guard covers highest index | +| `PagedAttention...` | nested `> 1` + inner `if (getNumOutputs() != 3) fail_shape_inference` | fails before any write | +| `EmbedLayerNormalizationShapeInference` | `> 2` then writes idx 2 | fixed by PR #28176 (precedent) | +| `SkipLayerNormalizationShapeInference` | each idx `k` guarded by `> k` | per-index guard | + +> **Gotcha — conditional writes can hide a vacuous audit.** A write may sit behind an inner +> condition (e.g. `hasInputShape(past_key_index)` before writing index 2). The site is still +> a bug, but you can only *observe* it when that inner condition is also satisfied. Keep this +> in mind both for the audit and for tests (§5). + +## 4. Fix patterns + +**Point fix (required): raise the guard to cover the highest index written.** + +```cpp +// before +if (ctx.getNumOutputs() > 1) { ... writes idx 2 ... } +// after +if (ctx.getNumOutputs() > 2) { // both present_key (idx 1) AND present_value (idx 2) + ... +} +``` + +Justify the threshold with the op's output semantics. For these attention ops the two trailing +outputs - `present_key` (idx 1) and `present_value` (idx 2) for `MultiHeadAttention`, +`new_key_cache` / `new_value_cache` for `DecoderAttention` (see the §1 table for each op's +exact output names) - are a **both-or-neither pair**: there is no valid configuration that +emits one without the other, so requiring all three outputs before populating indices 1 and 2 +is behavior-preserving. (`PagedAttention` encodes the same invariant via its nested `!= 3` +check.) + +**Defense-in-depth (recommended): bound the sink** so a future author cannot reintroduce the +class. + +```cpp +// onnxruntime/core/graph/graph.cc - InferenceContextImpl::getOutputType +TypeProto* getOutputType(size_t index) override { + if (index >= node_output_types_.size()) { + fail_type_inference("output index ", index, " is out of range; node has ", + node_output_types_.size(), " outputs"); + } + return &node_output_types_[index]; +} +``` + +This mirrors `getInputType`'s `.at()` and the existing bounds checks in the sibling +`DataPropagationContextImpl`. Placing it at the base layer transitively protects the NHWC and +quantization wrapper contexts. After the point fix this branch is unreachable through a normal +model (the guard already prevents the out-of-range index), so it is pure defense-in-depth. Its +failure mode is build-dependent: with exceptions enabled, `fail_type_inference` raises +`InferenceError` (a `std::exception`), caught by the existing `ORT_CATCH(const std::exception&)` +around `RunInferencing` and surfaced as a clean load-time error; under `ORT_NO_EXCEPTIONS` it is +**not** compiled out - ONNX's no-exceptions path prints the message to `std::cerr` and calls +`abort()`, a deterministic fail-fast (consistent with `getInputType`'s `.at()`, which likewise +terminates under no-exceptions). Either way the result is a controlled failure rather than an +out-of-range write. + +## 5. Test recipe + +Tests live in `onnxruntime/test/contrib_ops/*.cc` and are **auto-globbed** into the +`onnxruntime_provider_test` target by `cmake/onnxruntime_unittests.cmake` +(`test/contrib_ops/*.cc` pattern) - **no cmake edit needed** for a new file. See the +`ort-test` skill for the executable taxonomy (`onnxruntime_provider_test` vs +`onnxruntime_test_all`). + +Rules that make the regression test actually guard the fix: + +1. **Drive through `Model` + `Graph::Resolve`**, not ONNX's standalone `TestShapeInference`. + Only the full resolve path constructs the real `InferenceContextImpl` and hits the + `getOutputType` sink described in §2. A standalone ONNX shape-inference helper uses a + different context and **bypasses** the sink, so it cannot reproduce the bug. +2. **Negative tests must be NON-VACUOUS** - they must actually enter the write branch on + pre-fix source. If a write is gated by an inner condition (§3 gotcha), satisfy it: e.g. for + `MultiHeadAttention`/`DecoderMaskedMultiHeadAttention`, supply a **shaped `past_key`** + (and `past_sequence_length` / `past_present_share_buffer` as the op requires) so the + index-2 block runs. A negative test that only supplies `query` skips the block and passes + even on pre-fix source - regression-proof in name only. +3. **Add positive (all-outputs) cases**: a node with every output present must still infer the + trailing output types - proves the tightened guard did not over-restrict. +4. **Keep tests throw-free post-fix** so they are valid under `ORT_NO_EXCEPTIONS`. Any case + that is *expected* to `fail_shape_inference` (throws) must be excluded with + `#ifndef ORT_NO_EXCEPTIONS`. The "2 outputs must not go out of range" case is throw-free + after the point fix and is safe in all builds. + +**Verify the negative test is non-vacuous (sanitizer A/B)** - the most reliable way to prove a +negative test enters the previously-out-of-range branch: build the test at the **pre-fix** +commit with **AddressSanitizer** and confirm it flags the out-of-range output access; then +confirm it is clean after the fix. + +```bash +# Functional run (any Debug build): +cmake --build build/Linux/Debug --target onnxruntime_provider_test -j"$(nproc)" +./build/Linux/Debug/onnxruntime_provider_test \ + --gtest_filter='AttentionOptionalOutputsShapeInferenceTest.*' + +# A/B proof (isolated worktree at the pre-fix commit, CPU-only Debug + sanitizer): +git worktree add --detach ../ort-prefix-check ~1 +# copy the new test file in, then: +python3 tools/ci_build/build.py --build_dir build/asan --config Debug --parallel \ + --skip_tests --enable_address_sanitizer --skip_submodule_sync \ + --cmake_generator Ninja --target onnxruntime_provider_test +# Pre-fix: the negative tests fail (the sanitizer flags the out-of-range output access). +# Post-fix (cherry-pick the guard fix): all tests pass, no sanitizer report. +``` + +## 6. Process / wording conventions + +- Run **`lintrunner -a`** before pushing so the `CLANGFORMAT` / Python-format gate passes. See + the `ort-lint` skill. +- Use **correctness/robustness framing** in code, comments, commit messages, and the PR body + - describe the change as fixing an optional-output guard, not as a security fix. This + matches repo convention (compare `python-kwargs-setattr-security`) and keeps the PR neutral. + +## 7. Audit checklist (per-operator review) + +When reviewing or hardening any operator implementation or its shape inference: + +- [ ] Read the op's spec - ONNX standard op page, or for a contrib op its `OpSchema` + registration (`.Input/.Output/.Attr`, and `Optional`/`Variadic` markers). A local ONNX + checkout has the standard-op spec pages; contrib ops are defined only in ORT. +- [ ] Enumerate **all** inputs, attributes, and outputs, noting which are optional and the + resulting `min/max` input and output counts. +- [ ] Validate every input/attribute before indexing into it, to avoid out-of-range reads + (which can cascade into worse failures). Match each output-index write to a guard that + guarantees the index is in range (§3 rule). +- [ ] Prefer `ORT_RETURN_IF` / `ORT_RETURN_IF_NOT` for validation; use `ORT_ENFORCE` in + constructors. In shape inference use `fail_shape_inference` / `fail_type_inference`. +- [ ] Use `SafeInt<>` / `narrow<>()` for index and size arithmetic and casts to avoid overflow + or truncation that yields a wrong index. See `core/common/safeint.h` and + `docs/Coding_Conventions_and_Standards.md`. +- [ ] Ensure tests build and pass under **no-exceptions** builds; `#ifndef ORT_NO_EXCEPTIONS` + around any case expected to throw. +- [ ] Exclude EPs known not to support the op, with a comment explaining why. +- [ ] Check whether **other EPs (notably CUDA/ROCm)** implement the same op and whether the + same issue exists there. (For *shape inference* specifically, the logic is EP-agnostic + and single-source - confirm there is no kernel-side analogue.) + +## References + +- **PR #28176** - "Fix ... in EmbedLayerNormalizationShapeInference": the precedent that fixed + the identical `> 1` → `> 2` primitive in one site; the sibling attention sites were missed, + motivating the sweep in §3. +- **PR #29268** - this fix: guards corrected in `DecoderAttention` / `MultiHeadAttention` / + `DecoderMaskedMultiHeadAttention` shape inference, plus the `getOutputType` bounds check and + non-vacuous regression tests. +- Sibling skill: **`ort-test`** (test executables, `--gtest_filter`, contrib-op test layout); + **`ort-lint`** (`lintrunner -a`); **`ort-build`** (build flags, ASan). diff --git a/docs/Coding_Conventions_and_Standards.md b/docs/Coding_Conventions_and_Standards.md index 02af7ddaa49be..6dcf7e6dd43bd 100644 --- a/docs/Coding_Conventions_and_Standards.md +++ b/docs/Coding_Conventions_and_Standards.md @@ -109,6 +109,7 @@ void foo(gsl::span names) { * Use [SafeInt](https://github.com/dcleblanc/SafeInt) when calculating the size of memory to allocate to protect against overflow errors * `#include "core/common/safeint.h"` * search for `SafeInt` in the code for examples +* In operator shape inference, validate every output index against `getNumOutputs()` before writing it. Optional trailing outputs lower an op's `min_output`, so a node may legally declare fewer outputs than the schema's maximum; guard each optional output by the exact index it populates so a `getNumOutputs() > N` guard never writes an index greater than `N`. * The following C++ warnings should never be disabled in onnxruntime VC++ projects(Required by [Binskim](https://github.com/microsoft/binskim/blob/d9afb65c89a621411efded74c27999281d87867e/src/BinSkim.Rules/PERules/BA2007.EnableCriticalCompilerWarnings.cs)). 1. [4018](https://docs.microsoft.com/en-us/cpp/error-messages/compiler-warnings/compiler-warning-level-3-c4018) 'token' : signed/unsigned mismatch 2. [4146](https://docs.microsoft.com/en-us/cpp/error-messages/compiler-warnings/compiler-warning-level-2-c4146?view=msvc-160) unary minus operator applied to unsigned type, result still unsigned diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 896774fb5c8d8..9fdcebf57081d 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -29,7 +29,7 @@ namespace contrib { void DecoderAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) { // Type inference ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0); - if (ctx.getNumOutputs() > 1) { + if (ctx.getNumOutputs() > 2) { // has new_key_cache and new_value_cache outputs; a pair, so present only when > 2 ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 1); ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 2); } @@ -38,7 +38,7 @@ void DecoderAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx auto& query_shape = getInputShape(ctx, 0); updateOutputShape(ctx, 0, query_shape); } - if (ctx.getNumOutputs() > 1) { + if (ctx.getNumOutputs() > 2) { // has new_key_cache and new_value_cache outputs; a pair, so present only when > 2 if (hasInputShape(ctx, 6) && hasInputShape(ctx, 7)) { auto& cache_shape = getInputShape(ctx, 6); auto& cache_dims = cache_shape.dim(); @@ -199,7 +199,7 @@ void MultiHeadAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& c } } - if (ctx.getNumOutputs() > 1) { // has present output + if (ctx.getNumOutputs() > 2) { // has present_key and present_value outputs; a pair, so present only when > 2 if (hasInputShape(ctx, past_key_index)) { auto& past_shape = getInputShape(ctx, past_key_index); auto& past_dims = past_shape.dim(); diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index fe2df6a87d124..ff60f6cb5f741 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -2739,6 +2739,10 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext { } TypeProto* getOutputType(size_t index) override { + if (index >= node_output_types_.size()) { + fail_type_inference("output index ", index, " is out of range; node has ", + node_output_types_.size(), " outputs"); + } return &node_output_types_[index]; } diff --git a/onnxruntime/test/contrib_ops/attention_optional_outputs_shape_inference_test.cc b/onnxruntime/test/contrib_ops/attention_optional_outputs_shape_inference_test.cc new file mode 100644 index 0000000000000..d2d3a421a6560 --- /dev/null +++ b/onnxruntime/test/contrib_ops/attention_optional_outputs_shape_inference_test.cc @@ -0,0 +1,230 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Shape-inference correctness coverage for attention contrib ops with optional "present" outputs. +// DecoderAttention, MultiHeadAttention and DecoderMaskedMultiHeadAttention each expose an optional +// present_key (output 1) and present_value (output 2). These outputs are produced as a pair, so a +// node may declare either one output, or all three; declaring exactly two (present_key kept, +// present_value omitted) is also valid per the schemas. +// +// The "...Omitted" tests build each op with exactly two outputs and assert that Graph::Resolve() +// completes cleanly without referencing the absent third output. The "...AllPresentOutputs" tests +// build each op with all three outputs and assert that the present_key / present_value branch still +// runs and propagates their element types. Together they pin the guard to exactly "> 2": fewer +// outputs must not touch the missing one, and three outputs must still be inferred. +// +// These tests exercise only graph-load shape inference, which is execution-provider independent, so +// they run on the default CPU build with no provider-specific handling. The resolve path for these +// models is throw-free, so the tests are valid in builds compiled without exceptions +// (ORT_NO_EXCEPTIONS) and need no exception-specific guarding. + +#include "gtest/gtest.h" + +#include "core/graph/constants.h" +#include "core/graph/model.h" +#include "test/test_environment.h" +#include "test/unittest_util/graph_transform_test_builder.h" +#include "test/util/include/asserts.h" + +namespace onnxruntime { +namespace test { + +namespace { + +constexpr int kOnnxOpsetVersion = 17; + +// Builds a single-node model via add_node, resolves it (running the node's type/shape inference), +// asserts success, and runs the optional verifier against the resolved graph. +void BuildResolveAndVerify(const std::function& add_node, + const std::function& verify = nullptr) { + std::unordered_map domain_to_version; + domain_to_version[kOnnxDomain] = kOnnxOpsetVersion; + domain_to_version[kMSDomain] = 1; + + Model model("attention_optional_outputs", /*is_onnx_domain_only=*/false, ModelMetaData(), + PathString(), IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + DefaultLoggingManager().DefaultLogger()); + + ModelTestBuilder builder(model.MainGraph()); + add_node(builder); + builder.SetGraphOutputs(); + + ASSERT_STATUS_OK(model.MainGraph().Resolve()); + if (verify) { + verify(model.MainGraph()); + } +} + +// Asserts that the given output NodeArg received a tensor element type from shape inference. +void ExpectInferredElemType(const NodeArg* output) { + ASSERT_NE(output, nullptr); + const ONNX_NAMESPACE::TypeProto* type = output->TypeAsProto(); + ASSERT_NE(type, nullptr); + EXPECT_TRUE(type->has_tensor_type()); + EXPECT_TRUE(type->tensor_type().has_elem_type()); +} + +} // namespace + +// MultiHeadAttention with present_key kept and present_value omitted (exactly two outputs). +// past_key (input 6), past_value (input 7) and past_sequence_length (input 8) are supplied with +// shapes so the present-output branch is active; with only two outputs declared, inference must not +// reference the absent present_value (output 2). Inputs 1-5 are optional and left empty. +TEST(AttentionOptionalOutputsShapeInferenceTest, MultiHeadAttentionPresentValueOmitted) { + BuildResolveAndVerify([](ModelTestBuilder& builder) { + NodeArg& empty = builder.graph_.GetOrCreateNodeArg("", nullptr); + NodeArg* query = builder.MakeInput(std::vector{2, 1, 4}); + NodeArg* past_key = builder.MakeInput(std::vector{2, 2, 3, 2}); + NodeArg* past_value = builder.MakeInput(std::vector{2, 2, 3, 2}); + NodeArg* past_sequence_length = builder.MakeInput(std::vector{1}); + NodeArg* output = builder.MakeOutput(std::nullopt); + NodeArg* present_key = builder.MakeOutput(std::nullopt); + std::vector inputs = {query, &empty, &empty, &empty, &empty, &empty, + past_key, past_value, past_sequence_length}; + Node& node = builder.AddNode("MultiHeadAttention", inputs, {output, present_key}, kMSDomain); + node.AddAttribute("num_heads", static_cast(2)); + }); +} + +// DecoderMaskedMultiHeadAttention with present_key kept and present_value omitted. +// past_key (input 5) and past_value (input 6) are supplied with shapes and past buffer sharing is +// enabled so the present-output branch is active; with only two outputs declared, inference must not +// reference the absent present_value (output 2). Inputs 1-4 are optional and left empty. +TEST(AttentionOptionalOutputsShapeInferenceTest, DecoderMaskedMultiHeadAttentionPresentValueOmitted) { + BuildResolveAndVerify([](ModelTestBuilder& builder) { + NodeArg& empty = builder.graph_.GetOrCreateNodeArg("", nullptr); + NodeArg* query = builder.MakeInput(std::vector{2, 1, 4}); + NodeArg* past_key = builder.MakeInput(std::vector{2, 2, 3, 2}); + NodeArg* past_value = builder.MakeInput(std::vector{2, 2, 3, 2}); + NodeArg* output = builder.MakeOutput(std::nullopt); + NodeArg* present_key = builder.MakeOutput(std::nullopt); + std::vector inputs = {query, &empty, &empty, &empty, &empty, past_key, past_value}; + Node& node = builder.AddNode("DecoderMaskedMultiHeadAttention", inputs, {output, present_key}, + kMSDomain); + node.AddAttribute("num_heads", static_cast(2)); + node.AddAttribute("past_present_share_buffer", static_cast(1)); + }); +} + +// DecoderAttention with new_key_cache kept and new_value_cache omitted (exactly two outputs). +TEST(AttentionOptionalOutputsShapeInferenceTest, DecoderAttentionNewValueCacheOmitted) { + BuildResolveAndVerify([](ModelTestBuilder& builder) { + // DecoderAttention requires inputs 0-4 and 8-11; inputs 5-7 are optional and left empty here. + NodeArg& empty = builder.graph_.GetOrCreateNodeArg("", nullptr); + NodeArg* query = builder.MakeInput(std::vector{2, 1, 4}); + NodeArg* key = builder.MakeInput(std::vector{2, 1, 4}); + NodeArg* q_weight = builder.MakeInput(std::vector{4, 4}); + NodeArg* kv_weight = builder.MakeInput(std::vector{4, 8}); + NodeArg* bias = builder.MakeInput(std::vector{12}); + NodeArg* static_kv = builder.MakeInput(std::vector{1}); + NodeArg* use_past = builder.MakeInput(std::vector{1}); + NodeArg* has_layer_state = builder.MakeInput(std::vector{1}); + NodeArg* has_key_padding_mask = builder.MakeInput(std::vector{1}); + + NodeArg* output = builder.MakeOutput(std::nullopt); + NodeArg* new_key_cache = builder.MakeOutput(std::nullopt); + + std::vector inputs = {query, key, q_weight, kv_weight, bias, &empty, &empty, &empty, + static_kv, use_past, has_layer_state, has_key_padding_mask}; + Node& node = builder.AddNode("DecoderAttention", inputs, {output, new_key_cache}, kMSDomain); + node.AddAttribute("num_heads", static_cast(2)); + }); +} + +// MultiHeadAttention with all three outputs: the present_key / present_value branch must still run. +TEST(AttentionOptionalOutputsShapeInferenceTest, MultiHeadAttentionAllPresentOutputs) { + NodeArg* present_key = nullptr; + NodeArg* present_value = nullptr; + BuildResolveAndVerify( + [&](ModelTestBuilder& builder) { + // present_key/present_value types are propagated from past_key (input 6) and past_value + // (input 7) when past buffer sharing is active, which the op detects from shaped past_key + // (input 6) and past_sequence_length (input 8). Leave inputs 1-5 empty to reach them. + NodeArg& empty = builder.graph_.GetOrCreateNodeArg("", nullptr); + NodeArg* query = builder.MakeInput(std::vector{2, 1, 4}); + NodeArg* past_key = builder.MakeInput(std::vector{2, 2, 3, 2}); + NodeArg* past_value = builder.MakeInput(std::vector{2, 2, 3, 2}); + NodeArg* past_sequence_length = builder.MakeInput(std::vector{1}); + + NodeArg* output = builder.MakeOutput(std::nullopt); + present_key = builder.MakeOutput(); + present_value = builder.MakeOutput(); + + std::vector inputs = {query, &empty, &empty, &empty, &empty, &empty, + past_key, past_value, past_sequence_length}; + Node& node = builder.AddNode("MultiHeadAttention", inputs, + {output, present_key, present_value}, kMSDomain); + node.AddAttribute("num_heads", static_cast(2)); + }, + [&](const Graph&) { + ExpectInferredElemType(present_key); + ExpectInferredElemType(present_value); + }); +} + +// DecoderMaskedMultiHeadAttention with all three outputs: shape inference must still populate them. +TEST(AttentionOptionalOutputsShapeInferenceTest, DecoderMaskedMultiHeadAttentionAllPresentOutputs) { + NodeArg* present_key = nullptr; + NodeArg* present_value = nullptr; + BuildResolveAndVerify( + [&](ModelTestBuilder& builder) { + // For this op past_key/past_value are inputs 5 and 6; leave inputs 1-4 empty to reach them. + NodeArg& empty = builder.graph_.GetOrCreateNodeArg("", nullptr); + NodeArg* query = builder.MakeInput(std::vector{2, 1, 4}); + NodeArg* past_key = builder.MakeInput(std::vector{2, 2, 3, 2}); + NodeArg* past_value = builder.MakeInput(std::vector{2, 2, 3, 2}); + + NodeArg* output = builder.MakeOutput(std::nullopt); + present_key = builder.MakeOutput(); + present_value = builder.MakeOutput(); + + std::vector inputs = {query, &empty, &empty, &empty, &empty, past_key, past_value}; + Node& node = builder.AddNode("DecoderMaskedMultiHeadAttention", inputs, + {output, present_key, present_value}, kMSDomain); + node.AddAttribute("num_heads", static_cast(2)); + node.AddAttribute("past_present_share_buffer", static_cast(1)); + }, + [&](const Graph&) { + ExpectInferredElemType(present_key); + ExpectInferredElemType(present_value); + }); +} + +// DecoderAttention with all three outputs: new_key_cache / new_value_cache types must be inferred. +TEST(AttentionOptionalOutputsShapeInferenceTest, DecoderAttentionAllCacheOutputs) { + NodeArg* new_key_cache = nullptr; + NodeArg* new_value_cache = nullptr; + BuildResolveAndVerify( + [&](ModelTestBuilder& builder) { + NodeArg& empty = builder.graph_.GetOrCreateNodeArg("", nullptr); + NodeArg* query = builder.MakeInput(std::vector{2, 1, 4}); + NodeArg* key = builder.MakeInput(std::vector{2, 1, 4}); + NodeArg* q_weight = builder.MakeInput(std::vector{4, 4}); + NodeArg* kv_weight = builder.MakeInput(std::vector{4, 8}); + NodeArg* bias = builder.MakeInput(std::vector{12}); + NodeArg* key_cache = builder.MakeInput(std::vector{2, 2, 3, 2}); + NodeArg* value_cache = builder.MakeInput(std::vector{2, 2, 3, 2}); + NodeArg* static_kv = builder.MakeInput(std::vector{1}); + NodeArg* use_past = builder.MakeInput(std::vector{1}); + NodeArg* has_layer_state = builder.MakeInput(std::vector{1}); + NodeArg* has_key_padding_mask = builder.MakeInput(std::vector{1}); + + NodeArg* output = builder.MakeOutput(std::nullopt); + new_key_cache = builder.MakeOutput(); + new_value_cache = builder.MakeOutput(); + + std::vector inputs = {query, key, q_weight, kv_weight, bias, &empty, key_cache, + value_cache, static_kv, use_past, has_layer_state, + has_key_padding_mask}; + Node& node = builder.AddNode("DecoderAttention", inputs, + {output, new_key_cache, new_value_cache}, kMSDomain); + node.AddAttribute("num_heads", static_cast(2)); + }, + [&](const Graph&) { + ExpectInferredElemType(new_key_cache); + ExpectInferredElemType(new_value_cache); + }); +} + +} // namespace test +} // namespace onnxruntime