Skip to content

fix: warm latent cache loading does an unnecessary model swap#840

Open
Mandrakia wants to merge 1 commit into
ostris:mainfrom
Mandrakia:dataloader-warmstart-fix
Open

fix: warm latent cache loading does an unnecessary model swap#840
Mandrakia wants to merge 1 commit into
ostris:mainfrom
Mandrakia:dataloader-warmstart-fix

Conversation

@Mandrakia

Copy link
Copy Markdown

Problem

LatentCachingMixin wraps unconditionally the caching loop with

set_device_state_preset('cache_latents')
… 
restore_device_state()

On a warm start where every latent is already cached, the loop body is just os.path.exists checks — nothing is encoded — yet the device-state preset still moves models (notably the 9B transformer) off the GPU and back on, once per dataset × resolution pass. That move, not the encoding, is what makes warm starts take 10s+.

Fix

Change device state only the first time we actually hit a cache miss (gated behind a did_move flag), and restore only if we moved. This mirrors the pattern cache_text_embeddings already uses for cache_text_encoder.

Cold starts are unchanged; warm starts skip the unnecessary GPU shuffle.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant