-
Notifications
You must be signed in to change notification settings - Fork 716
GeoT optimization 2/4: Datapipes producer/consumer refactor + stream overlap #1742
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
64f26be
98df143
65cd32b
6de0ab5
bfe196f
2d1ca08
e85843b
8e8cb35
a92e79b
5693347
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -365,6 +365,15 @@ def _run_epoch( | |
| n_local = 0 | ||
| num_steps = len(dataloader) | ||
| epoch_t0 = time.perf_counter() | ||
| ### Single pinned scalar buffer reused every step so the loss D2H | ||
| ### transfer is async (non_blocking=True from device to pinned host | ||
| ### memory). The copy is issued right after forward_pass and read | ||
| ### just before the logger line; by then backward + optimizer.step | ||
| ### have run, giving the GPU time to complete the copy without | ||
| ### blocking the host. | ||
| _loss_pinned = ( | ||
| torch.zeros(1, pin_memory=True) if torch.cuda.is_available() else None | ||
| ) | ||
|
|
||
| with grad_ctx: | ||
| step_t0 = time.perf_counter() | ||
|
|
@@ -381,6 +390,13 @@ def _run_epoch( | |
| target_config=target_config, | ||
| ) | ||
|
|
||
| ### Kick off the async D2H copy of the scalar loss value into the | ||
| ### pinned buffer. Backward + optimizer.step run while the copy is | ||
| ### in flight, so by the time we call .item() below the transfer | ||
| ### is already done and there is no host stall. | ||
| if _loss_pinned is not None: | ||
| _loss_pinned.copy_(loss.detach(), non_blocking=True) | ||
|
|
||
| if is_train: | ||
| optimizer.zero_grad() | ||
| if precision == "float16" and scaler is not None: | ||
|
|
@@ -407,9 +423,13 @@ def _run_epoch( | |
| total_metrics_td.add_(metrics) | ||
| n_local += 1 | ||
|
|
||
| ### Per-step sync for the print line; lands after backward + | ||
| ### optimizer.step so it overlaps with queued GPU work. | ||
| this_loss = loss.detach().item() | ||
| ### Read the loss scalar from the pinned buffer; the async copy | ||
| ### was issued before backward so it has had the full backward + | ||
| ### optimizer.step to complete without stalling the host. | ||
| if _loss_pinned is not None: | ||
| this_loss = _loss_pinned.item() | ||
| else: | ||
| this_loss = loss.detach().item() | ||
|
Comment on lines
393
to
+432
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The fix is to record a CUDA event immediately after the copy and call |
||
| total_loss += this_loss | ||
|
|
||
| step_dt = time.perf_counter() - step_t0 | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is very funky, but makes sense! Cool implementation