Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion swift/template/templates/minicpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch import nn
from typing import Any, Dict, List, Literal, Optional

from swift.utils import get_env_args, get_packed_seq_params
from swift.utils import get_env_args, get_packed_seq_params, is_deepspeed_enabled
from ..base import Template
from ..constant import LLMTemplateType, MLLMTemplateType
from ..register import TemplateMeta, register_template
Expand Down Expand Up @@ -674,6 +674,26 @@ def _get_new_tokens(i):
encoded['loss_scale'] = loss_scale
return encoded

def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
if not self.is_training:
return inputs

pixel_values = inputs.get('pixel_values')
pixel_values_videos = inputs.get('pixel_values_videos')
if pixel_values is None and pixel_values_videos is None and is_deepspeed_enabled():
input_ids = inputs['input_ids']
base_model = self.get_base_model(model)
inputs_embeds = base_model.get_input_embeddings()(input_ids)
patch_size = base_model.config.vision_config.patch_size
dummy_pv = torch.zeros(
1, 3, 4 * patch_size, 4 * patch_size, device=inputs_embeds.device, dtype=base_model.vision_tower.dtype)
dummy_ts = torch.tensor([[4, 4]], device=inputs_embeds.device, dtype=torch.int32)
vision_output = base_model.get_image_features(dummy_pv, dummy_ts, downsample_mode=self.downsample_mode)
image_embeds = torch.cat(vision_output.pooler_output, dim=0)
inputs_embeds = inputs_embeds + image_embeds.mean() * 0.
return {'inputs_embeds': inputs_embeds}
return inputs

def _data_collator_mm_data(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
res = {}
pixel_values = [b['pixel_values'] for b in batch if b.get('pixel_values') is not None]
Expand Down
Loading