From eaed9abaf945e53da21dea4065c913ac7b54a734 Mon Sep 17 00:00:00 2001 From: Rasmus Larsen Date: Tue, 2 Sep 2025 21:34:33 +0200 Subject: [PATCH 1/7] changes for context parallel gemma, todos for flex attn --- maester/config.py | 5 +- maester/models/gemma/__init__.py | 10 ++-- maester/models/gemma/model.py | 3 +- maester/parallelisms/parallel_dims.py | 66 +++++++++++++++++------ maester/parallelisms/parallelize_gemma.py | 4 +- 5 files changed, 60 insertions(+), 28 deletions(-) diff --git a/maester/config.py b/maester/config.py index 5524b80..8ed527c 100644 --- a/maester/config.py +++ b/maester/config.py @@ -83,13 +83,14 @@ class Config(BaseSettings): data_parallel_shard_degree: int = 8 data_parallel_replicate_degree: int = 1 tensor_parallel_degree: int = 1 + context_parallel_degree: int = 1 train_batch_size: int = 2 # per device; 2 * 8 gpus * 32 nodes * 8192 seqlen = ~4M tokens per batch gradient_accumulation_steps: int = 1 gradient_accumulation_sync_each_step: bool = False train_num_steps: int = 1000 compile: bool = True enable_loss_parallel: bool = True - enable_cut_cross_entropy: bool = True + enable_cut_cross_entropy: bool = False init_timeout_seconds: int = 300 train_timeout_seconds: int = 100 @@ -143,7 +144,7 @@ class Config(BaseSettings): # lr schedule scheduler: str = "linear_warmup_cosine" warmup_steps: int = 50 - cooldown_steps: int = 100 # used for some schedules + cooldown_steps: int = 50 # fsdp mixed_precision_param: str = 'bfloat16' diff --git a/maester/models/gemma/__init__.py b/maester/models/gemma/__init__.py index eabc420..650c3e0 100644 --- a/maester/models/gemma/__init__.py +++ b/maester/models/gemma/__init__.py @@ -3,14 +3,14 @@ __all__ = ["GemmaTextModel", "ModelArgs"] gemma3_configs = { - "270M": ModelArgs( - vocab_size=262_144, - dim=640, - n_layers=18, + "debug": ModelArgs( + vocab_size=262_144, # Actual size from google/gemma-3-1b-pt tokenizer + dim=1152, + n_layers=4, n_heads=4, num_key_value_heads=1, head_dim=256, - intermediate_size=2048, + intermediate_size=6912, attn_types=["local_sliding", "local_sliding", "local_sliding", "local_sliding", "local_sliding", "global"], use_post_ffw_norm=True, use_pre_ffw_norm=True, diff --git a/maester/models/gemma/model.py b/maester/models/gemma/model.py index b296e99..446edff 100644 --- a/maester/models/gemma/model.py +++ b/maester/models/gemma/model.py @@ -221,7 +221,8 @@ class GemmaAttention(nn.Module): def __init__( self, config: ModelArgs, - attn_type: str + attn_type: str, + device_mesh = None ): super().__init__() diff --git a/maester/parallelisms/parallel_dims.py b/maester/parallelisms/parallel_dims.py index 058108a..52b180d 100644 --- a/maester/parallelisms/parallel_dims.py +++ b/maester/parallelisms/parallel_dims.py @@ -11,6 +11,7 @@ class ParallelDims: dp_replicate: int dp_shard: int tp: int + cp: int world_size: int enable_loss_parallel: bool @@ -18,49 +19,68 @@ def __post_init__(self): self._validate() def _validate(self): - dp_replicate, dp_shard, tp = self.dp_replicate, self.dp_shard, self.tp - for d in (dp_replicate, tp): + dp_replicate, dp_shard, tp, cp = self.dp_replicate, self.dp_shard, self.tp, self.cp + for d in (dp_replicate, tp, cp): assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard" assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1." dp = dp_replicate * dp_shard - if dp < 0: - dp = self.world_size // (tp) - self.dp_shard = dp_shard = dp // dp_replicate + if dp_shard < 0: + self.dp_shard = dp_shard = self.world_size // (dp_replicate * tp * cp) + assert dp_shard >= 1 assert dp_replicate >= 1 assert dp_shard >= 1 assert tp >= 1, tp - assert dp_replicate * dp_shard * tp == self.world_size, ( + assert cp >= 1, cp + assert dp_replicate * dp_shard * tp * cp == self.world_size, ( f"Invalid parallel dims: dp_replicate({dp_replicate}) * dp_shard({dp_shard}) * " - f"tp({tp}) != WORLD_SIZE({self.world_size})" + f"tp({tp}) * cp({cp}) != WORLD_SIZE({self.world_size})" ) def build_mesh(self, device_type): dims = [] names = [] for d, name in zip( - [self.dp_replicate, self.dp_shard, self.tp], - ["dp_replicate", "dp_shard", "tp"], + [self.dp_replicate, self.dp_shard, self.tp, self.cp], + ["dp_replicate", "dp_shard", "tp", "cp"], ): if d > 1: dims.append(d) - if (name == "dp_replicate" and self.dp_shard == 1) or ( - name == "dp_shard" and self.dp_replicate == 1 - ): - names.append("dp") - else: - names.append(name) + names.append(name) if dims == []: # edge case for non-distributed mesh w/ 1 GPU dims = [1] names = ("dp",) logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") names = tuple(names) mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) + # Create all the submesh here to ensure all required process groups are # initialized - if self.dp_replicate > 1 and self.dp_shard > 1: - mesh["dp_replicate", "dp_shard"]._flatten(mesh_dim_name="dp") + dp_mesh_dim_names = [] # for data loading (no comms) + dp_shard_cp_mesh_dim_names = [] # for param sharding + dp_cp_mesh_dim_names = [] # for loss all-reduce + + if self.dp_replicate_enabled: + dp_mesh_dim_names.append("dp_replicate") + dp_cp_mesh_dim_names.append("dp_replicate") + if self.dp_shard_enabled: + dp_mesh_dim_names.append("dp_shard") + dp_shard_cp_mesh_dim_names.append("dp_shard") + dp_cp_mesh_dim_names.append("dp_shard") + if self.cp_enabled: + dp_shard_cp_mesh_dim_names.append("cp") + dp_cp_mesh_dim_names.append("cp") + + if dp_mesh_dim_names != []: + mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp") + if dp_shard_cp_mesh_dim_names != []: + mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten( + mesh_dim_name="dp_shard_cp" + ) + if dp_cp_mesh_dim_names != []: + mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp") + return mesh @property @@ -74,6 +94,18 @@ def dp_replicate_enabled(self): @property def dp_shard_enabled(self): return self.dp_shard > 1 + + @property + def dp_cp_enabled(self): + return self.dp_enabled or self.cp_enabled + + @property + def cp_enabled(self): + return self.cp > 1 + + @property + def fsdp_enabled(self): + return self.dp_shard_enabled or self.cp_enabled @property def tp_enabled(self): diff --git a/maester/parallelisms/parallelize_gemma.py b/maester/parallelisms/parallelize_gemma.py index 66bc0e2..debab6c 100644 --- a/maester/parallelisms/parallelize_gemma.py +++ b/maester/parallelisms/parallelize_gemma.py @@ -66,10 +66,9 @@ def parallelize_gemma( apply_fsdp( model, - dp_mesh, + world_mesh[tuple(dp_mesh_dim_names)], param_dtype=TORCH_DTYPE_MAP[config.mixed_precision_param], reduce_dtype=TORCH_DTYPE_MAP[config.mixed_precision_reduce], - tp_enabled=parallel_dims.tp_enabled, #pp_enabled=parallel_dims.pp_enabled, ) @@ -270,7 +269,6 @@ def apply_fsdp( dp_mesh: DeviceMesh, param_dtype: torch.dtype, reduce_dtype: torch.dtype, - tp_enabled: bool, pp_enabled: bool = False, ): """Apply FSDP to Gemma model.""" From 3051ee2150bff7ba73e622f8d8e9c157c4fd113e Mon Sep 17 00:00:00 2001 From: Rasmus Larsen Date: Sun, 7 Sep 2025 13:34:27 +0200 Subject: [PATCH 2/7] changes for supporting cp for gemma w/ flex attn --- maester/models/gemma/model.py | 38 +++++++++++++++------ maester/parallelisms/parallel_dims.py | 6 +++- maester/parallelisms/parallelize_gemma.py | 6 ++-- train.py | 41 +++++++++++++++++++++-- 4 files changed, 73 insertions(+), 18 deletions(-) diff --git a/maester/models/gemma/model.py b/maester/models/gemma/model.py index 446edff..dfbdcba 100644 --- a/maester/models/gemma/model.py +++ b/maester/models/gemma/model.py @@ -40,6 +40,7 @@ class ModelArgs: vision_config: dict | None = None # For multimodal models tied_embeddings: bool = True # For training compatibility init_std: float = 0.02 # For weight initialization + attention_backend: str = "flex" # "eager", "flex", or "sdpa", but "flex" is recommended as the others might be incorrect def precompute_freqs_cis(dim: int, end: int, @@ -216,6 +217,18 @@ def _ensure_long(val): return wrapped_mask_fn +@torch._dynamo.disable +def _no_compile_sdpa(q, k, v, scale: float, is_causal: bool = True, attn_mask: torch.Tensor | None = None): + # q,k,v: [B, H, S, D]; CP sharding on S (dim=2) + with sdpa_kernel(SDPBackend.FLASH_ATTENTION): + return F.scaled_dot_product_attention( + q, k, v, + dropout_p=0.0, + attn_mask=attn_mask, + scale=scale, + is_causal=is_causal, + ) + class GemmaAttention(nn.Module): def __init__( @@ -348,13 +361,15 @@ class Gemma2DecoderLayer(nn.Module): def __init__( self, config: ModelArgs, - attn_type: str + attn_type: str, + device_mesh: DeviceMesh | None ): super().__init__() self.attn_type = attn_type self.self_attn = GemmaAttention( config=config, - attn_type=attn_type + attn_type=attn_type, + device_mesh=device_mesh ) self.mlp = GemmaMLP( hidden_size=config.dim, @@ -419,7 +434,7 @@ def init_weights(self, init_std: float): self.mlp.init_weights(init_std) class GemmaModel(nn.Module): - def __init__(self, config: ModelArgs): + def __init__(self, config: ModelArgs, device_mesh: DeviceMesh | None): super().__init__() self.config = config self.vocab_size = config.vocab_size @@ -431,7 +446,7 @@ def __init__(self, config: ModelArgs): if config.attn_types is not None else "global" ) - self.layers.append(Gemma2DecoderLayer(config, attn_type)) + self.layers.append(Gemma2DecoderLayer(config, attn_type, device_mesh=device_mesh)) self.norm = RMSNorm(config.dim, eps=config.rms_norm_eps) def forward( @@ -445,7 +460,7 @@ def forward( layer: Gemma2DecoderLayer = self.layers[i] # type: ignore hidden_states = layer( hidden_states=hidden_states, - freqs_cis=freqs_cis.get(layer.attn_type), + freqs_cis=freqs_cis[layer.attn_type], mask=mask, local_mask=local_mask, ) @@ -461,13 +476,14 @@ def init_weights(self, init_std: float): class GemmaTextModel(nn.Module): """Text-only Gemma model compatible with training setup.""" - def __init__(self, config: ModelArgs): + def __init__(self, config: ModelArgs, device_mesh: DeviceMesh | None = None): super().__init__() self.config = config self.model_args = config # For compatibility with training code self.vocab_size = config.vocab_size self.n_layers = config.n_layers - + self.device_mesh = device_mesh + # Text embeddings self.tok_embeddings = Embedding( num_embeddings=config.vocab_size, @@ -475,8 +491,8 @@ def __init__(self, config: ModelArgs): ) # Core transformer model - self.model = GemmaModel(config) - + self.model = GemmaModel(config, device_mesh=device_mesh) + # Precompute RoPE frequencies following multimodal pattern head_dim = config.head_dim max_seq_len = config.max_seq_len @@ -652,9 +668,9 @@ def forward( return output @classmethod - def from_model_args(cls, model_args: ModelArgs) -> "GemmaTextModel": + def from_model_args(cls, model_args: ModelArgs, device_mesh: DeviceMesh | None = None) -> "GemmaTextModel": """Initialize from model args (compatible with training loop).""" - return cls(model_args) + return cls(model_args, device_mesh=device_mesh) class Gemma3MultiModalModel(nn.Module): diff --git a/maester/parallelisms/parallel_dims.py b/maester/parallelisms/parallel_dims.py index 52b180d..1778273 100644 --- a/maester/parallelisms/parallel_dims.py +++ b/maester/parallelisms/parallel_dims.py @@ -117,4 +117,8 @@ def loss_parallel_enabled(self): @cached_property def model_parallel_size(self): - return self.tp \ No newline at end of file + return self.tp + + @cached_property + def non_data_parallel_size(self): + return self.cp * self.tp \ No newline at end of file diff --git a/maester/parallelisms/parallelize_gemma.py b/maester/parallelisms/parallelize_gemma.py index debab6c..1a75717 100644 --- a/maester/parallelisms/parallelize_gemma.py +++ b/maester/parallelisms/parallelize_gemma.py @@ -53,7 +53,7 @@ def parallelize_gemma( # Compile each layer individually if config.compile: - apply_compile(model) + apply_compile(model, fullgraph=not parallel_dims.cp_enabled) # TODO: fullgraph for CP? # Apply FSDP use_fsdp = parallel_dims.dp_enabled or ( @@ -256,10 +256,10 @@ def apply_ac(model: nn.Module, config: Config): logger.info("Applied activation checkpointing to the model") -def apply_compile(model: nn.Module): +def apply_compile(model: nn.Module, fullgraph: bool = False): """Compile each transformer layer individually.""" for layer_id, layer in enumerate(model.model.layers): - compiled_layer = torch.compile(layer, fullgraph=True) + compiled_layer = torch.compile(layer, fullgraph=fullgraph) model.model.layers[layer_id] = compiled_layer logger.info("Compiled each transformer layer with torch.compile") diff --git a/train.py b/train.py index bc641c0..a80baea 100644 --- a/train.py +++ b/train.py @@ -19,6 +19,8 @@ from torch.distributed.checkpoint.stateful import Stateful from torch.distributed.elastic.multiprocessing.errors import record from torch.distributed.tensor.parallel import loss_parallel +from torch.distributed.tensor.experimental import context_parallel +from torch.distributed.tensor.experimental._attention import set_rotate_method from transformers import AutoTokenizer, PreTrainedTokenizerFast @@ -97,6 +99,7 @@ def main(): dp_shard=cfg.data_parallel_shard_degree, dp_replicate=cfg.data_parallel_replicate_degree, tp=cfg.tensor_parallel_degree, + cp=cfg.context_parallel_degree, world_size=world_size, enable_loss_parallel=cfg.enable_loss_parallel, ) @@ -121,7 +124,17 @@ def main(): else: dp_degree, dp_rank = 1, 0 logger.info(f"world mesh: {world_mesh}") - #logger.info(f"dp mesh: {dp_mesh}") + # logger.info(f"dp mesh: {dp_mesh}") + + if parallel_dims.cp_enabled: # the following is necessary for CP w/ flex attention + from torch.distributed.tensor.experimental._attention import _set_cp_global_var, _DispatchMode, _cp_options + + # set_rotate_method("alltoall") # alltoall or allgather (only allgather for flex) + _set_cp_global_var("cp_shard_dim", 2) + # _cp_options.enable_load_balance = True # no load balancing for flex + torch.distributed.tensor.experimental._attention._dispatch_mode = ( + _DispatchMode.TORCH_FUNCTION + ) # Get tokenizer to determine vocab size if os.path.isfile(cfg.tokenizer_name): @@ -155,7 +168,7 @@ def main(): logger.info( f"Building {cfg.model_name} {cfg.flavor} with {model_config}" ) - model = model_cls.from_model_args(model_config) + model = model_cls.from_model_args(model_config, world_mesh["cp"]) # log model size model_param_count = get_num_params(model) @@ -361,6 +374,28 @@ def loss_fn(pred, labels): # Get document_ids if available (for flex attention document masking in packed data) document_ids = batch.get("document_ids", None) + + buffers = [input_ids, labels] + buffer_seq_dims = [1, 1] # shard on seq dim + if hasattr(model, 'freqs_cis'): + buffers.extend([model.freqs_cis]) + buffer_seq_dims.extend([0]) + elif hasattr(model, 'local_freqs_cis') and hasattr(model, 'global_freqs_cis'): + buffers.extend([model.local_freqs_cis, model.global_freqs_cis]) + buffer_seq_dims.extend([0, 0]) + context_parallel_ctx = context_parallel( + world_mesh["cp"], + buffers=buffers, + buffer_seq_dims=buffer_seq_dims, + no_restore_buffers={input_ids, labels}, # don't restore + ) if parallel_dims.cp_enabled else contextlib.nullcontext() + + # non-pp loss parallel, pp is not implemented + with loss_parallel_ctx(), context_parallel_ctx: + if cfg.enable_cut_cross_entropy: + loss = model(input_ids, labels) # using cut cross-entropy fused kernel + else: + pred = model(input_ids) # Collect padding stats if available (SFT mode) if "stats" in batch and "actual_lengths" in batch["stats"]: @@ -456,7 +491,7 @@ def loss_fn(pred, labels): time_delta = timer() - time_last_log total_tokens += ntokens_since_last_log - tps = ntokens_since_last_log / (time_delta * parallel_dims.model_parallel_size) + tps = ntokens_since_last_log / (time_delta * parallel_dims.non_data_parallel_size) mfu = 100 * num_flop_per_token * tps / gpu_peak_flops time_end_to_end = time_delta / cfg.log_freq From ba14de0756bd78a126515cad112fd5806736ddac Mon Sep 17 00:00:00 2001 From: Rasmus Larsen Date: Sun, 7 Sep 2025 14:14:17 +0200 Subject: [PATCH 3/7] fix non-cp parallelism --- maester/models/gemma/model.py | 20 ++++++++++---------- train.py | 2 +- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/maester/models/gemma/model.py b/maester/models/gemma/model.py index dfbdcba..c577b60 100644 --- a/maester/models/gemma/model.py +++ b/maester/models/gemma/model.py @@ -235,7 +235,7 @@ def __init__( self, config: ModelArgs, attn_type: str, - device_mesh = None + cp_device_mesh = None ): super().__init__() @@ -362,14 +362,14 @@ def __init__( self, config: ModelArgs, attn_type: str, - device_mesh: DeviceMesh | None + cp_device_mesh: DeviceMesh | None ): super().__init__() self.attn_type = attn_type self.self_attn = GemmaAttention( config=config, attn_type=attn_type, - device_mesh=device_mesh + cp_device_mesh=cp_device_mesh ) self.mlp = GemmaMLP( hidden_size=config.dim, @@ -434,7 +434,7 @@ def init_weights(self, init_std: float): self.mlp.init_weights(init_std) class GemmaModel(nn.Module): - def __init__(self, config: ModelArgs, device_mesh: DeviceMesh | None): + def __init__(self, config: ModelArgs, cp_device_mesh: DeviceMesh | None): super().__init__() self.config = config self.vocab_size = config.vocab_size @@ -446,7 +446,7 @@ def __init__(self, config: ModelArgs, device_mesh: DeviceMesh | None): if config.attn_types is not None else "global" ) - self.layers.append(Gemma2DecoderLayer(config, attn_type, device_mesh=device_mesh)) + self.layers.append(Gemma2DecoderLayer(config, attn_type, cp_device_mesh=cp_device_mesh)) self.norm = RMSNorm(config.dim, eps=config.rms_norm_eps) def forward( @@ -476,13 +476,13 @@ def init_weights(self, init_std: float): class GemmaTextModel(nn.Module): """Text-only Gemma model compatible with training setup.""" - def __init__(self, config: ModelArgs, device_mesh: DeviceMesh | None = None): + def __init__(self, config: ModelArgs, cp_device_mesh: DeviceMesh | None = None): super().__init__() self.config = config self.model_args = config # For compatibility with training code self.vocab_size = config.vocab_size self.n_layers = config.n_layers - self.device_mesh = device_mesh + self.cp_device_mesh = cp_device_mesh # Text embeddings self.tok_embeddings = Embedding( @@ -491,7 +491,7 @@ def __init__(self, config: ModelArgs, device_mesh: DeviceMesh | None = None): ) # Core transformer model - self.model = GemmaModel(config, device_mesh=device_mesh) + self.model = GemmaModel(config, cp_device_mesh=cp_device_mesh) # Precompute RoPE frequencies following multimodal pattern head_dim = config.head_dim @@ -668,9 +668,9 @@ def forward( return output @classmethod - def from_model_args(cls, model_args: ModelArgs, device_mesh: DeviceMesh | None = None) -> "GemmaTextModel": + def from_model_args(cls, model_args: ModelArgs, cp_device_mesh: DeviceMesh | None = None) -> "GemmaTextModel": """Initialize from model args (compatible with training loop).""" - return cls(model_args, device_mesh=device_mesh) + return cls(model_args, cp_device_mesh=cp_device_mesh) class Gemma3MultiModalModel(nn.Module): diff --git a/train.py b/train.py index a80baea..aacaa39 100644 --- a/train.py +++ b/train.py @@ -168,7 +168,7 @@ def main(): logger.info( f"Building {cfg.model_name} {cfg.flavor} with {model_config}" ) - model = model_cls.from_model_args(model_config, world_mesh["cp"]) + model = model_cls.from_model_args(model_config, cp_device_mesh=world_mesh["cp"] if parallel_dims.cp_enabled else None) # log model size model_param_count = get_num_params(model) From d9800e65cfcb721f3135d61f49151452be31ec54 Mon Sep 17 00:00:00 2001 From: Rasmus Larsen Date: Sun, 7 Sep 2025 15:28:59 +0200 Subject: [PATCH 4/7] small changes to debug gemma model --- maester/models/gemma/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/maester/models/gemma/__init__.py b/maester/models/gemma/__init__.py index 650c3e0..53b9f87 100644 --- a/maester/models/gemma/__init__.py +++ b/maester/models/gemma/__init__.py @@ -6,12 +6,12 @@ "debug": ModelArgs( vocab_size=262_144, # Actual size from google/gemma-3-1b-pt tokenizer dim=1152, - n_layers=4, + n_layers=5, n_heads=4, num_key_value_heads=1, head_dim=256, intermediate_size=6912, - attn_types=["local_sliding", "local_sliding", "local_sliding", "local_sliding", "local_sliding", "global"], + attn_types=["local_sliding", "local_sliding", "global", "local_sliding", "local_sliding"], use_post_ffw_norm=True, use_pre_ffw_norm=True, sliding_window_size=512, From c7969e9ca3bc2463cca83494ef80c8df5e7e3881 Mon Sep 17 00:00:00 2001 From: Oliver Kinch Date: Mon, 10 Nov 2025 12:48:14 +0100 Subject: [PATCH 5/7] llama context parallelism --- maester/models/llama/model.py | 5 ++- maester/parallelisms/parallelize_llama.py | 40 ++++++++++------------- 2 files changed, 22 insertions(+), 23 deletions(-) diff --git a/maester/models/llama/model.py b/maester/models/llama/model.py index f426ead..8746df2 100644 --- a/maester/models/llama/model.py +++ b/maester/models/llama/model.py @@ -20,6 +20,8 @@ from maester.models.norms import create_norm from maester.models.llama.tied_linear import TiedLinear +from torch.distributed.device_mesh import DeviceMesh + @dataclass class ModelArgs: @@ -490,12 +492,13 @@ def forward( return output @classmethod - def from_model_args(cls, model_args: ModelArgs) -> "Transformer": + def from_model_args(cls, model_args: ModelArgs, cp_device_mesh: Optional[DeviceMesh] = None) -> "Transformer": """ Initialize a Transformer model from a ModelArgs object. Args: model_args (ModelArgs): Model configuration arguments. + cp_device_mesh (Optional[DeviceMesh]): Device mesh for context parallelism. Returns: Transformer: Transformer model. diff --git a/maester/parallelisms/parallelize_llama.py b/maester/parallelisms/parallelize_llama.py index 41dc9d8..4eeb3d8 100644 --- a/maester/parallelisms/parallelize_llama.py +++ b/maester/parallelisms/parallelize_llama.py @@ -60,28 +60,24 @@ def parallelize_llama( "fused_rmsnorm is not compatible with torch.compile yet. " "Please use rmsnorm or layernorm." ) - apply_compile(model) + apply_compile(model, fullgraph=not parallel_dims.cp_enabled) - use_fsdp = parallel_dims.dp_shard_enabled or ( - world_mesh.ndim == 1 and world_mesh.size() == 1 - ) - - if use_fsdp: - if parallel_dims.dp_shard_enabled: - if parallel_dims.dp_replicate_enabled: - dp_mesh = world_mesh["dp_replicate", "dp_shard"] + if parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled: + if parallel_dims.dp_replicate_enabled: + if parallel_dims.cp_enabled: + dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") else: - dp_mesh = world_mesh["dp"] + dp_mesh_dim_names = ("dp_replicate", "dp_shard") else: - dp_mesh = world_mesh if world_mesh.ndim == 1 else world_mesh["dp"] - - apply_fsdp( - model, - dp_mesh, - param_dtype=TORCH_DTYPE_MAP[config.mixed_precision_param], - reduce_dtype=TORCH_DTYPE_MAP[config.mixed_precision_reduce], - ) - if parallel_dims.dp_shard_enabled and parallel_dims.dp_replicate_enabled: + if parallel_dims.cp_enabled: + dp_mesh_dim_names = ("dp_shard_cp",) + else: + dp_mesh_dim_names = ("dp",) + + dp_mesh = world_mesh[tuple(dp_mesh_dim_names)] + apply_fsdp(model, dp_mesh, param_dtype=TORCH_DTYPE_MAP[config.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[config.mixed_precision_reduce]) + if parallel_dims.dp_replicate_enabled: logger.info("Applied HSDP to the model") else: logger.info("Applied FSDP to the model") @@ -250,16 +246,16 @@ def apply_ac(model: nn.Module, ac_config: Config): logger.info(f"Applied {ac_config.ac_mode} activation checkpointing to the model") -def apply_compile(model: nn.Module): +def apply_compile(model: nn.Module, fullgraph: bool = True): """ Apply torch.compile to each TransformerBlock, which makes compilation efficient due to repeated structure. Alternatively one can compile the whole model (after applying DP). """ for layer_id, transformer_block in model.layers.named_children(): - transformer_block = torch.compile(transformer_block, fullgraph=True) + transformer_block = torch.compile(transformer_block, fullgraph=fullgraph) model.layers.register_module(layer_id, transformer_block) - logger.info("Compiling each TransformerBlock with torch.compile") + logger.info(f"Compiling each TransformerBlock with torch.compile (fullgraph={fullgraph})") def apply_fsdp( From 9b7f76632afa367005b18081fcc4ca0ad9d1e46a Mon Sep 17 00:00:00 2001 From: Oliver Kinch Date: Tue, 11 Nov 2025 09:34:14 +0100 Subject: [PATCH 6/7] import devicemesh --- maester/models/gemma/model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/maester/models/gemma/model.py b/maester/models/gemma/model.py index c577b60..1eb2ee3 100644 --- a/maester/models/gemma/model.py +++ b/maester/models/gemma/model.py @@ -10,6 +10,8 @@ from cut_cross_entropy import linear_cross_entropy, LinearCrossEntropyImpl +from torch.distributed import DeviceMesh + @dataclass class ModelArgs: """ From 39b5f2a6cf7786d453e6cd77cb3bbfed541dec4d Mon Sep 17 00:00:00 2001 From: Oliver Kinch Date: Wed, 12 Nov 2025 15:26:51 +0100 Subject: [PATCH 7/7] Ignore data and models folders --- .gitignore | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 5f8ee34..6d0e493 100644 --- a/.gitignore +++ b/.gitignore @@ -19,4 +19,8 @@ logs/ separate-logs/ *.distcp -wandb/ \ No newline at end of file +wandb/ + +data/ + +models/