fix: DataLoaderShard.set_epoch cannot penetrate SkipBatchSampler during checkpoint resume#9658
fix: DataLoaderShard.set_epoch cannot penetrate SkipBatchSampler during checkpoint resume#9658ppp2026 wants to merge 2 commits into
Conversation
…ng checkpoint resume
There was a problem hiding this comment.
Code Review
This pull request updates the dataloader and trainer mixin to properly set the epoch on wrapped batch samplers, specifically when skipping batches. The review feedback suggests a more robust, loop-based traversal of nested batch samplers to support arbitrary levels of wrapping, and recommends using getattr with a fallback for safer attribute access when retrieving the current epoch.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| if self.batch_sampler is not None and hasattr(self.batch_sampler, 'set_epoch'): | ||
| self.batch_sampler.set_epoch(epoch) | ||
| elif self.batch_sampler is not None and hasattr(self.batch_sampler, 'batch_sampler') and hasattr( | ||
| self.batch_sampler.batch_sampler, 'set_epoch'): | ||
| self.batch_sampler.batch_sampler.set_epoch(epoch) | ||
| elif self.sampler is not None and hasattr(self.sampler, 'set_epoch'): | ||
| self.sampler.set_epoch(epoch) |
There was a problem hiding this comment.
Instead of hardcoding a single level of nested batch_sampler check, we can recursively traverse the wrapped samplers using a loop. This is more robust and supports arbitrary levels of wrapping (e.g., if multiple wrappers are applied to the batch sampler).
set_epoch_called = False
sampler = self.batch_sampler
while sampler is not None:
if hasattr(sampler, 'set_epoch'):
sampler.set_epoch(epoch)
set_epoch_called = True
break
sampler = getattr(sampler, 'batch_sampler', None)
if not set_epoch_called and self.sampler is not None and hasattr(self.sampler, 'set_epoch'):
self.sampler.set_epoch(epoch)| epoch = self.state.epoch if hasattr(self, 'state') and self.state.epoch is not None else 0 | ||
| batch_sampler.set_epoch(int(epoch)) |
There was a problem hiding this comment.
Using getattr with a fallback is safer and more robust than direct attribute access on self.state to avoid potential AttributeError or TypeError if self.state is not fully initialized or if epoch is None.
| epoch = self.state.epoch if hasattr(self, 'state') and self.state.epoch is not None else 0 | |
| batch_sampler.set_epoch(int(epoch)) | |
| epoch = getattr(self.state, 'epoch', 0) or 0 if hasattr(self, 'state') else 0 | |
| batch_sampler.set_epoch(int(epoch)) |
…ng checkpoint resume
PR type
PR information
PR information
Problem
During checkpoint resume,
SkipBatchSampler(from accelerate) wrapsBatchSamplerShardbut has noset_epochmethod, causingDataLoaderShard.set_epoch()to silently skip the epoch update. This results in incorrect shuffle order (seed stays atbase_seedinstead ofbase_seed + epoch), leading to loss curve mismatch in the resumed epoch.Root Cause
shard.py:DataLoaderShard.set_epoch()checkshasattr(self.batch_sampler, 'set_epoch'), butSkipBatchSamplerhas noset_epochmethod, so the epoch update is silently skipped. The innerBatchSamplerShardnever receives the epoch.mixin.py: Inget_train_dataloader(),SkipBatchSamplerwrapsBatchSamplerShardwithout callingset_epochfirst. In transformers 4.x,set_epochis called beforeskip_first_batches, meaning the new DataLoader created by skip has no chance to receiveset_epochfrom the training loop.Fix
swift/dataloader/shard.py: Add elif branch inDataLoaderShard.set_epoch()to penetrate throughSkipBatchSamplerto innerBatchSamplerShard.set_epoch().swift/trainers/mixin.py: Callbatch_sampler.set_epoch(epoch)before wrapping withSkipBatchSampleringet_train_dataloader().Experiment results
Test Setup
Seed Verification
Added logging in
BatchSamplerShard.__iter__to printcurr_seed:Before fix:
Original training:
epoch 0: seed=42
epoch 1: seed=43
epoch 2: seed=44
Resume training (from step 7000, mid-epoch 1):
epoch 1: seed=42 ← Wrong! Should be 43
epoch 2: seed=44
After fix:
Resume training (from step 7000, mid-epoch 1):
epoch 1: seed=43
epoch 2: seed=44
Loss Curve Comparison
Before fix (resume epoch 1 vs original epoch 1):
After fix (resume epoch 1 vs original epoch 1):