Flux2: support local safetensors paths and stream-load to fix RAM blowup#842
Open
deepglugs wants to merge 1 commit into
Open
Flux2: support local safetensors paths and stream-load to fix RAM blowup#842deepglugs wants to merge 1 commit into
deepglugs wants to merge 1 commit into
Conversation
Loading FLUX.2-dev from a custom safetensors file (e.g. a pre-quantized fp8 checkpoint from SwarmUI/ComfyUI) previously failed because the loader assumed a HuggingFace repo id, and even when it worked the full bf16 state_dict was held in RAM alongside the module through the entire quantize pass (assign=True kept the dict and module sharing tensor references), pinning ~24-30GB of bf16 weights while quanto built fp8 QTensors on top — pushing peak host RAM well past 128GB. Changes in extensions_built_in/diffusion_models/flux2/flux2_model.py: - name_or_path may now point directly at a .safetensors file. The existing repo-id and directory-of-files paths are unchanged. - FP8 detection: probe weight dtypes via safe_open and, if pre-quantized fp8 weights are present, auto-enable quantize=True so quanto re-packs into QTensor form and keeps the fp8 memory footprint at runtime. - Stream-load the transformer one tensor at a time directly into the meta-device module via safe_open + setattr. The full state_dict never exists, so quanto's per-block freeze actually releases the bf16 weights as it goes. Peak RAM during load drops from >128GB to ~55GB on FLUX.2-dev. - Added flux2_hf_repo_id class attr (default "black-forest-labs/FLUX.2-dev") used as the VAE fallback when name_or_path is a local file so we don't hand a filesystem path to hf_hub_download as a repo id. Verified end-to-end against a local fp8 FLUX.2-dev checkpoint: detection fires, transformer + Mistral both load and quantize, VAE downloads from the fallback repo, pipe assembles, sample images generate, training proceeds. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Loading FLUX.2-dev from a custom safetensors file (e.g. a pre-quantized fp8 checkpoint from SwarmUI/ComfyUI) previously failed because the loader assumed a HuggingFace repo id, and even when it worked the full bf16 state_dict was held in RAM alongside the module through the entire quantize pass (assign=True kept the dict and module sharing tensor references), pinning ~24-30GB of bf16 weights while quanto built fp8 QTensors on top — pushing peak host RAM well past 128GB.
Changes in extensions_built_in/diffusion_models/flux2/flux2_model.py:
Verified end-to-end against a local fp8 FLUX.2-dev checkpoint: detection fires, transformer + Mistral both load and quantize, VAE downloads from the fallback repo, pipe assembles, sample images generate, training proceeds.