Skip to content

fix: DataLoaderShard.set_epoch cannot penetrate SkipBatchSampler during checkpoint resume#9658

Open
ppp2026 wants to merge 2 commits into
modelscope:mainfrom
ppp2026:fix/dataloader-set-epoch-penetration
Open

fix: DataLoaderShard.set_epoch cannot penetrate SkipBatchSampler during checkpoint resume#9658
ppp2026 wants to merge 2 commits into
modelscope:mainfrom
ppp2026:fix/dataloader-set-epoch-penetration

Conversation

@ppp2026

@ppp2026 ppp2026 commented Jun 29, 2026

Copy link
Copy Markdown

…ng checkpoint resume

PR type

  • Bug Fix
  • New Feature
  • Document Updates
  • More Models or Datasets Support

PR information

PR information

Problem

During checkpoint resume, SkipBatchSampler (from accelerate) wraps BatchSamplerShard but has no set_epoch method, causing DataLoaderShard.set_epoch() to silently skip the epoch update. This results in incorrect shuffle order (seed stays at base_seed instead of base_seed + epoch), leading to loss curve mismatch in the resumed epoch.

Root Cause

  1. shard.py: DataLoaderShard.set_epoch() checks hasattr(self.batch_sampler, 'set_epoch'), but SkipBatchSampler has no set_epoch method, so the epoch update is silently skipped. The inner BatchSamplerShard never receives the epoch.
  2. mixin.py: In get_train_dataloader(), SkipBatchSampler wraps BatchSamplerShard without calling set_epoch first. In transformers 4.x, set_epoch is called before skip_first_batches, meaning the new DataLoader created by skip has no chance to receive set_epoch from the training loop.

Fix

  1. swift/dataloader/shard.py: Add elif branch in DataLoaderShard.set_epoch() to penetrate through SkipBatchSampler to inner BatchSamplerShard.set_epoch().
  2. swift/trainers/mixin.py: Call batch_sampler.set_epoch(epoch) before wrapping with SkipBatchSampler in get_train_dataloader().

Experiment results

Test Setup

  • Model: Qwen3-VL-2B-Instruct
  • Fine-tuning: LoRA (rank=16)
  • Hardware: 2x Ascend 910B3
  • Dataset: alpaca-gpt4-data-zh + en
  • Checkpoint: resume from step 7000 (mid-epoch)

Seed Verification

Added logging in BatchSamplerShard.__iter__ to print curr_seed:
Before fix:
Original training:
epoch 0: seed=42
epoch 1: seed=43
epoch 2: seed=44
Resume training (from step 7000, mid-epoch 1):
epoch 1: seed=42 ← Wrong! Should be 43
epoch 2: seed=44
After fix:
Resume training (from step 7000, mid-epoch 1):
epoch 1: seed=43
epoch 2: seed=44

Loss Curve Comparison

Before fix (resume epoch 1 vs original epoch 1):

  • Loss average diff: 0.055
  • Loss max diff: 0.168
  • Loss correlation: -0.11 (almost uncorrelated)
    After fix (resume epoch 1 vs original epoch 1):
  • Loss average diff: 0.005
  • Loss max diff: 0.019
  • Loss correlation: 0.96+ (highly correlated)

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the dataloader and trainer mixin to properly set the epoch on wrapped batch samplers, specifically when skipping batches. The review feedback suggests a more robust, loop-based traversal of nested batch samplers to support arbitrary levels of wrapping, and recommends using getattr with a fallback for safer attribute access when retrieving the current epoch.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread swift/dataloader/shard.py Outdated
Comment on lines 87 to 93
if self.batch_sampler is not None and hasattr(self.batch_sampler, 'set_epoch'):
self.batch_sampler.set_epoch(epoch)
elif self.batch_sampler is not None and hasattr(self.batch_sampler, 'batch_sampler') and hasattr(
self.batch_sampler.batch_sampler, 'set_epoch'):
self.batch_sampler.batch_sampler.set_epoch(epoch)
elif self.sampler is not None and hasattr(self.sampler, 'set_epoch'):
self.sampler.set_epoch(epoch)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Instead of hardcoding a single level of nested batch_sampler check, we can recursively traverse the wrapped samplers using a loop. This is more robust and supports arbitrary levels of wrapping (e.g., if multiple wrappers are applied to the batch sampler).

        set_epoch_called = False
        sampler = self.batch_sampler
        while sampler is not None:
            if hasattr(sampler, 'set_epoch'):
                sampler.set_epoch(epoch)
                set_epoch_called = True
                break
            sampler = getattr(sampler, 'batch_sampler', None)

        if not set_epoch_called and self.sampler is not None and hasattr(self.sampler, 'set_epoch'):
            self.sampler.set_epoch(epoch)

Comment thread swift/trainers/mixin.py Outdated
Comment on lines +1311 to +1312
epoch = self.state.epoch if hasattr(self, 'state') and self.state.epoch is not None else 0
batch_sampler.set_epoch(int(epoch))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using getattr with a fallback is safer and more robust than direct attribute access on self.state to avoid potential AttributeError or TypeError if self.state is not fully initialized or if epoch is None.

Suggested change
epoch = self.state.epoch if hasattr(self, 'state') and self.state.epoch is not None else 0
batch_sampler.set_epoch(int(epoch))
epoch = getattr(self.state, 'epoch', 0) or 0 if hasattr(self, 'state') else 0
batch_sampler.set_epoch(int(epoch))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant