From 1ad773dbfd44a46bb6ab69717805cef638e76bb6 Mon Sep 17 00:00:00 2001 From: Hui Kang Date: Mon, 29 Jun 2026 12:05:52 +0800 Subject: [PATCH] Fix: MiniCPM-V 4.6 training hangs on text-only samples with DeepSpeed --- swift/template/templates/minicpm.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/swift/template/templates/minicpm.py b/swift/template/templates/minicpm.py index e363a3cd8a..9383fdbd8b 100644 --- a/swift/template/templates/minicpm.py +++ b/swift/template/templates/minicpm.py @@ -9,7 +9,7 @@ from torch import nn from typing import Any, Dict, List, Literal, Optional -from swift.utils import get_env_args, get_packed_seq_params +from swift.utils import get_env_args, get_packed_seq_params, is_deepspeed_enabled from ..base import Template from ..constant import LLMTemplateType, MLLMTemplateType from ..register import TemplateMeta, register_template @@ -674,6 +674,26 @@ def _get_new_tokens(i): encoded['loss_scale'] = loss_scale return encoded + def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]: + if not self.is_training: + return inputs + + pixel_values = inputs.get('pixel_values') + pixel_values_videos = inputs.get('pixel_values_videos') + if pixel_values is None and pixel_values_videos is None and is_deepspeed_enabled(): + input_ids = inputs['input_ids'] + base_model = self.get_base_model(model) + inputs_embeds = base_model.get_input_embeddings()(input_ids) + patch_size = base_model.config.vision_config.patch_size + dummy_pv = torch.zeros( + 1, 3, 4 * patch_size, 4 * patch_size, device=inputs_embeds.device, dtype=base_model.vision_tower.dtype) + dummy_ts = torch.tensor([[4, 4]], device=inputs_embeds.device, dtype=torch.int32) + vision_output = base_model.get_image_features(dummy_pv, dummy_ts, downsample_mode=self.downsample_mode) + image_embeds = torch.cat(vision_output.pooler_output, dim=0) + inputs_embeds = inputs_embeds + image_embeds.mean() * 0. + return {'inputs_embeds': inputs_embeds} + return inputs + def _data_collator_mm_data(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]: res = {} pixel_values = [b['pixel_values'] for b in batch if b.get('pixel_values') is not None]