Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion toolkit/stable_diffusion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,12 @@ def load_model(self):
self.vae: 'AutoencoderKL' = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype)
self.vae.eval()
self.vae.requires_grad_(False)
# Flux VAE loaded from ComfyUI single-file checkpoints may not carry these config values
if self.is_flux:
if getattr(self.vae.config, 'scaling_factor', None) is None:
self.vae.config.scaling_factor = 0.3611
if getattr(self.vae.config, 'shift_factor', None) is None:
self.vae.config.shift_factor = 0.1159
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
self.vae_scale_factor = VAE_SCALE_FACTOR
self.unet.to(self.device_torch, dtype=dtype)
Expand Down Expand Up @@ -2449,7 +2455,7 @@ def encode_prompt(
prompt_embeds, pooled_prompt_embeds = train_tools.encode_prompts_flux(
self.tokenizer, # list
self.text_encoder, # list
prompt,
[prompt if prompt is not None else ''] if isinstance(prompt, str) or prompt is None else prompt,
truncate=not long_prompts,
max_length=512,
dropout_prob=dropout_prob,
Expand Down
7 changes: 7 additions & 0 deletions toolkit/train_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,13 @@ def encode_prompts_flux(
device = text_encoder[0].device
dtype = text_encoder[0].dtype

# Normalize prompts: ensure every element is a non-None string so the
# CLIP/T5 tokenizers never receive None or False as input
if isinstance(prompts, list):
prompts = [str(p) if p is not None and p is not False else '' for p in prompts]
else:
prompts = [str(prompts) if prompts is not None and prompts is not False else '']

batch_size = len(prompts)

# clip
Expand Down