diff --git a/aten/src/ATen/cuda/CUDAGraph.cpp b/aten/src/ATen/cuda/CUDAGraph.cpp index 95e21729430b9..38bead02d2943 100644 --- a/aten/src/ATen/cuda/CUDAGraph.cpp +++ b/aten/src/ATen/cuda/CUDAGraph.cpp @@ -189,10 +189,10 @@ void CUDAGraph::capture_end() { "capture_end() called before capture_begin()."); _currently_capturing_graphs.erase(capture_id_); } - AT_CUDA_CHECK(endCaptureErr); c10::cuda::CUDACachingAllocator::endAllocateToPool(capture_dev_, mempool_id_); at::getHostAllocator(at::kCUDA)->end_allocate_to_pool(mempool_id_); + AT_CUDA_CHECK(endCaptureErr); TORCH_CHECK(graph_ != nullptr, "Invalid capture.");