diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 082d988a92c88..e4d9404bc029c 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -4,6 +4,7 @@ #include #include +#include #include #include #include @@ -21,6 +22,12 @@ #include #include +#ifndef _WIN32 +#include +#include +#include +#endif + #include "core/providers/shared_library/provider_api.h" #define ORT_API_MANUAL_INIT #include "core/session/onnxruntime_cxx_api.h" @@ -149,6 +156,87 @@ static std::vector parse_compile_batches(const std::string& spec); // but a few fallback paths still use synchronous hipMalloc. static std::mutex g_hip_alloc_mutex; +// HIP's current device is a per-thread setting. ORT/Triton may invoke Compile() +// and the compute function on threads that never ran the EP constructor's +// hipSetDevice, so they default to device 0. On a non-zero-device instance that +// loads code objects / launches kernels on the wrong device, producing +// "no kernel image is available for execution on the device" and +// "invalid resource handle". This guard pins the calling thread to the EP's +// device for the duration of a scope and restores the previous device on exit. +struct HipDeviceGuard { + int prev_{0}; + explicit HipDeviceGuard(int dev) { + HIP_CALL_THROW(hipGetDevice(&prev_)); + if (dev != prev_) HIP_CALL_THROW(hipSetDevice(dev)); + } + ~HipDeviceGuard() { + (void)hipSetDevice(prev_); // best-effort restore; never throw from a dtor + } +}; + +// --------------------------------------------------------------------------- +// Compile / cache concurrency primitives +// --------------------------------------------------------------------------- +// +// MIGraphX's ONNX parse + codegen path (parse_onnx_buffer / compile_program) is +// NOT thread-safe. Multiple EP instances (one per GPU) sharing a process will +// otherwise compile concurrently and crash or corrupt internal state. This +// mutex serializes every compile in the process to exactly one at a time. It +// is the load-bearing piece for compile thread-safety; the per-key lock below +// only governs cache herd control, not compiler reentrancy. +static std::mutex g_migraphx_compile_mutex; + +// One mutex per cache-file key. Ensures a given .mxr is compiled exactly once +// in-process: the first thread to miss compiles + publishes, every other thread +// waiting on the same key then re-checks the cache and loads. The registry +// mutex guards only the short map lookup, never the compile itself. +static std::mutex& mutex_for_cache_key(const std::string& key) { + static std::mutex registry_mu; + static std::unordered_map> registry; + std::lock_guard g(registry_mu); + auto& slot = registry[key]; + if (!slot) { + slot = std::make_unique(); + } + return *slot; +} + +// Cross-process advisory lock on ".lock". When several containers +// share a cache volume this serializes compile/publish across processes so they +// don't trample each other's .mxr writes. No-op when no cache file is set or on +// platforms without flock; in-process safety still comes from the mutexes above. +struct CacheFileLock { +#ifndef _WIN32 + int fd_{-1}; +#endif + explicit CacheFileLock(const std::filesystem::path& cache_file) { +#ifndef _WIN32 + if (cache_file.empty()) { + return; + } + auto lock_path = cache_file; + lock_path += ".lock"; + fd_ = ::open(lock_path.c_str(), O_CREAT | O_RDWR, 0644); + if (fd_ >= 0 && ::flock(fd_, LOCK_EX) != 0) { + ::close(fd_); + fd_ = -1; + } +#else + (void)cache_file; +#endif + } + ~CacheFileLock() { +#ifndef _WIN32 + if (fd_ >= 0) { + (void)::flock(fd_, LOCK_UN); + ::close(fd_); + } +#endif + } + CacheFileLock(const CacheFileLock&) = delete; + CacheFileLock& operator=(const CacheFileLock&) = delete; +}; + MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info) : IExecutionProvider{kMIGraphXExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, info.device_id)}, device_id_{info.device_id}, @@ -215,6 +303,7 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv GET_ENV_BOOL(migraphx_env_vars::kExhaustiveTune, exhaustive_tune_); GET_ENV_STRING(migraphx_env_vars::kCompileBatches, compile_batches_); GET_ENV_BOOL(migraphx_env_vars::kHipGraphEnable, hip_graph_enable_); + GET_ENV_BOOL(migraphx_env_vars::kCoalesceIO, coalesce_io_enable_); // hipGraph requires single-stream MIGraphX execution (MIGRAPHX_NSTREAMS=1). if (hip_graph_enable_) { @@ -244,6 +333,16 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv } } + // Coalesced I/O runs in the pinned-copy path, which is only allocated/used + // when hipGraph is active (or padding is required). Without hipGraph the flag + // has no effect, so warn rather than silently no-op. + if (coalesce_io_enable_ && !hip_graph_enable_) { + LOGS_DEFAULT(WARNING) + << "[MIGraphX EP] ORT_MIGRAPHX_COALESCE_IO is set but hipGraph is disabled; " + << "input coalescing runs in the pinned-copy path and will be inactive. " + << "Enable ORT_MIGRAPHX_HIP_GRAPH_ENABLE to use it."; + } + // If compile_batches is set, auto-derive max_dynamic_batch from the spec's max value if (!compile_batches_.empty()) { auto explicit_sizes = parse_compile_batches(compile_batches_); @@ -1316,17 +1415,42 @@ bool load_precompiled_model(migraphx::program& prog, const std::filesystem::path } void save_compiled_model(const migraphx::program& prog, const std::filesystem::path& path) { - if (!path.empty()) { - LOGS_DEFAULT(INFO) << "[save_compiled_model] Saving compiled model to disk: " << path.string(); - migraphx::file_options fo; - fo.set_file_format("msgpack"); - save(prog, path.string().c_str(), fo); - if (std::filesystem::exists(path)) { - auto file_sz = std::filesystem::file_size(path); - LOGS_DEFAULT(INFO) << "[save_compiled_model] Saved: " << path.string() - << " (file size: " << file_sz << " bytes, " - << (file_sz / (1024.0 * 1024.0)) << " MB)"; - } + if (path.empty()) { + return; + } + LOGS_DEFAULT(INFO) << "[save_compiled_model] Saving compiled model to disk: " << path.string(); + migraphx::file_options fo; + fo.set_file_format("msgpack"); + + // Atomic publish: serialize to a unique temp file in the same directory, then + // rename over the target. rename(2) is atomic on a single filesystem, so a + // concurrent reader (or a crash mid-write) never observes a torn .mxr, which + // otherwise surfaces as "no kernel image is available" on load. + auto tmp_path = path; + tmp_path += "." + std::to_string(static_cast( +#ifndef _WIN32 + ::getpid() +#else + 0 +#endif + )) + ".tmp"; + + save(prog, tmp_path.string().c_str(), fo); + + std::error_code ec; + std::filesystem::rename(tmp_path, path, ec); + if (ec) { + std::filesystem::remove(tmp_path, ec); + LOGS_DEFAULT(WARNING) << "[save_compiled_model] Atomic publish failed for " + << path.string() << ": " << ec.message(); + return; + } + + if (std::filesystem::exists(path)) { + auto file_sz = std::filesystem::file_size(path); + LOGS_DEFAULT(INFO) << "[save_compiled_model] Saved: " << path.string() + << " (file size: " << file_sz << " bytes, " + << (file_sz / (1024.0 * 1024.0)) << " MB)"; } } @@ -1490,21 +1614,61 @@ static void allocate_pinned_io( const auto& map_input_name_index = mgx_state->input_name_indexes; + // Round each arena slot up to this boundary so every sub-view stays aligned + // for the device kernels that consume it. + constexpr std::size_t kArenaAlign = 256; + auto align_up = [](std::size_t v, std::size_t a) { return (v + a - 1) / a * a; }; + pio.inputs.clear(); pio.input_name_to_idx.clear(); - for (const auto& name : param_shapes.names()) { - if (map_input_name_index.find(name) == map_input_name_index.end()) continue; - const auto& base_shape = param_shapes[name]; - auto lens = base_shape.lengths(); - if (!lens.empty()) lens[0] = max_batch_size; - auto max_shape = migraphx::shape(base_shape.type(), lens); - std::size_t bytes = max_shape.bytes(); + pio.input_offsets.clear(); + pio.coalesced = mgx_state->coalesce_io; + + if (pio.coalesced) { + // ── Single-arena layout: one device alloc + one pinned host staging buffer. + // First pass assigns aligned offsets and records per-input metadata; the + // device pointers are patched in once the arena is allocated. + std::size_t off = 0; + for (const auto& name : param_shapes.names()) { + if (map_input_name_index.find(name) == map_input_name_index.end()) continue; + const auto& base_shape = param_shapes[name]; + auto lens = base_shape.lengths(); + if (!lens.empty()) lens[0] = max_batch_size; + auto max_shape = migraphx::shape(base_shape.type(), lens); + std::size_t bytes = max_shape.bytes(); + + pio.input_name_to_idx[name] = pio.inputs.size(); + pio.input_offsets.push_back(off); + pio.inputs.push_back({nullptr, bytes, max_shape}); // .data patched below + off += align_up(bytes, kArenaAlign); + } + pio.in_arena_bytes = off; + + if (pio.in_arena_bytes > 0) { + HIP_CALL_THROW(hipMallocAsync(&pio.in_arena_dev, pio.in_arena_bytes, stream)); + HIP_CALL_THROW(hipMemsetAsync(pio.in_arena_dev, 0, pio.in_arena_bytes, stream)); + // Pinned host staging: page-locked so the single H2D is truly async. + HIP_CALL_THROW(hipHostMalloc(&pio.in_staging_host, pio.in_arena_bytes, hipHostMallocDefault)); + std::memset(pio.in_staging_host, 0, pio.in_arena_bytes); + for (std::size_t i = 0; i < pio.inputs.size(); ++i) { + pio.inputs[i].data = static_cast(pio.in_arena_dev) + pio.input_offsets[i]; + } + } + } else { + for (const auto& name : param_shapes.names()) { + if (map_input_name_index.find(name) == map_input_name_index.end()) continue; + const auto& base_shape = param_shapes[name]; + auto lens = base_shape.lengths(); + if (!lens.empty()) lens[0] = max_batch_size; + auto max_shape = migraphx::shape(base_shape.type(), lens); + std::size_t bytes = max_shape.bytes(); - pio.input_name_to_idx[name] = pio.inputs.size(); - void* ptr = nullptr; - HIP_CALL_THROW(hipMallocAsync(&ptr, bytes, stream)); - HIP_CALL_THROW(hipMemsetAsync(ptr, 0, bytes, stream)); - pio.inputs.push_back({ptr, bytes, max_shape}); + pio.input_name_to_idx[name] = pio.inputs.size(); + void* ptr = nullptr; + HIP_CALL_THROW(hipMallocAsync(&ptr, bytes, stream)); + HIP_CALL_THROW(hipMemsetAsync(ptr, 0, bytes, stream)); + pio.inputs.push_back({ptr, bytes, max_shape}); + } } pio.outputs.clear(); @@ -1539,8 +1703,18 @@ static void allocate_pinned_io( static void free_pinned_io(MIGraphXFuncState* mgx_state, hipStream_t stream) { auto& pio = mgx_state->pinned_io; - for (auto& buf : pio.inputs) { - if (buf.data) { (void)hipFreeAsync(buf.data, stream); buf.data = nullptr; } + if (pio.coalesced) { + // Inputs are sub-views of a single arena: free the arena once, never the + // individual .data pointers (they are not separate allocations). + if (pio.in_arena_dev) { (void)hipFreeAsync(pio.in_arena_dev, stream); pio.in_arena_dev = nullptr; } + if (pio.in_staging_host) { (void)hipHostFree(pio.in_staging_host); pio.in_staging_host = nullptr; } + pio.in_arena_bytes = 0; + pio.input_offsets.clear(); + pio.coalesced = false; + } else { + for (auto& buf : pio.inputs) { + if (buf.data) { (void)hipFreeAsync(buf.data, stream); buf.data = nullptr; } + } } for (auto& buf : pio.outputs) { if (buf.data) { (void)hipFreeAsync(buf.data, stream); buf.data = nullptr; } @@ -1670,6 +1844,57 @@ static void copy_inputs_to_pinned( auto& pio = mgx_state->pinned_io; const auto& map_input_name_index = mgx_state->input_name_indexes; + // ── Coalesced fast path ─────────────────────────────────────────────────── + // When the input arena is active, there is no padding, and every input is + // host-resident, gather all inputs into the pinned staging buffer and issue a + // single H2D for the whole arena. This collapses the ~N per-input + // hipMemcpyAsync launches (which dominate batch-1 many-input models) into one. + // Any other case (padding, or a device-resident input) falls through to the + // per-input loop below, which is still correct because each pin.data points + // into the arena. + if (pio.coalesced && pio.in_staging_host != nullptr && actual_batch == compiled_batch) { + bool all_host = true; + for (const auto& name : param_shapes.names()) { + auto it = map_input_name_index.find(name); + if (it == map_input_name_index.end()) continue; + auto mem = ctx.GetInput(it->second).GetTensorMemoryInfo(); + if (mem.GetDeviceType() != OrtMemoryInfoDeviceType_CPU) { all_host = false; break; } + } + + if (all_host) { + char* host_base = static_cast(pio.in_staging_host); + for (const auto& name : param_shapes.names()) { + auto it = map_input_name_index.find(name); + if (it == map_input_name_index.end()) continue; + auto pin_it = pio.input_name_to_idx.find(name); + if (pin_it == pio.input_name_to_idx.end()) continue; + const auto idx = pin_it->second; + const auto& pin = pio.inputs[idx]; + + const auto& input_tensor = ctx.GetInput(it->second); + const void* src = input_tensor.GetTensorRawData(); + const auto& base_shape = param_shapes[name]; + auto lens = base_shape.lengths(); + std::size_t elements_per_batch = std::accumulate( + lens.begin() + 1, lens.end(), std::size_t{1}, std::multiplies<>{}); + std::size_t total_elems = 1; + for (auto l : lens) total_elems *= l; + std::size_t byte_per_elem = (total_elems > 0) ? base_shape.bytes() / total_elems : 0; + std::size_t copy_bytes = actual_batch * elements_per_batch * byte_per_elem; + if (copy_bytes > pin.size_bytes) copy_bytes = pin.size_bytes; + if (copy_bytes > 0) { + std::memcpy(host_base + pio.input_offsets[idx], src, copy_bytes); + } + } + // One transfer for every input. Copying the whole arena (including the + // harmless aligned gaps / unused slot tails) keeps it a single contiguous + // DMA; the program only ever reads the bound [compiled_batch] rows. + HIP_CALL_THROW(hipMemcpyAsync(pio.in_arena_dev, pio.in_staging_host, + pio.in_arena_bytes, hipMemcpyHostToDevice, stream)); + return; + } + } + for (const auto& name : param_shapes.names()) { auto it = map_input_name_index.find(name); if (it == map_input_name_index.end()) continue; @@ -1799,6 +2024,33 @@ static void copy_pinned_outputs_to_ort( } } +// Defense-in-depth: make sure every ORT graph output has been materialized. +// +// An output OrtValue only becomes "produced" once some path calls +// ctx.GetOutput() for its index. The hipGraph paths split that responsibility +// between copy_pinned_outputs_to_ort (pre-allocated #output_N params) and +// materialize_extra_outputs ("extra" run_async results). If any output index +// falls through both (e.g. a partition mismatch), ORT later fails the fetch +// with "Unsupported OrtValue type". GetOutput is idempotent -- for an output +// already produced it returns the existing tensor and ignores the shape -- so +// calling it here for every program output guarantees none is left unset. +// Program output position i maps 1:1 to ORT output index i, matching how +// run_migraphx_program / materialize_extra_outputs index their outputs. +static void ensure_all_outputs_allocated( + Ort::KernelContext& ctx, + const migraphx::shapes& output_shapes, + std::size_t actual_batch) +{ + for (std::size_t i = 0; i < output_shapes.size(); ++i) { + auto lens = output_shapes[i].lengths(); + std::vector ort_shape(lens.begin(), lens.end()); + if (!ort_shape.empty() && actual_batch > 0) { + ort_shape[0] = static_cast(actual_batch); + } + (void)ctx.GetOutput(i, ort_shape.data(), ort_shape.size()); + } +} + // Helper: Run the MIGraphX program and handle outputs // This function executes the compiled MIGraphX program and copies outputs that @@ -2064,6 +2316,7 @@ static bool warmup_and_capture_hip_graph( HIP_CALL_THROW(hipGraphInstantiate(&entry.exec, entry.graph, nullptr, nullptr, 0)); entry.captured = true; + entry.direct_bind = false; // Record the scratch pointer that was baked into the captured kernels so // we can detect re-allocation across replays (e.g. after pool reuse). auto scratch_it = mgx_state->scratch_bufs.find(shape_hash); @@ -2240,7 +2493,7 @@ static bool warmup_and_capture_hip_graph_direct( auto& entry = mgx_state->hip_graph_cache[shape_hash]; try { - HIP_CALL_THROW(hipStreamBeginCapture(stream, hipStreamCaptureModeGlobal)); + HIP_CALL_THROW(hipStreamBeginCapture(stream, hipStreamCaptureModeThreadLocal)); { std::lock_guard lock(*mgx_state->mgx_mu_ptr); prog.run_async(m, stream); @@ -2255,6 +2508,7 @@ static bool warmup_and_capture_hip_graph_direct( HIP_CALL_THROW(hipGraphInstantiate(&entry.exec, entry.graph, nullptr, nullptr, 0)); entry.captured = true; + entry.direct_bind = true; entry.captured_input_ptrs = input_ptrs; entry.captured_output_ptrs = output_ptrs; { @@ -2361,7 +2615,7 @@ static void run_program_or_hip_graph_direct( std::size_t padded_batch_size = 0) { auto it = mgx_state->hip_graph_cache.find(shape_hash); - if (it != mgx_state->hip_graph_cache.end() && it->second.captured) { + if (it != mgx_state->hip_graph_cache.end() && it->second.captured && it->second.direct_bind) { void* current_scratch = nullptr; { auto sit = mgx_state->scratch_bufs.find(shape_hash); @@ -2382,6 +2636,13 @@ static void run_program_or_hip_graph_direct( if (it->second.graph) { (void)hipGraphDestroy(it->second.graph); it->second.graph = nullptr; } it->second.captured = false; } else { + // Pointers matched: any earlier drift was transient (e.g. a one-off + // allocator shuffle), not a sustained rebinding pattern. Reset the + // counter so only *consecutive* mismatches can trip the permanent eager + // fallback above. Without this the count is monotonic and rare, + // recoverable drift accumulates over a long-lived session until it + // needlessly disables the fast path for good. + mgx_state->direct_recapture_count = 0; // Same rationale as in replay_hip_graph: zero EP-owned scratch before // every direct-bind replay so the captured kernel sequence isn't // contaminated by the prior replay's scratch residue. @@ -2474,7 +2735,7 @@ static void run_program_or_hip_graph( } auto it = mgx_state->hip_graph_cache.find(shape_hash); - if (it != mgx_state->hip_graph_cache.end() && it->second.captured) { + if (it != mgx_state->hip_graph_cache.end() && it->second.captured && !it->second.direct_bind) { replay_hip_graph(mgx_state, stream, shape_hash); if (!it->second.extra_outputs.empty()) { @@ -2482,6 +2743,17 @@ static void run_program_or_hip_graph( original_batch_size, padded_batch_size); } } else { + // A captured-but-direct_bind entry here means we just transitioned out of + // direct-bind mode (e.g. after pointer-drift fallback). That graph baked in + // ORT tensor addresses and uses a different output partition than the pinned + // path, so it must be torn down rather than replayed -- otherwise some ORT + // outputs are left unwritten and the fetch fails with "Unsupported OrtValue + // type". Destroy it and re-capture cleanly in pinned-copy mode. + if (it != mgx_state->hip_graph_cache.end() && it->second.captured && it->second.direct_bind) { + if (it->second.exec) { (void)hipGraphExecDestroy(it->second.exec); it->second.exec = nullptr; } + if (it->second.graph) { (void)hipGraphDestroy(it->second.graph); it->second.graph = nullptr; } + it->second.captured = false; + } if (!warmup_and_capture_hip_graph(mgx_state, stream, prog, m, prog_output_indices, shape_hash)) { run_migraphx_program(mgx_state->mgx_mu_ptr, stream, ctx, prog, m, @@ -2596,6 +2868,13 @@ migraphx::program CompileProgramWithBatch( const std::vector>& all_input_base_shapes = {}, size_t batch_size = 0) { + // MIGraphX parse + codegen is not thread-safe. This is the single chokepoint + // every compile path flows through, so serializing here guarantees at most one + // compile in the process at a time across all EP instances / GPUs. Held for + // the whole parse+quantize+compile span; per-key locks (if any) are always + // acquired *before* this one, so the ordering is one-way and deadlock-free. + std::lock_guard compile_guard(g_migraphx_compile_mutex); + LOGS_DEFAULT(VERBOSE) << "[CompileBatch] Starting compilation"; // Set input shapes with the specified batch size for ALL inputs (if provided) @@ -2694,9 +2973,10 @@ static migraphx::program load_or_compile_model( { migraphx::program prog; - if (!load_precompiled_model(prog, cache_file)) { - - prog = CompileProgramWithBatch( + // No cache configured: just compile. CompileProgramWithBatch still serializes + // itself via the global compile mutex, so this remains thread-safe. + if (cache_file.empty()) { + return CompileProgramWithBatch( onnx_string, options, t, @@ -2713,9 +2993,41 @@ static migraphx::program load_or_compile_model( input_names, all_input_base_shapes, batch_size); + } - save_compiled_model(prog, cache_file); + // Per-key locks: only one thread/process compiles+publishes this cache file; + // everyone else waiting on the same key falls through to the double-checked + // load below. Ordering is always per-key -> (inside compile) global compile + // mutex, never the reverse, so no deadlock is possible. + std::lock_guard key_lock(mutex_for_cache_key(cache_file.string())); + CacheFileLock cross_proc_lock(cache_file); + + // Double-checked load: another thread/process may have produced the cache file + // while we were blocked on the lock above. + if (load_precompiled_model(prog, cache_file)) { + return prog; } + + // Cache miss and we hold the key: we are the single compiler for this file. + prog = CompileProgramWithBatch( + onnx_string, + options, + t, + fp16_enable, + bf16_enable, + int8_enable, + fp8_enable, + int8_calibration_cache_available, + dynamic_range_map, + exhaustive_tune, + model_path, + ctx, + map_input_name_index, + input_names, + all_input_base_shapes, + batch_size); + + save_compiled_model(prog, cache_file); return prog; } @@ -2807,6 +3119,8 @@ static void handle_input_shape_mismatch( mgx_state->cached_prog_output_indices.clear(); mgx_state->last_input_shapes_raw.clear(); mgx_state->last_input_shape_hash.clear(); + mgx_state->cached_binding_is_pinned = false; + mgx_state->cached_binding_actual_batch = 0; param_shapes = prog.get_parameter_shapes(); mgx_state->defer_compilation = false; @@ -3076,6 +3390,27 @@ static bool execute_ultra_fast_path( std::size_t compiled_batch = padded_batch_size > 0 ? padded_batch_size : actual_batch; bool needs_padding = (actual_batch < compiled_batch); + // Reuse the cached binding only when it was built for THIS request's actual + // batch size AND the same binding mode it now needs. The ultra-fast caches + // are typically populated by an exact-batch run (direct-bind: ORT pointers at + // the compiled batch). Reusing that binding for a padded request runs the + // model at the compiled batch and hands ORT the compiled output shape instead + // of the request's actual shape -> "OrtValue shape verification failed. + // Current shape:{compiled} Requested shape:{actual}", which is what surfaced + // under alternating batch sizes at higher concurrency. The inverse (reusing a + // pinned binding on the direct-bind/eager path) is equally unsafe. On any + // mismatch, bail to the fast path, which rebinds correctly: pinned staging + + // output slice-back for padded, direct-bind for exact. + const bool req_direct = mgx_state->use_direct_hip_graph && !needs_padding; + const bool req_pinned = !req_direct && + (needs_padding || mgx_state->hip_graph_enabled) && + mgx_state->pinned_io.allocated && + mgx_state->cached_mgx_param_shapes.has_value(); + if (mgx_state->cached_binding_actual_batch != actual_batch || + mgx_state->cached_binding_is_pinned != req_pinned) { + return false; + } + // Direct-bind hipGraph: no copies, bind ORT pointers and replay if (mgx_state->use_direct_hip_graph && !needs_padding) { auto& m = mgx_state->cached_prog_params.value(); @@ -3128,11 +3463,13 @@ static bool execute_ultra_fast_path( auto& m = mgx_state->cached_prog_params.value(); run_program_or_hip_graph(mgx_state, rocm_stream, ctx, prog, m, mgx_state->cached_prog_output_indices, - mgx_state->last_input_shape_hash); + mgx_state->last_input_shape_hash, + actual_batch, compiled_batch); copy_pinned_outputs_to_ort(mgx_state, output_shapes, mgx_state->cached_prog_output_indices, mgx_state->cached_pinned_output_indices, ctx, actual_batch, rocm_stream); + ensure_all_outputs_allocated(ctx, output_shapes, actual_batch); return true; } @@ -3303,6 +3640,8 @@ static bool execute_fast_path( mgx_state->last_input_shapes_raw = build_input_shapes_in_cached_order(mgx_state, ctx, 0); mgx_state->last_input_shape_hash = current_hash; mgx_state->caches_valid = true; + mgx_state->cached_binding_is_pinned = false; + mgx_state->cached_binding_actual_batch = actual_batch; run_program_or_hip_graph_direct(mgx_state, rocm_stream, ctx, prog, mgx_state->cached_prog_params.value(), @@ -3328,15 +3667,19 @@ static bool execute_fast_path( mgx_state, ctx, padded_batch_size); mgx_state->last_input_shape_hash = current_hash; mgx_state->caches_valid = true; + mgx_state->cached_binding_is_pinned = true; + mgx_state->cached_binding_actual_batch = actual_batch; run_program_or_hip_graph(mgx_state, rocm_stream, ctx, prog, mgx_state->cached_prog_params.value(), mgx_state->cached_prog_output_indices, - effective_program_hash); + effective_program_hash, + actual_batch, compiled_batch); copy_pinned_outputs_to_ort(mgx_state, output_shapes, mgx_state->cached_prog_output_indices, mgx_state->cached_pinned_output_indices, ctx, actual_batch, rocm_stream); + ensure_all_outputs_allocated(ctx, output_shapes, actual_batch); return true; } @@ -3349,11 +3692,13 @@ static bool execute_fast_path( mgx_state->last_input_shapes_raw = build_input_shapes_in_cached_order(mgx_state, ctx, 0); mgx_state->last_input_shape_hash = current_hash; mgx_state->caches_valid = true; + mgx_state->cached_binding_is_pinned = false; + mgx_state->cached_binding_actual_batch = actual_batch; run_migraphx_program(mgx_state->mgx_mu_ptr, rocm_stream, ctx, prog, mgx_state->cached_prog_params.value(), mgx_state->cached_prog_output_indices); - return true; + return true; } // Result structure for handle_input_shape function @@ -3703,6 +4048,8 @@ static void execute_standard_path( mgx_state->last_input_shapes_raw = build_input_shapes_in_cached_order(mgx_state, ctx, 0); mgx_state->last_input_shape_hash = padded_hash; mgx_state->caches_valid = true; + mgx_state->cached_binding_is_pinned = false; + mgx_state->cached_binding_actual_batch = original_batch_size; run_program_or_hip_graph_direct(mgx_state, rocm_stream, ctx, prog, m, prog_output_indices, padded_hash, @@ -3753,13 +4100,17 @@ static void execute_standard_path( mgx_state, ctx, padded_batch_size); mgx_state->last_input_shape_hash = padded_hash; mgx_state->caches_valid = true; + mgx_state->cached_binding_is_pinned = true; + mgx_state->cached_binding_actual_batch = copy_actual; run_program_or_hip_graph(mgx_state, rocm_stream, ctx, prog, bind_result.params, - bind_result.prog_output_indices, padded_hash); + bind_result.prog_output_indices, padded_hash, + copy_actual, padded_batch_size); copy_pinned_outputs_to_ort(mgx_state, output_shapes, bind_result.prog_output_indices, bind_result.pinned_output_indices, ctx, copy_actual, rocm_stream); + ensure_all_outputs_allocated(ctx, output_shapes, copy_actual); } else { auto [m, prog_output_indices] = handle_program_input_outputs( param_shapes, output_shapes, map_input_name_index, ctx, @@ -3770,6 +4121,8 @@ static void execute_standard_path( mgx_state->last_input_shapes_raw = build_input_shapes_in_cached_order(mgx_state, ctx, 0); mgx_state->last_input_shape_hash = current_hash; mgx_state->caches_valid = true; + mgx_state->cached_binding_is_pinned = false; + mgx_state->cached_binding_actual_batch = original_batch_size; run_migraphx_program(mgx_state->mgx_mu_ptr, rocm_stream, ctx, prog, m, prog_output_indices); } @@ -3851,6 +4204,8 @@ static void execute_standard_path( mgx_state->last_input_shapes_raw = build_input_shapes_in_cached_order(mgx_state, ctx, 0); mgx_state->last_input_shape_hash = current_hash; mgx_state->caches_valid = true; + mgx_state->cached_binding_is_pinned = false; + mgx_state->cached_binding_actual_batch = 0; run_program_or_hip_graph_direct(mgx_state, rocm_stream, ctx, prog, m, prog_output_indices, current_hash, @@ -3874,13 +4229,17 @@ static void execute_standard_path( mgx_state->last_input_shapes_raw = build_input_shapes_in_cached_order(mgx_state, ctx, 0); mgx_state->last_input_shape_hash = current_hash; mgx_state->caches_valid = true; + mgx_state->cached_binding_is_pinned = true; + mgx_state->cached_binding_actual_batch = actual_batch; run_program_or_hip_graph(mgx_state, rocm_stream, ctx, prog, bind_result.params, - bind_result.prog_output_indices, current_hash); + bind_result.prog_output_indices, current_hash, + actual_batch, actual_batch); copy_pinned_outputs_to_ort(mgx_state, output_shapes, bind_result.prog_output_indices, bind_result.pinned_output_indices, ctx, actual_batch, rocm_stream); + ensure_all_outputs_allocated(ctx, output_shapes, actual_batch); return; } @@ -3893,6 +4252,8 @@ static void execute_standard_path( mgx_state->last_input_shapes_raw = build_input_shapes_in_cached_order(mgx_state, ctx, 0); mgx_state->last_input_shape_hash = current_hash; mgx_state->caches_valid = true; + mgx_state->cached_binding_is_pinned = false; + mgx_state->cached_binding_actual_batch = 0; run_migraphx_program(mgx_state->mgx_mu_ptr, rocm_stream, ctx, prog, m, prog_output_indices); } @@ -4175,7 +4536,8 @@ static inline void precompile_all_dynamic_batch_models( const std::filesystem::path& model_path, const std::filesystem::path& model_cache_path, const std::string& mxr_filename_prefix, - std::unordered_map& cached_programs) + std::unordered_map& cached_programs, + int device_id) { LOGS_DEFAULT(INFO) << "[precompile_all_dynamic_batch_models] Processing " << compiled_batch_sizes.size() << " batch models..."; @@ -4240,7 +4602,12 @@ static inline void precompile_all_dynamic_batch_models( for (const auto& info : batch_infos) { load_futures.push_back(std::async(std::launch::async, - [&, info]() { + [&, info, device_id]() { + // HIP's current device is thread-local and NOT inherited by this async + // worker; pin it to the EP's device so the loaded code objects bind to + // the correct GPU (otherwise non-zero-device instances fail to launch + // with "invalid device ordinal"). + HipDeviceGuard dev_guard(device_id); LOGS_DEFAULT(VERBOSE) << "[precompile_all_dynamic_batch_models] Trying to load batch " << info.batch_size << " from disk..."; @@ -4289,8 +4656,12 @@ static inline void precompile_all_dynamic_batch_models( LOGS_DEFAULT(INFO) << "[precompile_all_dynamic_batch_models] Compiling batch size " << info.batch_size << "..."; - // Compile the model (this is the thread-unsafe part that must be serialized) - migraphx::program batch_prog = CompileProgramWithBatch( + // Route through load_or_compile_model so this path shares the same + // per-key cache lock, cross-process file lock, double-checked load and + // atomic save as every other compile path. Compilation itself remains + // serialized by the global compile mutex inside CompileProgramWithBatch. + migraphx::program batch_prog = load_or_compile_model( + info.cache_file, onnx_string, options, t, @@ -4308,16 +4679,9 @@ static inline void precompile_all_dynamic_batch_models( all_input_base_shapes, info.batch_size); - LOGS_DEFAULT(INFO) << "[precompile_all_dynamic_batch_models] ✓ Compiled batch size " + LOGS_DEFAULT(INFO) << "[precompile_all_dynamic_batch_models] ✓ Ready batch size " << info.batch_size; - // Save to disk cache - save_compiled_model(batch_prog, info.cache_file); - if (!info.cache_file.empty()) { - LOGS_DEFAULT(VERBOSE) << "[precompile_all_dynamic_batch_models] Saved to disk: " - << info.cache_file.string(); - } - // Store in memory cache cached_programs[info.cache_hash] = std::move(batch_prog); } @@ -4493,7 +4857,8 @@ static inline void precompile_static_model( static void preload_mxr_cache_from_disk( const std::filesystem::path& model_cache_path, const std::string& mxr_filename_prefix, - std::unordered_map& cached_programs) + std::unordered_map& cached_programs, + int device_id) { if (model_cache_path.empty() || !std::filesystem::exists(model_cache_path)) return; @@ -4521,7 +4886,12 @@ static void preload_mxr_cache_from_disk( std::mutex mu; std::vector> futs; for (const auto& [hash, path] : to_load) { - futs.push_back(std::async(std::launch::async, [&, hash, path]() { + futs.push_back(std::async(std::launch::async, [&, hash, path, device_id]() { + // HIP's current device is thread-local and NOT inherited by this async + // worker; pin it to the EP's device so the loaded code objects bind to + // the correct GPU (otherwise non-zero-device instances fail to launch + // with "invalid device ordinal"). + HipDeviceGuard dev_guard(device_id); migraphx::program prog; if (load_precompiled_model(prog, path)) { std::lock_guard lk(mu); @@ -4558,7 +4928,8 @@ static inline bool handle_precompilation_decision( const std::string& mxr_filename_prefix, std::unordered_map& cached_programs, std::size_t max_dynamic_batch, - const std::string& compile_batches_spec) + const std::string& compile_batches_spec, + int device_id) { // ═══════════════════════════════════════════════════════════════════════════ // PRECOMPILATION: Compile models during Compile() phase instead of compute_func() @@ -4635,7 +5006,8 @@ static inline bool handle_precompilation_decision( model_path, model_cache_path, mxr_filename_prefix, - cached_programs); + cached_programs, + device_id); // Precompilation complete - disable deferred compilation LOGS_DEFAULT(VERBOSE) << "[Compile][PRECOMPILE] ✓✓✓ Dynamic batch precompilation COMPLETE for node '" @@ -4736,6 +5108,10 @@ constexpr std::uint64_t MIGraphX_Version = Status MIGraphXExecutionProvider::Compile(const std::vector& fused_nodes, std::vector& node_compute_funcs) { + // Compile()/load may run on a thread that never ran the constructor's + // hipSetDevice. Pin it to this EP's GPU so code objects are loaded/finalized + // on the correct device (otherwise: "no kernel image is available..."). + HipDeviceGuard dev_guard(device_id_); for (const auto& fused_node_graph : fused_nodes) { const GraphViewer& graph_body_viewer = fused_node_graph.filtered_graph; const Node& fused_node = fused_node_graph.fused_node; @@ -4805,11 +5181,13 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& mxr_filename_prefix, cached_programs_[fused_node.Name()], max_dynamic_batch_, - compile_batches_); + compile_batches_, + device_id_); // Pre-load any .mxr files from disk that aren't already in memory. preload_mxr_cache_from_disk(model_cache_path_, mxr_filename_prefix, - cached_programs_[fused_node.Name()]); + cached_programs_[fused_node.Name()], + device_id_); // Create program object (may be empty if precompiled programs are in cache) migraphx::program prog; @@ -4833,6 +5211,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& p->input_name_indexes = map_input_index_[context->node_name]; p->mgx_mu_ptr = &mgx_mu_; p->stream = stream_; + p->device_id = device_id_; p->defer_compilation = map_defer_compilation_[context->node_name]; p->fp16_enable = fp16_enable_; p->bf16_enable = bf16_enable_; @@ -4870,6 +5249,12 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& LOGS_DEFAULT(VERBOSE) << "[Compile][CREATE_STATE] defer_compilation=" << p->defer_compilation; } + // Coalesced input H2D: when enabled, the pinned-copy path batches all + // host-resident inputs into a single transfer (see copy_inputs_to_pinned). + // Must be set BEFORE allocate_pinned_io below, which reads it to decide + // between the single-arena and per-input buffer layouts. + p->coalesce_io = coalesce_io_enable_; + // Allocate pinned I/O buffers from the cached programs. // create_state_func runs ONCE at session init (long before any Run()), // so there is no per-Run compute stream to query here — ComputeContext @@ -4926,7 +5311,11 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& // hipGraph: set per-node enable flag and validate cached programs p->hip_graph_enabled = hip_graph_enable_; - p->use_direct_hip_graph = hip_graph_enable_; + // Coalescing operates in the pinned-copy path and consumes host-resident + // inputs, which direct-bind cannot bind into a device program. When + // coalescing is requested, force the pinned path (keep hipGraph + // capture/replay, just not the direct-bind variant). + p->use_direct_hip_graph = hip_graph_enable_ && !coalesce_io_enable_; if (p->hip_graph_enabled && p->cached_programs_ref.has_value()) { for (const auto& [hash, cached_prog] : p->cached_programs_ref.value().get()) { if (!check_hip_graph_compatibility(cached_prog, context->node_name)) { @@ -4954,8 +5343,13 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& }; compute_info.compute_func = [this, mxr_filename_prefix](FunctionState state, const OrtApi* /*api*/, OrtKernelContext* context) { - Ort::KernelContext ctx(context); MIGraphXFuncState* mgx_state = reinterpret_cast(state); + // Pin this worker thread to the EP's GPU before any HIP work (kernel + // launch, hipGraph capture/replay, deferred compile/load). HIP's current + // device is thread-local, so without this a non-zero-device instance runs + // on device 0 and fails with "invalid resource handle". + HipDeviceGuard dev_guard(mgx_state->device_id); + Ort::KernelContext ctx(context); // Run on whichever stream ORT elected for this device for THIS Run(). // - external_stream_=true -> ORT wrapper around the user-supplied stream diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index 6ce431c127c8c..cf9188e04a2b6 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -38,6 +38,13 @@ constexpr auto kModelCachePath = "ORT_MIGRAPHX_MODEL_CACHE_PATH"sv; constexpr auto kModelMaxDynamicBatch = "ORT_MIGRAPHX_MAX_DYNAMIC_BATCH"sv; constexpr auto kCompileBatches = "ORT_MIGRAPHX_COMPILE_BATCHES"sv; constexpr auto kHipGraphEnable = "ORT_MIGRAPHX_HIP_GRAPH_ENABLE"sv; +// When enabled, the pinned-copy path gathers all (host-resident) input tensors +// into a single contiguous pinned staging buffer and issues ONE host->device +// transfer into a single device arena, instead of one copy per input. This +// collapses the per-input H2D launch overhead that dominates batch-1 models +// with many small inputs (e.g. feed-gen-rec, ~190 inputs). Inputs are then +// bound as sub-views into the arena. See copy_inputs_to_pinned. +constexpr auto kCoalesceIO = "ORT_MIGRAPHX_COALESCE_IO"sv; } // namespace migraphx_env_vars // Tracks which dimensions are symbolic for a given input @@ -58,6 +65,7 @@ struct MIGraphXFuncState { std::unordered_map input_name_indexes; std::mutex* mgx_mu_ptr = nullptr; hipStream_t stream = nullptr; + int device_id = 0; bool defer_compilation = false; bool fp16_enable = false; bool bf16_enable = false; @@ -91,6 +99,19 @@ struct MIGraphXFuncState { std::unordered_map output_name_to_idx; std::size_t max_batch_size = 0; bool allocated = false; + + // ── Coalesced-input arena (enabled by ORT_MIGRAPHX_COALESCE_IO) ────────── + // When `coalesced` is true, the per-input device buffers are NOT independent + // allocations: every inputs[i].data points to (in_arena_dev + input_offsets[i]) + // inside a single device allocation `in_arena_dev` of `in_arena_bytes`. + // `in_staging_host` is a pinned host buffer with the identical layout that + // copy_inputs_to_pinned gathers into before issuing one H2D for the whole + // arena. free_pinned_io must free the arena once (not per input buffer). + bool coalesced = false; + void* in_arena_dev = nullptr; // single device arena backing all inputs + void* in_staging_host = nullptr; // pinned host staging buffer (gather target) + std::size_t in_arena_bytes = 0; // total arena size (aligned slot sum) + std::vector input_offsets; // byte offset per input, parallel to `inputs` }; PinnedIOSet pinned_io; @@ -161,6 +182,16 @@ struct MIGraphXFuncState { // Flag indicating caches are valid bool caches_valid = false; + + // Describes how cached_prog_params was last bound, so the ultra-fast path + // never reuses a binding built for a different actual batch size or a + // different binding mode (direct-bind vs pinned-copy/slice). Mixing these + // across alternating batch sizes (e.g. an exact batch-4 direct-bind binding + // being reused to service a padded batch-3 request) leaked the compiled + // batch shape to ORT and triggered the {compiled}/{actual} output-shape + // verification failure observed at higher concurrency. + bool cached_binding_is_pinned = false; + std::size_t cached_binding_actual_batch = 0; // ═══════════════════════════════════════════════════════════════════════════ // OPTIMIZATION: Cached MIGraphX API results (avoid redundant API calls) @@ -193,6 +224,12 @@ struct MIGraphXFuncState { hipGraph_t graph = nullptr; hipGraphExec_t exec = nullptr; bool captured = false; + // Which capture mode produced this entry. The direct-bind and pinned-copy + // paths partition the program outputs differently (pre-allocated #output_N + // params vs. "extra" run_async results), so an entry captured in one mode + // must never be replayed by the other -- doing so leaves some ORT outputs + // unwritten (empty OrtValue) or bound to stale addresses. + bool direct_bind = false; std::vector extra_outputs; // Addresses captured in the graph for direct-bind mode. @@ -224,6 +261,9 @@ struct MIGraphXFuncState { }; bool hip_graph_enabled = false; + // When true, the pinned-copy path coalesces all host-resident inputs into a + // single H2D transfer into pinned_io.in_arena_dev (see kCoalesceIO). + bool coalesce_io = false; // When true, capture/replay binds ORT tensor pointers directly (no pinned copies). // Requires the pool allocator to provide stable addresses. bool use_direct_hip_graph = false; @@ -335,6 +375,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider { size_t max_dynamic_batch_{0}; std::string compile_batches_{}; // Comma-separated list of batch sizes to compile, e.g. "1,4,8,16,32" bool hip_graph_enable_{false}; + bool coalesce_io_enable_{false}; // ORT_MIGRAPHX_COALESCE_IO: coalesce per-input H2D copies }; }; // namespace onnxruntime