diff --git a/aten/src/ATen/cuda/CUDAGraph.cpp b/aten/src/ATen/cuda/CUDAGraph.cpp index 3de3cc3f5259c..eda9e84ad7f69 100644 --- a/aten/src/ATen/cuda/CUDAGraph.cpp +++ b/aten/src/ATen/cuda/CUDAGraph.cpp @@ -129,10 +129,25 @@ void CUDAGraph::capture_end() { TORCH_CHECK(stream.stream() == capture_stream_.stream(), "Capture must end on the same stream it began on."); +<<<<<<< HEAD AT_CUDA_CHECK(cudaStreamEndCapture(capture_stream_, &graph_)); +======= + // Capture is over once cudaStreamEndCapture returns (success or failure). + // Clear bookkeeping before propagating the return status so watchdog-side + // checks cannot observe stale "capture active" state on error paths. + cudaError_t endCaptureErr = cudaStreamEndCapture(capture_stream_, &graph_); + { + std::unique_lock lock(_currently_capturing_graphs_mutex); + TORCH_CHECK( + _currently_capturing_graphs.count(capture_id_), + "capture_end() called before capture_begin()."); + _currently_capturing_graphs.erase(capture_id_); + } +>>>>>>> 08b5c324de0 ([release/2.12] fix leak in CUDAGraph::capture_end (#180395) (#3357)) c10::cuda::CUDACachingAllocator::endAllocateToPool(capture_dev_, mempool_id_); at::getHostAllocator(at::kCUDA)->end_allocate_to_pool(mempool_id_); + AT_CUDA_CHECK(endCaptureErr); TORCH_CHECK(graph_ != nullptr, "Invalid capture.");