Skip to content

Flux2: support local safetensors paths and stream-load to fix RAM blowup#842

Open
deepglugs wants to merge 1 commit into
ostris:mainfrom
deepglugs:flux2-fp8-streaming-load
Open

Flux2: support local safetensors paths and stream-load to fix RAM blowup#842
deepglugs wants to merge 1 commit into
ostris:mainfrom
deepglugs:flux2-fp8-streaming-load

Conversation

@deepglugs

Copy link
Copy Markdown

Loading FLUX.2-dev from a custom safetensors file (e.g. a pre-quantized fp8 checkpoint from SwarmUI/ComfyUI) previously failed because the loader assumed a HuggingFace repo id, and even when it worked the full bf16 state_dict was held in RAM alongside the module through the entire quantize pass (assign=True kept the dict and module sharing tensor references), pinning ~24-30GB of bf16 weights while quanto built fp8 QTensors on top — pushing peak host RAM well past 128GB.

Changes in extensions_built_in/diffusion_models/flux2/flux2_model.py:

  • name_or_path may now point directly at a .safetensors file. The existing repo-id and directory-of-files paths are unchanged.
  • FP8 detection: probe weight dtypes via safe_open and, if pre-quantized fp8 weights are present, auto-enable quantize=True so quanto re-packs into QTensor form and keeps the fp8 memory footprint at runtime.
  • Stream-load the transformer one tensor at a time directly into the meta-device module via safe_open + setattr. The full state_dict never exists, so quanto's per-block freeze actually releases the bf16 weights as it goes. Peak RAM during load drops from >128GB to ~55GB on FLUX.2-dev.
  • Added flux2_hf_repo_id class attr (default "black-forest-labs/FLUX.2-dev") used as the VAE fallback when name_or_path is a local file so we don't hand a filesystem path to hf_hub_download as a repo id.

Verified end-to-end against a local fp8 FLUX.2-dev checkpoint: detection fires, transformer + Mistral both load and quantize, VAE downloads from the fallback repo, pipe assembles, sample images generate, training proceeds.

Loading FLUX.2-dev from a custom safetensors file (e.g. a pre-quantized
fp8 checkpoint from SwarmUI/ComfyUI) previously failed because the loader
assumed a HuggingFace repo id, and even when it worked the full bf16
state_dict was held in RAM alongside the module through the entire
quantize pass (assign=True kept the dict and module sharing tensor
references), pinning ~24-30GB of bf16 weights while quanto built fp8
QTensors on top — pushing peak host RAM well past 128GB.

Changes in extensions_built_in/diffusion_models/flux2/flux2_model.py:

- name_or_path may now point directly at a .safetensors file. The
  existing repo-id and directory-of-files paths are unchanged.
- FP8 detection: probe weight dtypes via safe_open and, if pre-quantized
  fp8 weights are present, auto-enable quantize=True so quanto re-packs
  into QTensor form and keeps the fp8 memory footprint at runtime.
- Stream-load the transformer one tensor at a time directly into the
  meta-device module via safe_open + setattr. The full state_dict
  never exists, so quanto's per-block freeze actually releases the bf16
  weights as it goes. Peak RAM during load drops from >128GB to ~55GB
  on FLUX.2-dev.
- Added flux2_hf_repo_id class attr (default "black-forest-labs/FLUX.2-dev")
  used as the VAE fallback when name_or_path is a local file so we don't
  hand a filesystem path to hf_hub_download as a repo id.

Verified end-to-end against a local fp8 FLUX.2-dev checkpoint: detection
fires, transformer + Mistral both load and quantize, VAE downloads from
the fallback repo, pipe assembles, sample images generate, training
proceeds.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant