diff --git a/spleeter_pytorch/__init__.py b/spleeter_pytorch/__init__.py new file mode 100644 index 000000000..55268c8af --- /dev/null +++ b/spleeter_pytorch/__init__.py @@ -0,0 +1,4 @@ +"""Spleeter 4-stem separator in PyTorch (no TensorFlow required).""" +from .separator import SpleeterPyTorch, INSTRUMENTS, SAMPLE_RATE + +__all__ = ['SpleeterPyTorch', 'INSTRUMENTS', 'SAMPLE_RATE'] diff --git a/spleeter_pytorch/checkpoint_reader.py b/spleeter_pytorch/checkpoint_reader.py new file mode 100644 index 000000000..e1cf8032c --- /dev/null +++ b/spleeter_pytorch/checkpoint_reader.py @@ -0,0 +1,112 @@ +"""Read a TF1 checkpoint (.index + .data) without TensorFlow. + +Parses the LevelDB SSTable .index file to extract variable names and their +offsets/sizes in the .data file, then reads the raw float32 tensors and +returns them as numpy arrays or PyTorch tensors. + +Only supports float32 (dtype=1) tensors (all Spleeter weights are float32). +""" + +import struct +import numpy as np +import torch + + +def _read_varint(data: bytes, pos: int): + result, shift = 0, 0 + while pos < len(data): + b = data[pos]; pos += 1 + result |= (b & 0x7f) << shift + if not (b & 0x80): + break + shift += 7 + return result, pos + + +def _parse_bundle_entry(value: bytes): + """Decode a BundleEntryProto to extract (dtype, offset, size).""" + pos = 0 + dtype = 1 # default: float32 + offset = 0 # default: start of file + size = None + while pos < len(value): + tag, pos = _read_varint(value, pos) + field_num = tag >> 3 + wire_type = tag & 7 + if wire_type == 0: + v, pos = _read_varint(value, pos) + if field_num == 1: dtype = v + if field_num == 4: offset = v + if field_num == 5: size = v + elif wire_type == 2: + l, pos = _read_varint(value, pos) + pos += l + elif wire_type == 5: + pos += 4 + elif wire_type == 1: + pos += 8 + else: + break + return dtype, offset, size + + +def _parse_block(block_data: bytes): + """Parse an SSTable data block: sequence of (key, value) pairs.""" + n = len(block_data) + num_restarts = struct.unpack_from(' content_end + 1: + break + key = last_key[:shared] + block_data[pos: pos + unshared] + pos += unshared + value = block_data[pos: pos + vlen] + pos += vlen + last_key = key + entries.append((key, value)) + return entries + + +def read_checkpoint(index_path: str) -> dict: + """Return a dict mapping variable_name → (offset_in_data, size_bytes, dtype).""" + with open(index_path, 'rb') as f: + raw = f.read() + + footer = raw[-48:] + pos = 0 + _, pos = _read_varint(footer, pos) # metaindex offset (skip) + _, pos = _read_varint(footer, pos) # metaindex size (skip) + idx_off, pos = _read_varint(footer, pos) + idx_sz, _ = _read_varint(footer, pos) + + # Index block → find all data block locations + idx_block = raw[idx_off: idx_off + idx_sz] + idx_entries = _parse_block(idx_block) + + # Parse data blocks + var_info = {} + for _, handle_bytes in idx_entries: + blk_off, p2 = _read_varint(handle_bytes, 0) + blk_sz, _ = _read_varint(handle_bytes, p2) + blk = raw[blk_off: blk_off + blk_sz] + for name_bytes, val in _parse_block(blk): + name = name_bytes.decode('utf-8', errors='replace') + dtype, off, sz = _parse_bundle_entry(val) + var_info[name] = (off, sz, dtype) + return var_info + + +def load_tensor(data_path: str, offset: int, size: int) -> torch.Tensor: + """Load a float32 tensor from the .data file at the given offset.""" + with open(data_path, 'rb') as f: + f.seek(offset) + raw = f.read(size) + arr = np.frombuffer(raw, dtype=np.float32).copy() + return torch.from_numpy(arr) diff --git a/spleeter_pytorch/convert_weights.py b/spleeter_pytorch/convert_weights.py new file mode 100644 index 000000000..26148af0c --- /dev/null +++ b/spleeter_pytorch/convert_weights.py @@ -0,0 +1,201 @@ +"""Convert Spleeter TF1 checkpoint → PyTorch .pt files (4-stem or 5-stem). + +Usage: + python convert_weights.py --checkpoint /path/to/4stems/model \\ + --output-dir ./models4 + python convert_weights.py --checkpoint /path/to/5stems/model \\ + --output-dir ./models5 --stems 5 + +Instrument order in the checkpoints: + 4stems (from spleeter/resources/4stems.json): vocals, drums, bass, other + 5stems (from spleeter/resources/5stems.json): vocals, piano, drums, bass, other + +Each instrument's U-Net occupies a consecutive block of TF variable indices: + 4stems: + instrument 0 (vocals): conv2d_0..6, conv2d_transpose_0..5, bn_0..11 + instrument 1 (drums): conv2d_7..13, conv2d_transpose_6..11, bn_12..23 + instrument 2 (bass): conv2d_14..20, conv2d_transpose_12..17, bn_24..35 + instrument 3 (other): conv2d_21..27, conv2d_transpose_18..23, bn_36..47 + 5stems (same offsets, one more instrument appended): + instrument 0 (vocals): conv2d_0..6, conv2d_transpose_0..5, bn_0..11 + instrument 1 (piano): conv2d_7..13, conv2d_transpose_6..11, bn_12..23 + instrument 2 (drums): conv2d_14..20, conv2d_transpose_12..17, bn_24..35 + instrument 3 (bass): conv2d_21..27, conv2d_transpose_18..23, bn_36..47 + instrument 4 (other): conv2d_28..34, conv2d_transpose_24..29, bn_48..59 +""" + +import argparse +import os +import sys + +import torch + +# Allow running from the repo root +sys.path.insert(0, os.path.dirname(__file__)) +from unet import UNet +from checkpoint_reader import read_checkpoint, load_tensor + + +INSTRUMENTS_4 = ['vocals', 'drums', 'bass', 'other'] +INSTRUMENTS_5 = ['vocals', 'piano', 'drums', 'bass', 'other'] + + +def build_state_dict(var_info: dict, data_path: str, stem_idx: int) -> dict: + """Build a PyTorch state dict for one instrument's UNet. + + Args: + var_info: dict from read_checkpoint (name → (offset, size, dtype)) + data_path: path to .data-00000-of-00001 file + stem_idx: 0=vocals, 1=drums, 2=bass, 3=other + """ + def get(name: str) -> torch.Tensor: + entry = var_info.get(name) + if entry is None: + raise KeyError(f"Variable '{name}' not found in checkpoint") + off, sz, _ = entry + return load_tensor(data_path, off, sz) + + def conv(tf_name: str) -> torch.Tensor: + # TF Conv2D kernel: (H, W, in_ch, out_ch) → PyTorch (out_ch, in_ch, H, W) + return get(tf_name).reshape(-1).view(0, 0, 0, 0) # placeholder; real reshape below + + def kernel(tf_name: str, shape_hwio) -> torch.Tensor: + """Load a TF Conv2D/ConvTranspose2D kernel and permute to PyTorch layout.""" + t = get(tf_name) + H, W, I, O = shape_hwio + t = t.reshape(H, W, I, O) + return t.permute(3, 2, 0, 1).contiguous() # HWIO → OIHW + + def tkernel(tf_name: str, shape_hwoi) -> torch.Tensor: + """Load a TF ConvTranspose2D kernel: (H, W, out_ch, in_ch) → PyTorch (in_ch, out_ch, H, W).""" + t = get(tf_name) + H, W, O, I = shape_hwoi + t = t.reshape(H, W, O, I) + return t.permute(3, 2, 0, 1).contiguous() # HWOI → IOIHW + + co = stem_idx * 7 # conv2d offset (0, 7, 14, 21) + to = stem_idx * 6 # transpose offset (0, 6, 12, 18) + bo = stem_idx * 12 # batch-norm offset + + def cname(i): return f"conv2d_{i}/kernel" if i else "conv2d/kernel" + def cbname(i): return f"conv2d_{i}/bias" if i else "conv2d/bias" + def tnname(i): return f"conv2d_transpose_{i}/kernel" if i else "conv2d_transpose/kernel" + def tbname(i): return f"conv2d_transpose_{i}/bias" if i else "conv2d_transpose/bias" + def bnname(i, attr): return f"batch_normalization_{i}/{attr}" if i else f"batch_normalization/{attr}" + + # Encoder conv shapes: (H, W, in, out) + enc_shapes = [ + (5, 5, 2, 16), + (5, 5, 16, 32), + (5, 5, 32, 64), + (5, 5, 64, 128), + (5, 5, 128, 256), + (5, 5, 256, 512), + ] + # Decoder conv-transpose shapes: (H, W, out, in) + # PyTorch ConvTranspose2d(in, out) has weight (in, out, H, W) + # TF has (H, W, out, in) where out=smaller side after transpose + dec_shapes = [ + (5, 5, 256, 512), + (5, 5, 128, 512), + (5, 5, 64, 256), + (5, 5, 32, 128), + (5, 5, 16, 64), + (5, 5, 1, 32), + ] + + sd = {} + + # Encoder + enc_py_names = ['conv', 'conv1', 'conv2', 'conv3', 'conv4', 'conv5'] + enc_bn_py = ['bn', 'bn1', 'bn2', 'bn3', 'bn4', None] + # TF bn indices for encoder (skipping conv5's BN since PyTorch UNet doesn't use it) + enc_bn_tf = [bo+0, bo+1, bo+2, bo+3, bo+4, None] + + for i, (py_name, shape) in enumerate(zip(enc_py_names, enc_shapes)): + tf_k = cname(co + i) + tf_b = cbname(co + i) + sd[f"{py_name}.weight"] = kernel(tf_k, shape) + sd[f"{py_name}.bias"] = get(tf_b) + + bn_py = enc_bn_py[i] + bn_tf = enc_bn_tf[i] + if bn_py is not None: + bn_i = bn_tf + sd[f"{bn_py}.weight"] = get(bnname(bn_i, 'gamma')) + sd[f"{bn_py}.bias"] = get(bnname(bn_i, 'beta')) + sd[f"{bn_py}.running_mean"] = get(bnname(bn_i, 'moving_mean')) + sd[f"{bn_py}.running_var"] = get(bnname(bn_i, 'moving_variance')) + + # Decoder conv-transpose (up1–up6) and their BNs + dec_py_names = ['up1', 'up2', 'up3', 'up4', 'up5', 'up6'] + dec_bn_py = ['bn5', 'bn6', 'bn7', 'bn8', 'bn9', 'bn10'] + # TF BN indices for decoder: bn_6, bn_7, ... bn_11 for stem 0 (and +12 per stem) + dec_bn_tf_base = bo + 6 # stem0=6, stem1=18, stem2=30, stem3=42 + + for i, (py_name, shape) in enumerate(zip(dec_py_names, dec_shapes)): + tf_k = tnname(to + i) + tf_b = tbname(to + i) + sd[f"{py_name}.weight"] = tkernel(tf_k, shape) + sd[f"{py_name}.bias"] = get(tf_b) + + bn_py = dec_bn_py[i] + bn_i = dec_bn_tf_base + i + sd[f"{bn_py}.weight"] = get(bnname(bn_i, 'gamma')) + sd[f"{bn_py}.bias"] = get(bnname(bn_i, 'beta')) + sd[f"{bn_py}.running_mean"] = get(bnname(bn_i, 'moving_mean')) + sd[f"{bn_py}.running_var"] = get(bnname(bn_i, 'moving_variance')) + + # up7: regular Conv2d(1, 2, kernel_size=4, dilation=2, padding=3) + # TF stores as conv2d_{co+6}/kernel with shape (4, 4, 1, 2) + up7_tf = cname(co + 6) + sd['up7.weight'] = kernel(up7_tf, (4, 4, 1, 2)) + sd['up7.bias'] = get(cbname(co + 6)) + + return sd + + +def convert(checkpoint_prefix: str, output_dir: str, n_stems: int = 5): + """Convert a Spleeter TF1 checkpoint to .pt files. + + Args: + checkpoint_prefix: path without extension (e.g. /path/to/model) + output_dir: directory to write .pt files + n_stems: 4 or 5 + """ + instruments = INSTRUMENTS_5 if n_stems == 5 else INSTRUMENTS_4 + + index_path = checkpoint_prefix + '.index' + data_path = checkpoint_prefix + '.data-00000-of-00001' + + print(f"[Spleeter] Reading checkpoint index: {index_path}") + var_info = read_checkpoint(index_path) + print(f"[Spleeter] Found {len(var_info)} variables ({n_stems}-stem)") + + os.makedirs(output_dir, exist_ok=True) + + for stem_idx, name in enumerate(instruments): + print(f"[Spleeter] Converting stem {stem_idx}: {name}") + sd = build_state_dict(var_info, data_path, stem_idx) + + model = UNet() + model.load_state_dict(sd, strict=True) + model.eval() + + out_path = os.path.join(output_dir, f"{name}.pt") + torch.save(model.state_dict(), out_path) + print(f"[Spleeter] → {out_path}") + + print("[Spleeter] Conversion complete.") + + +if __name__ == '__main__': + ap = argparse.ArgumentParser() + ap.add_argument('--checkpoint', required=True, + help='Path prefix to TF checkpoint (e.g. /path/to/model)') + ap.add_argument('--output-dir', required=True, + help='Directory to write .pt files') + ap.add_argument('--stems', type=int, default=5, choices=[4, 5], + help='Number of stems (4 or 5, default: 5)') + args = ap.parse_args() + convert(args.checkpoint, args.output_dir, n_stems=args.stems) diff --git a/spleeter_pytorch/separator.py b/spleeter_pytorch/separator.py new file mode 100644 index 000000000..cc886f9b3 --- /dev/null +++ b/spleeter_pytorch/separator.py @@ -0,0 +1,223 @@ +"""PyTorch-based Spleeter 5-stem separator (4-stem also supported). No TensorFlow required. + +Implements the exact same separation pipeline as Spleeter's librosa backend: + 1. STFT (n_fft=4096, hop=1024, periodic Hann window) + 2. Magnitude spectrogram → UNet per instrument → Wiener masks + 3. Masked complex STFT → iSTFT + +Models are loaded from .pt files produced by convert_weights.py. +Default configuration: 5-stem (vocals, piano, drums, bass, other). +""" + +import os +import sys +from typing import Dict, Optional + +import numpy as np +import torch +import torch.nn.functional as F +from librosa.core import stft as librosa_stft, istft as librosa_istft +from scipy.signal.windows import hann as scipy_hann + +sys.path.insert(0, os.path.dirname(__file__)) +from unet import UNet + +# Spleeter 4-stems parameters (from 4stems.json) +SAMPLE_RATE = 44100 +FRAME_LENGTH = 4096 # n_fft +FRAME_STEP = 1024 # hop_length +T = 512 # time frames per UNet chunk +F_BINS = 1024 # frequency bins fed to UNet (first F of n_fft//2+1=2049) +N_CHANNELS = 2 # stereo +SEP_EXPONENT = 2 # Wiener mask exponent +EPSILON = 1e-10 +WIN_COMP = 2.0 / 3.0 # window compensation factor +INSTRUMENTS = ['vocals', 'piano', 'drums', 'bass', 'other'] # 5-stem default + + +def _stft(waveform: np.ndarray) -> np.ndarray: + """Compute STFT of a stereo waveform, matching Spleeter's librosa backend. + + Args: + waveform: (n_samples, 2) float32 array. + + Returns: + Complex spectrogram (n_frames, n_fft//2+1, 2) = (n_frames, 2049, 2). + """ + N = FRAME_LENGTH + H = FRAME_STEP + win = scipy_hann(N, sym=False) + out = [] + for c in range(N_CHANNELS): + # Prepend and append N zeros (matching Spleeter's padding) + d = np.concatenate([np.zeros(N, dtype=np.float32), + waveform[:, c].astype(np.float32), + np.zeros(N, dtype=np.float32)]) + s = librosa_stft(d, n_fft=N, hop_length=H, window=win, center=False) + # s: (n_fft//2+1, n_frames) complex + s = s.T # (n_frames, 2049) + out.append(s[:, :, np.newaxis]) + return np.concatenate(out, axis=2) # (n_frames, 2049, 2) + + +def _istft(spec: np.ndarray, length: int) -> np.ndarray: + """Inverse STFT of a complex stereo spectrogram. + + Args: + spec: (n_frames, n_fft//2+1, 2) complex array. + length: Original waveform length in samples. + + Returns: + Waveform (length, 2) float32. + """ + N = FRAME_LENGTH + H = FRAME_STEP + win = scipy_hann(N, sym=False) + out = [] + for c in range(N_CHANNELS): + s = spec[:, :, c].T # (2049, n_frames) + x = librosa_istft(s, hop_length=H, win_length=N, window=win, + center=False, length=length + N) + # Trim pre-pended N zeros + x = x[N : N + length] + out.append(x[:, np.newaxis]) + waveform = np.concatenate(out, axis=1) # (length, 2) + return (waveform * WIN_COMP).astype(np.float32) + + +def _pad_and_chunk(spec_mag: np.ndarray) -> np.ndarray: + """Pad time dimension to multiple of T, then chunk into (n_chunks, T, F, 2). + + Args: + spec_mag: (n_frames, F_BINS, 2) magnitude spectrogram. + + Returns: + (n_chunks, T, F_BINS, 2) array. + """ + n_frames = spec_mag.shape[0] + pad_len = (T - n_frames % T) % T + if pad_len: + spec_mag = np.pad(spec_mag, ((0, pad_len), (0, 0), (0, 0))) + n_chunks = spec_mag.shape[0] // T + return spec_mag.reshape(n_chunks, T, F_BINS, N_CHANNELS) + + +def _run_unets(chunks: np.ndarray, models: Dict[str, UNet], + device: torch.device) -> Dict[str, np.ndarray]: + """Run each instrument's UNet on the magnitude chunks. + + Args: + chunks: (n_chunks, T, F_BINS, 2) magnitude array. + models: dict instrument → UNet (eval mode). + device: torch device. + + Returns: + Dict instrument → (n_chunks, T, F_BINS, 2) output magnitude. + """ + # PyTorch: (n_chunks, channels, T, F_BINS) = NCHW + x = torch.from_numpy(chunks).permute(0, 3, 1, 2).to(device) + outputs = {} + with torch.no_grad(): + for name, model in models.items(): + y = model(x) # (n_chunks, 2, T, F_BINS) + outputs[name] = y.permute(0, 2, 3, 1).cpu().numpy() # (n_chunks, T, F_BINS, 2) + return outputs + + +class SpleeterPyTorch: + """Spleeter 4-stem separator using PyTorch UNets (no TensorFlow).""" + + def __init__(self, model_dir: str, device: Optional[str] = None): + """ + Args: + model_dir: directory containing vocals.pt, drums.pt, bass.pt, other.pt. + device: 'cpu', 'cuda', or None for auto-detect. + """ + if device is None: + self._device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + else: + self._device = torch.device(device) + + # Detect which stems are available from .pt files + available = [n for n in INSTRUMENTS + if os.path.exists(os.path.join(model_dir, f"{n}.pt"))] + if not available: + raise FileNotFoundError( + f"No .pt model files found in {model_dir}. " + "Run convert_weights.py first." + ) + self._instruments = available + + self._models: Dict[str, UNet] = {} + for name in self._instruments: + pt_path = os.path.join(model_dir, f"{name}.pt") + if not os.path.exists(pt_path): + raise FileNotFoundError( + f"Model file not found: {pt_path}\n" + f"Run convert_weights.py first to generate .pt files." + ) + model = UNet() + model.load_state_dict(torch.load(pt_path, map_location='cpu')) + model.eval() + model.to(self._device) + self._models[name] = model + print(f"[Spleeter] Loaded {len(self._models)}-stem UNets ({', '.join(self._instruments)}) on {self._device}") + + def separate(self, waveform: np.ndarray) -> Dict[str, np.ndarray]: + """Separate a stereo waveform into instrument stems. + + Args: + waveform: (n_samples, 2) float32. Sample rate must be 44100 Hz. + + Returns: + Dict instrument → (n_samples, 2) float32 separated waveform. + """ + assert waveform.ndim == 2 and waveform.shape[1] == 2, \ + "waveform must be (n_samples, 2)" + n_samples = waveform.shape[0] + + # 1. STFT → complex spec (n_frames, 2049, 2) + mix_stft = _stft(waveform) + n_frames = mix_stft.shape[0] + + # 2. Magnitude, slice to F_BINS + mix_mag = np.abs(mix_stft[:, :F_BINS, :]) # (n_frames, 1024, 2) + + # 3. Pad and chunk + chunks = _pad_and_chunk(mix_mag) # (n_chunks, 512, 1024, 2) + n_chunks = chunks.shape[0] + + # 4. Run UNets + raw_outputs = _run_unets(chunks, self._models, self._device) + n_instruments = len(raw_outputs) + + # 5. Wiener masking + # Compute sum of outputs^SEP_EXPONENT for normalization + output_sum = np.zeros_like(chunks) + EPSILON + for out in raw_outputs.values(): + output_sum += out ** SEP_EXPONENT + + separated = {} + for name, raw_out in raw_outputs.items(): + # Wiener mask (n_chunks, T, F_BINS, 2) + mask = (raw_out ** SEP_EXPONENT + EPSILON / n_instruments) / output_sum + + # Reshape chunks back to frames: (n_chunks*T, F_BINS, 2) + mask_frames = mask.reshape(n_chunks * T, F_BINS, N_CHANNELS) + + # Extend mask to full frequency range (n_chunks*T, 2049, 2) + n_extra = FRAME_LENGTH // 2 + 1 - F_BINS # = 1025 + mask_full = np.zeros((n_chunks * T, FRAME_LENGTH // 2 + 1, N_CHANNELS), + dtype=np.float32) + mask_full[:, :F_BINS, :] = mask_frames + + # Trim to actual n_frames + mask_full = mask_full[:n_frames, :, :] + + # Apply mask to complex STFT + masked_stft = mix_stft * mask_full.astype(mix_stft.dtype) + + # iSTFT → waveform + separated[name] = _istft(masked_stft, n_samples) + + return separated diff --git a/spleeter_pytorch/unet.py b/spleeter_pytorch/unet.py new file mode 100644 index 000000000..0b4409108 --- /dev/null +++ b/spleeter_pytorch/unet.py @@ -0,0 +1,60 @@ +# PyTorch U-Net for Spleeter stem separation. +# Architecture ported from the original Deezer TF model by Fangjun Kuang (Xiaomi Corp, 2023). +# Weights are loaded from the original TF checkpoint via checkpoint_reader.py. + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class UNet(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(2, 16, kernel_size=5, stride=2, padding=0) + self.bn = nn.BatchNorm2d(16, track_running_stats=True, eps=1e-3, momentum=0.01) + self.conv1 = nn.Conv2d(16, 32, kernel_size=5, stride=2, padding=0) + self.bn1 = nn.BatchNorm2d(32, track_running_stats=True, eps=1e-3, momentum=0.01) + self.conv2 = nn.Conv2d(32, 64, kernel_size=5, stride=2, padding=0) + self.bn2 = nn.BatchNorm2d(64, track_running_stats=True, eps=1e-3, momentum=0.01) + self.conv3 = nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=0) + self.bn3 = nn.BatchNorm2d(128, track_running_stats=True, eps=1e-3, momentum=0.01) + self.conv4 = nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=0) + self.bn4 = nn.BatchNorm2d(256, track_running_stats=True, eps=1e-3, momentum=0.01) + self.conv5 = nn.Conv2d(256, 512, kernel_size=5, stride=2, padding=0) + + self.up1 = nn.ConvTranspose2d(512, 256, kernel_size=5, stride=2) + self.bn5 = nn.BatchNorm2d(256, track_running_stats=True, eps=1e-3, momentum=0.01) + self.up2 = nn.ConvTranspose2d(512, 128, kernel_size=5, stride=2) + self.bn6 = nn.BatchNorm2d(128, track_running_stats=True, eps=1e-3, momentum=0.01) + self.up3 = nn.ConvTranspose2d(256, 64, kernel_size=5, stride=2) + self.bn7 = nn.BatchNorm2d(64, track_running_stats=True, eps=1e-3, momentum=0.01) + self.up4 = nn.ConvTranspose2d(128, 32, kernel_size=5, stride=2) + self.bn8 = nn.BatchNorm2d(32, track_running_stats=True, eps=1e-3, momentum=0.01) + self.up5 = nn.ConvTranspose2d(64, 16, kernel_size=5, stride=2) + self.bn9 = nn.BatchNorm2d(16, track_running_stats=True, eps=1e-3, momentum=0.01) + self.up6 = nn.ConvTranspose2d(32, 1, kernel_size=5, stride=2) + self.bn10 = nn.BatchNorm2d(1, track_running_stats=True, eps=1e-3, momentum=0.01) + self.up7 = nn.Conv2d(1, 2, kernel_size=4, dilation=2, padding=3) + + def forward(self, x): + in_x = x # (T, 2, 512, 1024) + def pad(t): return F.pad(t, (1, 2, 1, 2), "constant", 0) + def crop(t): return t[:, :, 1:-2, 1:-2] + + act = F.elu # 4-stems model trained with ELU activations + + # Encoder: conv → BN → ELU (same order as TF model) + c1 = act(self.bn(self.conv(pad(x)))) + c2 = act(self.bn1(self.conv1(pad(c1)))) + c3 = act(self.bn2(self.conv2(pad(c2)))) + c4 = act(self.bn3(self.conv3(pad(c3)))) + c5 = act(self.bn4(self.conv4(pad(c4)))) + c6 = self.conv5(pad(c5)) # no BN/act for bottleneck (TF output is discarded) + + u1 = self.bn5(act(crop(self.up1(c6)))) + u2 = self.bn6(act(crop(self.up2(torch.cat([c5, u1], 1))))) + u3 = self.bn7(act(crop(self.up3(torch.cat([c4, u2], 1))))) + u4 = self.bn8(act(crop(self.up4(torch.cat([c3, u3], 1))))) + u5 = self.bn9(act(crop(self.up5(torch.cat([c2, u4], 1))))) + u6 = self.bn10(act(crop(self.up6(torch.cat([c1, u5], 1))))) + return torch.sigmoid(self.up7(u6)) * in_x