From c8cd27770077df719adda3c3755f0998ddd01d84 Mon Sep 17 00:00:00 2001 From: Artem Gorodetskii Date: Tue, 2 Jun 2026 12:54:38 +0400 Subject: [PATCH] [TTS][Magpietts] Added local transformer CFG distillation --- .../tts/models/magpietts_cfg_distillation.py | 481 ++++++++++++++---- ...tts_FrameStacking_OnlineCFGDistillation.sh | 2 + 2 files changed, 388 insertions(+), 95 deletions(-) diff --git a/nemo/collections/tts/models/magpietts_cfg_distillation.py b/nemo/collections/tts/models/magpietts_cfg_distillation.py index 3e8261ab0375..063dd5115e43 100644 --- a/nemo/collections/tts/models/magpietts_cfg_distillation.py +++ b/nemo/collections/tts/models/magpietts_cfg_distillation.py @@ -31,7 +31,12 @@ NRMSELogitsLoss, ) from nemo.collections.tts.models.magpietts import ContextTensorsOutput, MagpieTTSModel -from nemo.collections.tts.modules.magpietts_modules import EOSDetectionMethod, remove_embedded_eos_token +from nemo.collections.tts.modules.magpietts_modules import ( + EOSDetectionMethod, + LocalTransformerType, + clear_forbidden_logits, + remove_embedded_eos_token, +) from nemo.collections.tts.parts.utils.helpers import get_mask_from_lengths from nemo.lightning.callback_group import CallbackGroup from nemo.utils import logging @@ -69,6 +74,14 @@ class _DefaultParams: truncation_weight: Optional[float] = 0.1 # Weight coefficient for the MoE loss component. moe_loss_weight: float = 1.0 + # Whether to enable distillation of the local transformer head in addition to the main decoder logits. + distill_local_transformer: bool = False + # Target mixing weight for the local-transformer distillation loss in the final total loss. + lt_loss_weight: float = 0.1 + # Global training step at which local-transformer distillation becomes active. + lt_distillation_start_step: int = 3000 + # Number of steps used to linearly ramp the local-transformer loss weight from 0 to `lt_loss_weight`. + lt_distillation_ramp_len: int = 2000 _DEFAULT_PARAMS = _DefaultParams() @@ -119,6 +132,15 @@ def _validate_configuration(cfg: DictConfig) -> None: "It defines the relative weighting for truncated samples in the loss." ) + if hasattr(cfg, "lt_loss_weight") and not (0.0 <= cfg.get("lt_loss_weight") <= 1.0): + raise ValueError("`lt_loss_weight` must be in the range [0, 1].") + + if hasattr(cfg, "lt_distillation_start_step") and cfg.get("lt_distillation_start_step") < 0: + raise ValueError("`lt_distillation_start_step` must be non-negative.") + + if hasattr(cfg, "lt_distillation_ramp_len") and cfg.get("lt_distillation_ramp_len") < 0: + raise ValueError("`lt_distillation_ramp_len` must be non-negative.") + def _get_teacher_model(cfg: DictConfig) -> MagpieTTSModel: model_path = Path(cfg.teacher_model_path) @@ -204,17 +226,29 @@ def _process_moe_routing_info( } +@dataclass +class _StudentOutput: + """Outputs produced by the student forward pass for distillation.""" + + logits: Tensor + moe_routing_data: Optional[dict[str, Tensor]] + logits_lt: Optional[Tensor] + + def _process_batch_student( model: MagpieTTSModel, batch: dict[str, Tensor | list], -) -> tuple[Tensor, Optional[dict[str, Tensor]]]: + use_lt: bool, +) -> _StudentOutput: """Perform a teacher-forced forward decoding pass for a MagpieTTS student model. This method runs a standard forward pass without classifier-free guidance (CFG). It prepares decoder inputs from the provided audio code sequence, removes the - terminal EOS token, and computes logits for all positions after the decoder context - prefix. If the model uses a Mixture-of-Experts (MoE) decoder, aggregated routing - information is also processed and returned. + terminal EOS token, and computes decoder logits for all positions after the + decoder context prefix. If the model uses a Mixture-of-Experts (MoE) decoder, + aggregated routing information is also processed and returned. When local-transformer + distillation is enabled, the method additionally computes local-transformer logits + from the detached decoder outputs. Args: model (MagpieTTSModel): Student model instance used for the forward pass. @@ -222,21 +256,33 @@ def _process_batch_student( audio code lengths, and contextual tensors required for decoding. The audio code sequence is expected to already include the special tokens required by the decoder input convention. + use_lt (bool): Whether to compute local-transformer logits for distillation. Returns: - tuple[Tensor, Optional[dict[str, Tensor]]]: - - **logits (Tensor)**: Logits tensor of shape `(B, T', D)`, where `B` is - batch size, `T'` is the frame-stacked decoder sequence length after + _StudentOutput: Container with student outputs used for distillation: + - **logits (Tensor)**: Decoder logits of shape `(B, T', D)`, where `B` + is batch size, `T'` is the frame-stacked decoder sequence length after removing the decoder context prefix, and `D` is the concatenated logit dimension across codebooks and frame-stacking positions. - - **moe_routing_data (Optional[dict[str, Tensor]])**: Aggregated Mixture-of-Experts - routing data, or `None` if MoE is disabled or routing information is unavailable. + - **moe_routing_data (Optional[dict[str, Tensor]])**: Aggregated + Mixture-of-Experts routing data, or `None` if MoE is disabled or + routing information is unavailable. + - **logits_lt (Optional[Tensor])**: Local-transformer logits used for + distillation, or `None` when local-transformer distillation is disabled + for the current step. """ + if use_lt and model.local_transformer_type != LocalTransformerType.AR: + raise ValueError( + f"Only `LocalTransformerType.AR` is supported for local-transformer distillation, " + f"but got `{model.local_transformer_type}`." + ) + context_tensors = model.prepare_context_tensors(batch) audio_codes = batch["audio_codes"] audio_codes_lens = batch["audio_codes_lens"] dec_context_size = context_tensors.dec_context_size moe_routing_data = None + logits_lt = None audio_codes_embedded_all, audio_codes_lens_all = model.embed_audio_tokens( audio_tokens=audio_codes, @@ -248,7 +294,7 @@ def _process_batch_student( ) audio_codes_mask = get_mask_from_lengths(audio_codes_lens_) inputs = _prepare_forward_inputs(model, context_tensors, audio_codes_embedded, audio_codes_mask) - logits, _, _, moe_routing_info = model.forward(**inputs) + logits, _, dec_out, moe_routing_info = model.forward(**inputs) logits = logits[:, dec_context_size:, :] if model.use_moe and moe_routing_info is not None: @@ -256,29 +302,129 @@ def _process_batch_student( moe_routing_info=moe_routing_info, dec_input_mask=inputs["dec_input_mask"], ) - return logits, moe_routing_data + + if use_lt: + logits_lt = model._lt_helper.compute_logits( + dec_out=dec_out[:, dec_context_size:, :].detach(), + audio_codes_target=batch["audio_codes_lt"], + targets_offset_by_one=False, + ) + + return _StudentOutput( + logits=logits, + moe_routing_data=moe_routing_data, + logits_lt=logits_lt, + ) + + +def _lt_sample_autoregressive( + model: MagpieTTSModel, + dec_output: Tensor, + temperature: float = 0.7, + topk: int = 80, + cfg_scale: float = 1.0, + use_kv_cache: bool = True, + forbid_audio_eos: bool = False, + sanitize_logits: bool = False, +) -> tuple[Tensor, Tensor]: + model.local_transformer.reset_cache(use_cache=use_kv_cache) + dec_output = dec_output.unsqueeze(1) + local_transformer_input = model.local_transformer_in_projection(dec_output) + predicted_codes = [] + predicted_logits = [] + + for codebook_num in range(model.num_audio_codebooks * model.frame_stacking_factor): + size = (local_transformer_input.size(0), local_transformer_input.size(1)) + _mask = torch.ones(*size, device=local_transformer_input.device) + local_transformer_output = model.local_transformer(local_transformer_input, _mask)["output"] + + lt_out_for_proj = model.local_transformer_audio_out_projection(local_transformer_output[:, -1, :]) + codebook_logits = model.local_transformer_out_projections[codebook_num](lt_out_for_proj) + + bs = codebook_logits.size(0) // 2 + conditional_logits = codebook_logits[:bs] + unconditional_logits = codebook_logits[bs:] + cfg_logits = cfg_scale * conditional_logits + (1.0 - cfg_scale) * unconditional_logits + codebook_logits[:bs] = cfg_logits + predicted_logits.append(codebook_logits.clone()) + + if sanitize_logits: + codebook_logits = torch.nan_to_num(codebook_logits, nan=0.0, posinf=100.0, neginf=-100.0) + codebook_logits = codebook_logits.clamp(min=-100.0, max=100.0) + + codebook_logits = clear_forbidden_logits( + logits=codebook_logits.unsqueeze(1), + codebook_size=model.codebook_size, + forbid_audio_eos=forbid_audio_eos, + ) + codebook_logits = codebook_logits.squeeze(1) + + codebook_logits_topk = torch.topk(codebook_logits, topk, dim=-1)[0] + indices_to_remove = codebook_logits < codebook_logits_topk[:, -1].unsqueeze(-1) + codebook_logits_rescored = codebook_logits.clone() + codebook_logits_rescored[indices_to_remove] = float("-inf") + + if temperature <= 0.0: + codebook_preds = codebook_logits_rescored.argmax(dim=-1, keepdim=True) + else: + codebook_probs = torch.softmax(codebook_logits_rescored / temperature, dim=-1) + codebook_preds = torch.multinomial(codebook_probs, 1) + + codebook_preds[bs:] = codebook_preds[:bs] + predicted_codes.append(codebook_preds) + + next_local_transformer_input = model.audio_embeddings[codebook_num](codebook_preds.squeeze(-1)).unsqueeze(1) + next_local_transformer_input = model.audio_in_projection(next_local_transformer_input) + next_local_transformer_input = model.local_transformer_in_projection(next_local_transformer_input) + local_transformer_input = torch.cat([local_transformer_input, next_local_transformer_input], dim=1) + + predicted_codes = torch.cat(predicted_codes, dim=1) + predicted_logits = torch.cat(predicted_logits, dim=1) + dims = (-1, model.frame_stacking_factor, model.num_audio_codebooks) + predicted_codes = predicted_codes.reshape(*dims).permute(0, 2, 1) + + predicted_codes = predicted_codes[:bs] + predicted_logits = predicted_logits[:bs] + + return predicted_codes, predicted_logits + + +@dataclass +class _TeacherOutput: + """Outputs produced by the teacher rollout used as distillation targets.""" + + codes: Tensor + logits: Tensor + lens: Tensor + sample_weights: Optional[Tensor] + + codes_lt: Optional[Tensor] + logits_lt: Optional[Tensor] def _infer_batch_teacher( model: MagpieTTSModel, batch: dict[str, Tensor | list], + use_lt: bool, max_decoder_steps: int = 500, temperature: float = 0.7, topk: int = 80, - cfg_scale: int = 2.5, + cfg_scale: float = 2.5, truncation_threshold: Optional[float] = None, truncation_weight: Optional[float] = None, use_kv_cache: bool = False, eos_detection_method: str = "argmax_or_multinomial_any", min_generated_frames: int = 4, -) -> tuple[Tensor, Tensor, Tensor, Optional[Tensor]]: +) -> _TeacherOutput: """Perform autoregressive batch inference for a MagpieTTS teacher model. This method generates audio token rollouts conditioned on text and context inputs. It performs autoregressive decoding with classifier-free guidance (CFG) by combining conditional and unconditional logits using the specified guidance scale. The function also supports optional truncation-based early stopping and per-sample weighting for - downstream distillation losses. + downstream distillation losses. When local-transformer distillation is enabled, it + additionally generates local-transformer predictions and logits aligned with the + decoder rollout. Generation is performed in units of stacked frames. When frame stacking is enabled, each decoder step predicts a full stacked block of audio codes. If EOS is detected @@ -290,6 +436,8 @@ def _infer_batch_teacher( model (MagpieTTSModel): The teacher model instance used for autoregressive generation. batch (dict[str, Tensor | list]): Input batch containing text tokens, audio code lengths, and additional contextual tensors required for decoding. + use_lt (bool): Whether to generate local-transformer rollout targets and logits for + distillation. max_decoder_steps (int, optional): Maximum number of generated audio frames in the unstacked time domain. Defaults to 500. temperature (float, optional): Sampling temperature controlling randomness during token @@ -311,15 +459,27 @@ def _infer_batch_teacher( is allowed. Prevents premature termination. Defaults to 4. Returns: - tuple[Tensor, Tensor, Tensor, Optional[Tensor]]: - - **predicted_codes (Tensor)**: Generated discrete audio codes of shape `(B, C, T)`, - where `T` is in the unstacked time domain. - - **predicted_codes_logits (Tensor)**: Decoder logits collected per stacked decoding step, - shape `(B, T', D)`, where `T'` is the frame-stacked sequence length. - - **predicted_codes_lens (Tensor)**: Predicted rollout lengths per batch item in the - unstacked time domain, shape `(B,)`. - - **sample_weights (Optional[Tensor])**: Optional per-sample weighting factors, shape `(B,)`. + _TeacherOutput: Container with teacher rollout outputs used for distillation: + - **codes (Tensor)**: Generated discrete audio codes of shape `(B, C, T)`, where + `T` is in the unstacked time domain. + - **logits (Tensor)**: Decoder logits collected per stacked decoding step, of shape + `(B, T', D)`, where `T'` is the frame-stacked sequence length. + - **lens (Tensor)**: Predicted rollout lengths per batch item in the unstacked + time domain, shape `(B,)`. + - **sample_weights (Optional[Tensor])**: Optional per-sample weighting factors, + shape `(B,)`. + - **codes_lt (Optional[Tensor])**: Generated local-transformer code targets aligned + with the rollout, or `None` when local-transformer distillation is disabled for + the current step. + - **logits_lt (Optional[Tensor])**: Local-transformer logits aligned with the rollout, + or `None` when local-transformer distillation is disabled for the current step. """ + if use_lt and model.local_transformer_type != LocalTransformerType.AR: + raise ValueError( + f"Only `LocalTransformerType.AR` is supported for local-transformer distillation, " + f"but got `{model.local_transformer_type}`." + ) + model.decoder.reset_cache(use_cache=use_kv_cache) eos_detection_method = EOSDetectionMethod(eos_detection_method) context_tensors = model.prepare_context_tensors(batch) @@ -331,6 +491,7 @@ def _infer_batch_teacher( audio_codes_input = torch.full(dims, model.audio_bos_id, device=device).long() audio_codes_lens = torch.full((bs,), fs_factor, device=device).long() truncation_count = 0 + codes_lt, logits_lt = None, None if truncation_weight is None: sample_weights = None @@ -349,8 +510,8 @@ def _infer_batch_teacher( context_tensors.additional_decoder_mask, ) ) - predicted_codes_logits = [] - predictions = [] + predicted_codes_logits, predicted_codes_logits_lt = [], [] + predictions, predictions_lt = [], [] # Stores the start of the final retained stacked block. end_indices = {} attn_prior = [None, None] if model.model_type == "multi_encoder_context_tts" else None @@ -381,7 +542,7 @@ def _infer_batch_teacher( cfg_audio_codes_embedded[bs:, :index] = dummy_additional_decoder_input cfg_audio_codes_mask[bs:, :index] = dummy_addition_dec_mask - combined_logits, _, _, _ = model.forward( + combined_logits, _, dec_out, _ = model.forward( dec_input_embedded=cfg_audio_codes_embedded, dec_input_mask=cfg_audio_codes_mask, cond=cfg_cond, @@ -391,23 +552,33 @@ def _infer_batch_teacher( ) cond_logits = combined_logits[:bs] uncond_logits = combined_logits[bs:] - all_code_logits = (1 - cfg_scale) * uncond_logits + cfg_scale * cond_logits + mixed_logits = (1 - cfg_scale) * uncond_logits + cfg_scale * cond_logits forbid_audio_eos = idx * fs_factor < min_generated_frames - all_code_logits_t = all_code_logits[:, -1, :] - predicted_codes_logits.append(all_code_logits_t.unsqueeze(1)) + logits_t = mixed_logits[:, -1, :] + predicted_codes_logits.append(logits_t.unsqueeze(1)) audio_codes_next = model.sample_codes_from_logits( - all_code_logits_t, + logits_t, temperature=temperature, topk=topk, forbid_audio_eos=forbid_audio_eos, ) all_codes_next_argmax = model.sample_codes_from_logits( - all_code_logits_t, + logits_t, temperature=0.01, topk=1, forbid_audio_eos=forbid_audio_eos, ) + if use_lt: + audio_codes_next_lt, logits_t_lt = _lt_sample_autoregressive( + model=model, + dec_output=dec_out[:, -1, :], + temperature=temperature, + topk=topk, + cfg_scale=cfg_scale, + use_kv_cache=use_kv_cache, + forbid_audio_eos=forbid_audio_eos, + ) for item_idx in range(bs): if item_idx in end_indices: @@ -439,6 +610,10 @@ def _infer_batch_teacher( audio_codes_input = torch.cat([audio_codes_input, audio_codes_next], dim=-1) audio_codes_lens = audio_codes_lens + fs_factor + if use_lt: + predictions_lt.append(audio_codes_next_lt) + predicted_codes_logits_lt.append(logits_t_lt.unsqueeze(1)) + if len(end_indices) == bs and len(predictions) >= 4: msg = "All ends reached" if truncation_threshold is not None: @@ -446,22 +621,35 @@ def _infer_batch_teacher( print(msg) break - predicted_codes = torch.cat(predictions, dim=-1) - predicted_codes_logits = torch.cat(predicted_codes_logits, dim=1) + codes = torch.cat(predictions, dim=-1) + logits = torch.cat(predicted_codes_logits, dim=1) - max_step = predicted_codes.size(-1) + max_step = codes.size(-1) predicted_lens = [end_indices[idx] + fs_factor if idx in end_indices else max_step for idx in range(bs)] predicted_codes_lens = torch.tensor(predicted_lens, device=text.device).long() max_len = predicted_codes_lens.max() max_stacked_len = max_len // fs_factor - predicted_codes = predicted_codes[:, :, :max_len] - predicted_codes_logits = predicted_codes_logits[:, :max_stacked_len, :] + codes = codes[:, :, :max_len] + logits = logits[:, :max_stacked_len, :] + + if use_lt: + codes_lt = torch.cat(predictions_lt, dim=-1) + codes_lt = codes_lt[:, :, :max_len] + logits_lt = torch.cat(predicted_codes_logits_lt, dim=1) + logits_lt = logits_lt[:, :max_stacked_len, :] model.decoder.reset_cache(use_cache=False) torch.cuda.empty_cache() - return predicted_codes, predicted_codes_logits, predicted_codes_lens, sample_weights + return _TeacherOutput( + codes=codes, + logits=logits, + lens=predicted_codes_lens, + sample_weights=sample_weights, + codes_lt=codes_lt, + logits_lt=logits_lt, + ) def _collect_validation_outputs( @@ -488,6 +676,26 @@ def _add(items: list[dict[str, Tensor]]) -> None: return torch.stack(values).mean() +def _get_loss_key( + key: str, + lt_mode: bool = False, +) -> str: + if not lt_mode: + return key + return f"{key}_lt" + + +_MONITORED_LOSS_KEYS: list[str] = [ + _get_loss_key("kl_loss"), + _get_loss_key("ce_loss"), + _get_loss_key("nrmse_loss"), + "moe_loss", + _get_loss_key("kl_loss", lt_mode=True), + _get_loss_key("ce_loss", lt_mode=True), + _get_loss_key("nrmse_loss", lt_mode=True), +] + + class OnlineCFGDistillation(MagpieTTSModel): """Implements online classifier-free guidance (CFG) distillation for MagpieTTS.""" @@ -632,26 +840,52 @@ def _add_batch_audio_codes( def _update_batch( self, batch: dict[str, Tensor | list], - rollout_codes: Tensor, - rollout_lens: Tensor, + teacher_output: _TeacherOutput, + use_lt: bool, ) -> dict[str, Tensor | list]: rollout_codes = torch.nn.functional.pad( - input=rollout_codes, + input=teacher_output.codes, pad=(self.frame_stacking_factor, 0), value=self.audio_bos_id, ) batch["audio_codes"] = rollout_codes - batch["audio_codes_lens"] = rollout_lens + self.frame_stacking_factor + batch["audio_codes_lens"] = teacher_output.lens + self.frame_stacking_factor + + if use_lt: + if teacher_output.codes_lt is None: + raise ValueError( + "Local-transformer distillation is enabled for this step, but `teacher_output.codes_lt` is None." + ) + batch["audio_codes_lt"] = teacher_output.codes_lt + return batch - def _compute_loss( + def _get_local_transformer_status(self) -> bool: + if self.local_transformer_type == LocalTransformerType.NO_LT or not self.distill_local_transformer: + return False + + return self.global_step >= self.lt_distillation_start_step + + def _get_local_transformer_loss_weight(self) -> float: + if self.global_step < self.lt_distillation_start_step: + return 0.0 + + elif self.global_step >= self.lt_distillation_start_step + self.lt_distillation_ramp_len: + return self.lt_loss_weight + + weight = self.lt_loss_weight + weight *= (self.global_step - self.lt_distillation_start_step) / self.lt_distillation_ramp_len + + return weight + + def _compute_loss_helper( self, - teacher_codes: Tensor, teacher_logits: Tensor, + teacher_codes: Tensor, student_logits: Tensor, mask: Tensor, sample_weights: Optional[Tensor], - moe_routing_data: Optional[dict[str, Tensor]], + lt_mode: bool, ) -> dict[str, Tensor]: output: dict[str, Tensor] = {} @@ -665,7 +899,7 @@ def _compute_loss( if self.distillation_temperature != 1.0: kl_loss = kl_loss * (self.distillation_temperature**2) - output["kl_loss"] = kl_loss + output[_get_loss_key("kl_loss", lt_mode)] = kl_loss if self.alpha != 0.0: ce_loss = self._ce_criterion( @@ -674,9 +908,11 @@ def _compute_loss( mask=mask, sample_weights=sample_weights, ) - output["ce_loss"] = ce_loss + output[_get_loss_key("ce_loss", lt_mode)] = ce_loss - loss = (1 - self.alpha) * output.get("kl_loss", 0.0) + self.alpha * output.get("ce_loss", 0.0) + kl_term = output.get(_get_loss_key("kl_loss", lt_mode), 0.0) + ce_term = output.get(_get_loss_key("ce_loss", lt_mode), 0.0) + loss = (1 - self.alpha) * kl_term + self.alpha * ce_term if self.beta > 0.0: nrmse_loss = self._nrmse_criterion( @@ -685,18 +921,71 @@ def _compute_loss( mask=mask, sample_weights=sample_weights, ) - output["nrmse_loss"] = nrmse_loss + output[_get_loss_key("nrmse_loss", lt_mode)] = nrmse_loss loss = loss + self.beta * nrmse_loss - if moe_routing_data is not None: - _, _, moe_loss = self.moe_auxiliary_loss(**moe_routing_data) - output["moe_loss"] = moe_loss - loss = loss + self.moe_loss_weight * moe_loss + output[_get_loss_key("loss", lt_mode)] = loss - output["loss"] = loss + return output + + def _compute_loss( + self, + teacher_output: _TeacherOutput, + student_output: _StudentOutput, + mask: Tensor, + use_lt: bool, + ) -> dict[str, Tensor]: + output = self._compute_loss_helper( + teacher_logits=teacher_output.logits, + teacher_codes=teacher_output.codes, + student_logits=student_output.logits, + mask=mask, + sample_weights=teacher_output.sample_weights, + lt_mode=False, + ) + backbone_loss_key = _get_loss_key("loss") + + if backbone_loss_key != "loss": + output["loss"] = output[backbone_loss_key] + + if use_lt: + lt_output = self._compute_loss_helper( + teacher_logits=teacher_output.logits_lt, + teacher_codes=teacher_output.codes_lt, + student_logits=student_output.logits_lt, + mask=mask, + sample_weights=teacher_output.sample_weights, + lt_mode=True, + ) + output.update(lt_output) + lt_weight = self._get_local_transformer_loss_weight() + lt_loss_key = _get_loss_key("loss", lt_mode=True) + output["loss"] = (1 - lt_weight) * output["loss"] + lt_weight * output[lt_loss_key] + del output[lt_loss_key] + + if student_output.moe_routing_data is not None: + _, _, moe_loss = self.moe_auxiliary_loss(**student_output.moe_routing_data) + output["moe_loss"] = moe_loss + output["loss"] = output["loss"] + self.moe_loss_weight * moe_loss return output + def _rescale_logits( + self, + teacher_output: _TeacherOutput, + student_output: _StudentOutput, + use_lt: bool, + ) -> tuple[_TeacherOutput, _StudentOutput]: + if self.distillation_temperature != 1.0: + student_output.logits = student_output.logits / self.distillation_temperature + teacher_output.logits = teacher_output.logits / self.distillation_temperature + + if use_lt: + student_output.logits_lt = student_output.logits_lt / self.distillation_temperature + teacher_output.logits_lt = teacher_output.logits_lt / self.distillation_temperature + + return teacher_output, student_output + def _process_batch_distillation( self, batch: dict[str, Tensor | list], @@ -704,13 +993,22 @@ def _process_batch_distillation( ) -> dict[str, Tensor]: """Perform a knowledge distillation step between teacher and student models. - This method orchestrates the end-to-end distillation process: - 1. The teacher model generates rollouts. - 2. The student model performs a teacher-forced forward pass on those rollouts. - 3. Multiple loss components are computed, including KL divergence, cross-entropy, - normalized RMSE (NRMSE), and optionally Mixture-of-Experts (MoE) auxiliary loss. - 4. The total loss is formed from the active loss terms, controlled by the - coefficients `alpha`, `beta`, and `moe_loss_weight`. + This method orchestrates the end-to-end online distillation process: + 1. Audio codes are added to the batch if they are not already present. + 2. The teacher model generates autoregressive rollout targets and logits. + 3. The batch is updated with teacher-generated rollout codes. + 4. The student model performs a teacher-forced forward pass on the updated batch. + 5. Decoder logits are optionally rescaled by the distillation temperature. + 6. Distillation losses are computed for the main decoder, and optionally for the + local transformer if local-transformer distillation is enabled for the current step. + + Depending on configuration, the final loss may include: + - KL divergence between student and teacher logits + - Cross-entropy on teacher-generated discrete codes + - Normalized RMSE between student and teacher logits + - Optional Mixture-of-Experts (MoE) auxiliary loss + - Optional local-transformer distillation loss mixed with the base loss using + the configured local-transformer loss schedule Args: batch (dict[str, Tensor | list]): Input batch containing text tokens, conditioning @@ -721,20 +1019,24 @@ def _process_batch_distillation( efficiency. Defaults to `"train"`. Returns: - dict[str, Tensor]: Dictionary containing computed loss components and auxiliary values: - - **loss (Tensor)**: Total weighted distillation loss combining all active components. - - **kl_loss (Tensor, optional)**: KL divergence between student and teacher logits. - Included when `alpha != 1.0`. - - **ce_loss (Tensor, optional)**: Cross-entropy loss between student predictions and - teacher-generated discrete audio codes. Included when `alpha != 0.0`. - - **nrmse_loss (Tensor, optional)**: Normalized RMSE loss between student and teacher - logits. Included when `beta > 0.0`. - - **moe_loss (Tensor, optional)**: Auxiliary Mixture-of-Experts routing loss. - Included when MoE routing data is available. + dict[str, Tensor]: Dictionary containing the total loss and any active auxiliary + loss components. Always includes: + - **loss (Tensor)**: Final distillation loss used for optimization. + + Depending on the active configuration, may additionally include: + - **kl_loss (Tensor)**: KL divergence between student and teacher decoder logits. + - **ce_loss (Tensor)**: Cross-entropy loss on teacher-generated decoder codes. + - **nrmse_loss (Tensor)**: Normalized RMSE between student and teacher decoder logits. + - **moe_loss (Tensor)**: Auxiliary Mixture-of-Experts routing loss. + - **kl_loss_lt (Tensor)**: KL divergence between student and teacher local-transformer logits. + - **ce_loss_lt (Tensor)**: Cross-entropy loss on teacher-generated local-transformer codes. + - **nrmse_loss_lt (Tensor)**: Normalized RMSE between student and teacher + local-transformer logits. """ batch = self._add_batch_audio_codes(batch) + use_lt = self._get_local_transformer_status() - rollout_codes, rollout_logits, rollout_lens, sample_weights = _infer_batch_teacher( + teacher_output = _infer_batch_teacher( model=self._teacher_model, batch=batch, max_decoder_steps=self.max_decoder_steps, @@ -744,24 +1046,14 @@ def _process_batch_distillation( truncation_threshold=self.truncation_threshold if mode == "train" else None, truncation_weight=self.truncation_weight if mode == "train" else None, use_kv_cache=self.use_kv_cache_during_rollout, + use_lt=use_lt, ) + batch = self._update_batch(batch, teacher_output, use_lt) + student_output = _process_batch_student(self, batch, use_lt) + mask = get_mask_from_lengths(teacher_output.lens) + teacher_output, student_output = self._rescale_logits(teacher_output, student_output, use_lt) + output = self._compute_loss(teacher_output, student_output, mask, use_lt) - batch = self._update_batch(batch, rollout_codes, rollout_lens) - student_logits, moe_routing_data = _process_batch_student(model=self, batch=batch) - mask = get_mask_from_lengths(rollout_lens) - - if self.distillation_temperature != 1.0: - student_logits = student_logits / self.distillation_temperature - rollout_logits = rollout_logits / self.distillation_temperature - - output = self._compute_loss( - teacher_codes=rollout_codes, - teacher_logits=rollout_logits, - student_logits=student_logits, - mask=mask, - sample_weights=sample_weights, - moe_routing_data=moe_routing_data, - ) return output def training_step( @@ -791,7 +1083,7 @@ def training_step( on_step=True, on_epoch=True, ) - for key in ["kl_loss", "ce_loss", "nrmse_loss", "moe_loss"]: + for key in _MONITORED_LOSS_KEYS: if key in outputs: self.log( name=f"train/{key}", @@ -812,15 +1104,14 @@ def validation_step( ) -> dict[str, Tensor]: """Execute a single validation step for the model. - Args: - batch (dict): Validation batch containing required model inputs. - batch_idx (int): Index of the current validation batch. - dataloader_idx (int): Index of the dataloader (0 for single dataloader). + Args: + batch (dict): Validation batch containing required model inputs. + batch_idx (int): Index of the current validation batch. + dataloader_idx (int): Index of the dataloader (0 for single dataloader). Returns: dict[str, Tensor]: Dictionary containing validation loss values returned by - `_process_batch_distillation()`. Always includes `"loss"` and may also - include `"kl_loss"`, `"ce_loss"`, `"nrmse_loss"`, and `"moe_loss"`. + `_process_batch_distillation()`. """ val_output = self._process_batch_distillation(batch, mode="validation") self.validation_step_outputs[dataloader_idx].append(val_output) @@ -853,7 +1144,7 @@ def _on_validation_epoch_end_logging( enable_graph=False, ) - for key in ["kl_loss", "ce_loss", "nrmse_loss", "moe_loss"]: + for key in _MONITORED_LOSS_KEYS: value = _collect_validation_outputs(val_outputs, key) if value is not None: self.log( diff --git a/tests/functional_tests/L2_TTS_Fast_dev_runs_Magpietts_FrameStacking_OnlineCFGDistillation.sh b/tests/functional_tests/L2_TTS_Fast_dev_runs_Magpietts_FrameStacking_OnlineCFGDistillation.sh index ce20b919a7aa..aadfc775d01c 100644 --- a/tests/functional_tests/L2_TTS_Fast_dev_runs_Magpietts_FrameStacking_OnlineCFGDistillation.sh +++ b/tests/functional_tests/L2_TTS_Fast_dev_runs_Magpietts_FrameStacking_OnlineCFGDistillation.sh @@ -19,6 +19,8 @@ HF_HUB_OFFLINE=1 TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1 coverage run -a --data-file= +model.teacher_model_path="/home/TestData/tts/2602_FrameStacking4x/frame-stacking-4x-english-nanocodec.ckpt" \ model.codecmodel_path="/home/TestData/tts/21fps_causal_codecmodel.nemo" \ +model.frame_stacking_factor=4 \ + +model.distill_local_transformer=true \ + +model.lt_distillation_start_step=0 \ model.alignment_loss_scale=0.0 \ model.use_text_conditioning_encoder=false \ model.prior_scaling_factor=null \