Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions invokeai/backend/model_manager/configs/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
Main_Checkpoint_Anima_Config,
Main_Checkpoint_Flux2_Config,
Main_Checkpoint_FLUX_Config,
Main_Checkpoint_QwenImage_Config,
Main_Checkpoint_SD1_Config,
Main_Checkpoint_SD2_Config,
Main_Checkpoint_SDXL_Config,
Expand Down Expand Up @@ -183,6 +184,7 @@
Annotated[Main_Checkpoint_SDXLRefiner_Config, Main_Checkpoint_SDXLRefiner_Config.get_tag()],
Annotated[Main_Checkpoint_Flux2_Config, Main_Checkpoint_Flux2_Config.get_tag()],
Annotated[Main_Checkpoint_FLUX_Config, Main_Checkpoint_FLUX_Config.get_tag()],
Annotated[Main_Checkpoint_QwenImage_Config, Main_Checkpoint_QwenImage_Config.get_tag()],
Annotated[Main_Checkpoint_ZImage_Config, Main_Checkpoint_ZImage_Config.get_tag()],
Annotated[Main_Checkpoint_Anima_Config, Main_Checkpoint_Anima_Config.get_tag()],
# Main (Pipeline) - quantized formats
Expand Down
62 changes: 47 additions & 15 deletions invokeai/backend/model_manager/configs/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -1343,6 +1343,52 @@ def _has_qwen_image_keys(state_dict: dict[str | int, Any]) -> bool:
return has_txt_in and has_txt_norm and has_img_in and not has_context_embedder


def _infer_qwen_image_variant(sd: dict[str | int, Any], path) -> QwenImageVariantType:
"""Infer Qwen Image variant from state dict marker or filename heuristic.

Edit-variant models include an `__index_timestep_zero__` tensor used by the
`zero_cond_t` dual-modulation path. Falls back to a filename "edit" substring
check for converters that don't emit the marker.
"""
if "__index_timestep_zero__" in sd:
return QwenImageVariantType.Edit
if "edit" in path.stem.lower():
return QwenImageVariantType.Edit
return QwenImageVariantType.Generate


class Main_Checkpoint_QwenImage_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base):
"""Model config for Qwen Image single-file checkpoint models (safetensors, etc).

Covers both raw bf16/fp16 checkpoints and ComfyUI-style fp8_scaled checkpoints.
The loader dequantizes fp8 weights back to bf16 at load time; the
`default_settings.fp8_storage` toggle can then optionally re-cast to fp8 for
VRAM savings.
"""

base: Literal[BaseModelType.QwenImage] = Field(default=BaseModelType.QwenImage)
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
variant: QwenImageVariantType | None = Field(default=None)

@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_file(mod)

raise_for_override_fields(cls, override_fields)

sd = mod.load_state_dict()

if not _has_qwen_image_keys(sd):
raise NotAMatchError("state dict does not look like a Qwen Image model")

if _has_ggml_tensors(sd):
raise NotAMatchError("state dict looks like GGUF quantized")

explicit_variant = override_fields.pop("variant", None) or _infer_qwen_image_variant(sd, mod.path)

return cls(**override_fields, variant=explicit_variant)


class Main_GGUF_QwenImage_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base):
"""Model config for GGUF-quantized Qwen Image transformer models."""

Expand All @@ -1364,21 +1410,7 @@ def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -
if not _has_ggml_tensors(sd):
raise NotAMatchError("state dict does not look like GGUF quantized")

# Infer variant from the state dict if not explicitly provided.
# The Edit variant includes an extra tensor `__index_timestep_zero__` (used by the
# `zero_cond_t` dual-modulation path in diffusers' QwenImageTransformer2DModel).
# If the marker tensor is missing, fall back to the filename heuristic since older
# or alternate GGUF converters may not emit it.
explicit_variant = override_fields.pop("variant", None)
if explicit_variant is None:
if "__index_timestep_zero__" in sd:
explicit_variant = QwenImageVariantType.Edit
else:
filename = mod.path.stem.lower()
if "edit" in filename:
explicit_variant = QwenImageVariantType.Edit
else:
explicit_variant = QwenImageVariantType.Generate
explicit_variant = override_fields.pop("variant", None) or _infer_qwen_image_variant(sd, mod.path)

return cls(**override_fields, variant=explicit_variant)

Expand Down
258 changes: 195 additions & 63 deletions invokeai/backend/model_manager/load/model_loaders/qwen_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@

from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base, Diffusers_Config_Base
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.configs.main import Main_GGUF_QwenImage_Config
from invokeai.backend.model_manager.configs.main import (
Main_Checkpoint_QwenImage_Config,
Main_GGUF_QwenImage_Config,
)
from invokeai.backend.model_manager.configs.qwen_vl_encoder import (
QwenVLEncoder_Checkpoint_Config,
QwenVLEncoder_Diffusers_Config,
Expand All @@ -27,6 +30,132 @@
from invokeai.backend.util.devices import TorchDevice


def _strip_comfyui_prefix(sd: dict) -> dict:
"""Strip ComfyUI-style `model.diffusion_model.` / `diffusion_model.` prefixes from keys."""
prefix_to_strip = None
for prefix in ["model.diffusion_model.", "diffusion_model."]:
if any(k.startswith(prefix) for k in sd.keys() if isinstance(k, str)):
prefix_to_strip = prefix
break
if prefix_to_strip is None:
return sd
stripped: dict = {}
for key, value in sd.items():
if isinstance(key, str) and key.startswith(prefix_to_strip):
stripped[key[len(prefix_to_strip) :]] = value
else:
stripped[key] = value
return stripped


def _dequantize_comfyui_fp8(sd: dict) -> int:
"""Dequantize ComfyUI-style fp8_scaled weights in-place. Returns count of dequantized tensors.

Two key naming schemes are in the wild:
- `<path>.weight` + `<path>.weight_scale` (FLUX, Z-Image style)
- `<path>.weight` + `<path>.scale_weight` (Qwen2.5-VL fp8_scaled style, also
emits `<path>.scale_input` for activation scaling that we discard).
"""
scale_suffixes = (".weight_scale", ".scale_weight")
weight_scale_keys = [k for k in sd.keys() if isinstance(k, str) and k.endswith(scale_suffixes)]
count = 0
for scale_key in weight_scale_keys:
for suffix in scale_suffixes:
if scale_key.endswith(suffix):
weight_key = scale_key[: -len(suffix)] + ".weight"
break
if weight_key not in sd:
continue
weight = sd[weight_key]
scale = sd[scale_key]
weight_float = weight.float()
scale_float = scale.float()
if scale_float.shape != weight_float.shape and scale_float.numel() > 1:
for dim in range(len(weight_float.shape)):
if dim < len(scale_float.shape) and scale_float.shape[dim] != weight_float.shape[dim]:
block_size = weight_float.shape[dim] // scale_float.shape[dim]
if block_size > 1:
scale_float = scale_float.repeat_interleave(block_size, dim=dim)
sd[weight_key] = weight_float * scale_float
count += 1
return count


def _strip_quantization_metadata(sd: dict) -> None:
"""Strip ComfyUI fp8 quantization metadata keys in-place."""
keys_to_drop = [
k
for k in sd.keys()
if isinstance(k, str)
and (
k.endswith(".weight_scale")
or k.endswith(".scale_weight")
or k.endswith(".scale_input")
or "comfy_quant" in k
or k == "scaled_fp8"
)
]
for k in keys_to_drop:
del sd[k]


def _build_qwen_image_transformer_config(sd: dict, is_edit: bool) -> dict:
"""Auto-detect Qwen Image transformer architecture parameters from the state dict.

Works for both GGUF (GGMLTensor) and plain safetensors (torch.Tensor) state dicts.
Mutates nothing.
"""
from diffusers import QwenImageTransformer2DModel

def _shape(t):
return t.tensor_shape if isinstance(t, GGMLTensor) else t.shape

num_layers = 0
for key in sd.keys():
if isinstance(key, str) and key.startswith("transformer_blocks."):
parts = key.split(".")
if len(parts) >= 2:
try:
num_layers = max(num_layers, int(parts[1]) + 1)
except ValueError:
pass

num_attention_heads = 24
attention_head_dim = 128
in_channels = 64

if "img_in.weight" in sd:
shape = _shape(sd["img_in.weight"])
hidden_dim = shape[0]
in_channels = shape[1]
num_attention_heads = hidden_dim // attention_head_dim

joint_attention_dim = 3584
if "txt_in.weight" in sd:
joint_attention_dim = _shape(sd["txt_in.weight"])[1]

model_config: dict = {
"patch_size": 2,
"in_channels": in_channels,
"out_channels": 16,
"num_layers": num_layers if num_layers > 0 else 60,
"attention_head_dim": attention_head_dim,
"num_attention_heads": num_attention_heads,
"joint_attention_dim": joint_attention_dim,
"guidance_embeds": False,
"axes_dims_rope": (16, 56, 56),
}

# zero_cond_t enables dual modulation for noisy vs reference patches in edit-variant
# models. Setting it on txt2img models produces garbage. Requires diffusers 0.37+.
import inspect

if is_edit and "zero_cond_t" in inspect.signature(QwenImageTransformer2DModel.__init__).parameters:
model_config["zero_cond_t"] = True

return model_config


@ModelLoaderRegistry.register(base=BaseModelType.QwenImage, type=ModelType.Main, format=ModelFormat.Diffusers)
class QwenImageDiffusersModel(GenericDiffusersLoader):
"""Class to load Qwen Image Edit main models."""
Expand Down Expand Up @@ -73,6 +202,7 @@ def _load_model(
else:
raise e

result = self._apply_fp8_layerwise_casting(result, config, submodel_type)
return result


Expand Down Expand Up @@ -107,76 +237,78 @@ def _load_from_singlefile(self, config: AnyModelConfig) -> AnyModel:
compute_dtype = TorchDevice.choose_bfloat16_safe_dtype(target_device)

sd = gguf_sd_loader(model_path, compute_dtype=compute_dtype)
sd = _strip_comfyui_prefix(sd)

# Strip ComfyUI-style prefixes if present
prefix_to_strip = None
for prefix in ["model.diffusion_model.", "diffusion_model."]:
if any(k.startswith(prefix) for k in sd.keys() if isinstance(k, str)):
prefix_to_strip = prefix
break
is_edit = getattr(config, "variant", None) == QwenImageVariantType.Edit
model_config = _build_qwen_image_transformer_config(sd, is_edit=is_edit)

if prefix_to_strip:
stripped_sd = {}
for key, value in sd.items():
if isinstance(key, str) and key.startswith(prefix_to_strip):
stripped_sd[key[len(prefix_to_strip) :]] = value
else:
stripped_sd[key] = value
sd = stripped_sd

# Auto-detect architecture from state dict
num_layers = 0
for key in sd.keys():
if isinstance(key, str) and key.startswith("transformer_blocks."):
parts = key.split(".")
if len(parts) >= 2:
try:
layer_idx = int(parts[1])
num_layers = max(num_layers, layer_idx + 1)
except ValueError:
pass

# Detect dimensions from weights
num_attention_heads = 24 # default
attention_head_dim = 128 # default

if "img_in.weight" in sd:
w = sd["img_in.weight"]
shape = w.tensor_shape if isinstance(w, GGMLTensor) else w.shape
hidden_dim = shape[0]
in_channels = shape[1]
num_attention_heads = hidden_dim // attention_head_dim

joint_attention_dim = 3584 # default
if "txt_in.weight" in sd:
w = sd["txt_in.weight"]
shape = w.tensor_shape if isinstance(w, GGMLTensor) else w.shape
joint_attention_dim = shape[1]

model_config: dict = {
"patch_size": 2,
"in_channels": in_channels if "img_in.weight" in sd else 64,
"out_channels": 16,
"num_layers": num_layers if num_layers > 0 else 60,
"attention_head_dim": attention_head_dim,
"num_attention_heads": num_attention_heads,
"joint_attention_dim": joint_attention_dim,
"guidance_embeds": False,
"axes_dims_rope": (16, 56, 56),
}

# zero_cond_t is only used by edit-variant models. It enables dual modulation
# for noisy vs reference patches. Setting it on txt2img models produces garbage.
# Also requires diffusers 0.37+ (the parameter doesn't exist in older versions).
import inspect
with accelerate.init_empty_weights():
model = QwenImageTransformer2DModel(**model_config)

model.load_state_dict(sd, strict=False, assign=True)
return model


@ModelLoaderRegistry.register(base=BaseModelType.QwenImage, type=ModelType.Main, format=ModelFormat.Checkpoint)
class QwenImageCheckpointModel(ModelLoader):
"""Loads Qwen Image transformer models from single-file safetensors checkpoints
(e.g. ComfyUI fp8_scaled, plain bf16/fp16). Dequantizes ComfyUI fp8 scaling to
bf16 at load time; the `default_settings.fp8_storage` toggle then optionally
re-casts to fp8 for VRAM savings."""

def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, Checkpoint_Config_Base):
raise ValueError("Only CheckpointConfigBase models are currently supported here.")

match submodel_type:
case SubModelType.Transformer:
model = self._load_from_singlefile(config)
return self._apply_fp8_layerwise_casting(model, config, submodel_type)

raise ValueError(
f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
)

def _load_from_singlefile(self, config: AnyModelConfig) -> AnyModel:
from diffusers import QwenImageTransformer2DModel
from safetensors.torch import load_file

from invokeai.backend.util.logging import InvokeAILogger

logger = InvokeAILogger.get_logger(self.__class__.__name__)

if not isinstance(config, Main_Checkpoint_QwenImage_Config):
raise TypeError(f"Expected Main_Checkpoint_QwenImage_Config, got {type(config).__name__}.")
model_path = Path(config.path)

target_device = TorchDevice.choose_torch_device()
model_dtype = TorchDevice.choose_bfloat16_safe_dtype(target_device)

sd = load_file(str(model_path))
sd = _strip_comfyui_prefix(sd)

dequantized = _dequantize_comfyui_fp8(sd)
if dequantized > 0:
logger.info(f"Dequantized {dequantized} ComfyUI-quantized weights")
_strip_quantization_metadata(sd)

is_edit = getattr(config, "variant", None) == QwenImageVariantType.Edit
if is_edit and "zero_cond_t" in inspect.signature(QwenImageTransformer2DModel.__init__).parameters:
model_config["zero_cond_t"] = True
model_config = _build_qwen_image_transformer_config(sd, is_edit=is_edit)

with accelerate.init_empty_weights():
model = QwenImageTransformer2DModel(**model_config)

new_sd_size = sum(t.nelement() * model_dtype.itemsize for t in sd.values())
self._ram_cache.make_room(new_sd_size)

for k in list(sd.keys()):
if sd[k].is_floating_point():
sd[k] = sd[k].to(model_dtype)

model.load_state_dict(sd, strict=False, assign=True)
return model

Expand Down
Loading
Loading