Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions spleeter_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""Spleeter 4-stem separator in PyTorch (no TensorFlow required)."""
from .separator import SpleeterPyTorch, INSTRUMENTS, SAMPLE_RATE

__all__ = ['SpleeterPyTorch', 'INSTRUMENTS', 'SAMPLE_RATE']
112 changes: 112 additions & 0 deletions spleeter_pytorch/checkpoint_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
"""Read a TF1 checkpoint (.index + .data) without TensorFlow.

Parses the LevelDB SSTable .index file to extract variable names and their
offsets/sizes in the .data file, then reads the raw float32 tensors and
returns them as numpy arrays or PyTorch tensors.

Only supports float32 (dtype=1) tensors (all Spleeter weights are float32).
"""

import struct
import numpy as np
import torch


def _read_varint(data: bytes, pos: int):
result, shift = 0, 0
while pos < len(data):
b = data[pos]; pos += 1
result |= (b & 0x7f) << shift
if not (b & 0x80):
break
shift += 7
return result, pos


def _parse_bundle_entry(value: bytes):
"""Decode a BundleEntryProto to extract (dtype, offset, size)."""
pos = 0
dtype = 1 # default: float32
offset = 0 # default: start of file
size = None
while pos < len(value):
tag, pos = _read_varint(value, pos)
field_num = tag >> 3
wire_type = tag & 7
if wire_type == 0:
v, pos = _read_varint(value, pos)
if field_num == 1: dtype = v
if field_num == 4: offset = v
if field_num == 5: size = v
elif wire_type == 2:
l, pos = _read_varint(value, pos)
pos += l
elif wire_type == 5:
pos += 4
elif wire_type == 1:
pos += 8
else:
break
return dtype, offset, size


def _parse_block(block_data: bytes):
"""Parse an SSTable data block: sequence of (key, value) pairs."""
n = len(block_data)
num_restarts = struct.unpack_from('<I', block_data, n - 4)[0]
content_end = n - 4 - num_restarts * 4

entries = []
pos = 0
last_key = b''
while pos < content_end:
shared, pos = _read_varint(block_data, pos)
unshared, pos = _read_varint(block_data, pos)
vlen, pos = _read_varint(block_data, pos)
if pos + unshared + vlen > content_end + 1:
break
key = last_key[:shared] + block_data[pos: pos + unshared]
pos += unshared
value = block_data[pos: pos + vlen]
pos += vlen
last_key = key
entries.append((key, value))
return entries


def read_checkpoint(index_path: str) -> dict:
"""Return a dict mapping variable_name → (offset_in_data, size_bytes, dtype)."""
with open(index_path, 'rb') as f:
raw = f.read()

footer = raw[-48:]
pos = 0
_, pos = _read_varint(footer, pos) # metaindex offset (skip)
_, pos = _read_varint(footer, pos) # metaindex size (skip)
idx_off, pos = _read_varint(footer, pos)
idx_sz, _ = _read_varint(footer, pos)

# Index block → find all data block locations
idx_block = raw[idx_off: idx_off + idx_sz]
idx_entries = _parse_block(idx_block)

# Parse data blocks
var_info = {}
for _, handle_bytes in idx_entries:
blk_off, p2 = _read_varint(handle_bytes, 0)
blk_sz, _ = _read_varint(handle_bytes, p2)
blk = raw[blk_off: blk_off + blk_sz]
for name_bytes, val in _parse_block(blk):
name = name_bytes.decode('utf-8', errors='replace')
dtype, off, sz = _parse_bundle_entry(val)
var_info[name] = (off, sz, dtype)
return var_info


def load_tensor(data_path: str, offset: int, size: int) -> torch.Tensor:
"""Load a float32 tensor from the .data file at the given offset."""
with open(data_path, 'rb') as f:
f.seek(offset)
raw = f.read(size)
arr = np.frombuffer(raw, dtype=np.float32).copy()
return torch.from_numpy(arr)
201 changes: 201 additions & 0 deletions spleeter_pytorch/convert_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
"""Convert Spleeter TF1 checkpoint → PyTorch .pt files (4-stem or 5-stem).

Usage:
python convert_weights.py --checkpoint /path/to/4stems/model \\
--output-dir ./models4
python convert_weights.py --checkpoint /path/to/5stems/model \\
--output-dir ./models5 --stems 5

Instrument order in the checkpoints:
4stems (from spleeter/resources/4stems.json): vocals, drums, bass, other
5stems (from spleeter/resources/5stems.json): vocals, piano, drums, bass, other

Each instrument's U-Net occupies a consecutive block of TF variable indices:
4stems:
instrument 0 (vocals): conv2d_0..6, conv2d_transpose_0..5, bn_0..11
instrument 1 (drums): conv2d_7..13, conv2d_transpose_6..11, bn_12..23
instrument 2 (bass): conv2d_14..20, conv2d_transpose_12..17, bn_24..35
instrument 3 (other): conv2d_21..27, conv2d_transpose_18..23, bn_36..47
5stems (same offsets, one more instrument appended):
instrument 0 (vocals): conv2d_0..6, conv2d_transpose_0..5, bn_0..11
instrument 1 (piano): conv2d_7..13, conv2d_transpose_6..11, bn_12..23
instrument 2 (drums): conv2d_14..20, conv2d_transpose_12..17, bn_24..35
instrument 3 (bass): conv2d_21..27, conv2d_transpose_18..23, bn_36..47
instrument 4 (other): conv2d_28..34, conv2d_transpose_24..29, bn_48..59
"""

import argparse
import os
import sys

import torch

# Allow running from the repo root
sys.path.insert(0, os.path.dirname(__file__))
from unet import UNet
from checkpoint_reader import read_checkpoint, load_tensor


INSTRUMENTS_4 = ['vocals', 'drums', 'bass', 'other']
INSTRUMENTS_5 = ['vocals', 'piano', 'drums', 'bass', 'other']


def build_state_dict(var_info: dict, data_path: str, stem_idx: int) -> dict:
"""Build a PyTorch state dict for one instrument's UNet.

Args:
var_info: dict from read_checkpoint (name → (offset, size, dtype))
data_path: path to .data-00000-of-00001 file
stem_idx: 0=vocals, 1=drums, 2=bass, 3=other
"""
def get(name: str) -> torch.Tensor:
entry = var_info.get(name)
if entry is None:
raise KeyError(f"Variable '{name}' not found in checkpoint")
off, sz, _ = entry
return load_tensor(data_path, off, sz)

def conv(tf_name: str) -> torch.Tensor:
# TF Conv2D kernel: (H, W, in_ch, out_ch) → PyTorch (out_ch, in_ch, H, W)
return get(tf_name).reshape(-1).view(0, 0, 0, 0) # placeholder; real reshape below

def kernel(tf_name: str, shape_hwio) -> torch.Tensor:
"""Load a TF Conv2D/ConvTranspose2D kernel and permute to PyTorch layout."""
t = get(tf_name)
H, W, I, O = shape_hwio
t = t.reshape(H, W, I, O)
return t.permute(3, 2, 0, 1).contiguous() # HWIO → OIHW

def tkernel(tf_name: str, shape_hwoi) -> torch.Tensor:
"""Load a TF ConvTranspose2D kernel: (H, W, out_ch, in_ch) → PyTorch (in_ch, out_ch, H, W)."""
t = get(tf_name)
H, W, O, I = shape_hwoi
t = t.reshape(H, W, O, I)
return t.permute(3, 2, 0, 1).contiguous() # HWOI → IOIHW

co = stem_idx * 7 # conv2d offset (0, 7, 14, 21)
to = stem_idx * 6 # transpose offset (0, 6, 12, 18)
bo = stem_idx * 12 # batch-norm offset

def cname(i): return f"conv2d_{i}/kernel" if i else "conv2d/kernel"
def cbname(i): return f"conv2d_{i}/bias" if i else "conv2d/bias"
def tnname(i): return f"conv2d_transpose_{i}/kernel" if i else "conv2d_transpose/kernel"
def tbname(i): return f"conv2d_transpose_{i}/bias" if i else "conv2d_transpose/bias"
def bnname(i, attr): return f"batch_normalization_{i}/{attr}" if i else f"batch_normalization/{attr}"

# Encoder conv shapes: (H, W, in, out)
enc_shapes = [
(5, 5, 2, 16),
(5, 5, 16, 32),
(5, 5, 32, 64),
(5, 5, 64, 128),
(5, 5, 128, 256),
(5, 5, 256, 512),
]
# Decoder conv-transpose shapes: (H, W, out, in)
# PyTorch ConvTranspose2d(in, out) has weight (in, out, H, W)
# TF has (H, W, out, in) where out=smaller side after transpose
dec_shapes = [
(5, 5, 256, 512),
(5, 5, 128, 512),
(5, 5, 64, 256),
(5, 5, 32, 128),
(5, 5, 16, 64),
(5, 5, 1, 32),
]

sd = {}

# Encoder
enc_py_names = ['conv', 'conv1', 'conv2', 'conv3', 'conv4', 'conv5']
enc_bn_py = ['bn', 'bn1', 'bn2', 'bn3', 'bn4', None]
# TF bn indices for encoder (skipping conv5's BN since PyTorch UNet doesn't use it)
enc_bn_tf = [bo+0, bo+1, bo+2, bo+3, bo+4, None]

for i, (py_name, shape) in enumerate(zip(enc_py_names, enc_shapes)):
tf_k = cname(co + i)
tf_b = cbname(co + i)
sd[f"{py_name}.weight"] = kernel(tf_k, shape)
sd[f"{py_name}.bias"] = get(tf_b)

bn_py = enc_bn_py[i]
bn_tf = enc_bn_tf[i]
if bn_py is not None:
bn_i = bn_tf
sd[f"{bn_py}.weight"] = get(bnname(bn_i, 'gamma'))
sd[f"{bn_py}.bias"] = get(bnname(bn_i, 'beta'))
sd[f"{bn_py}.running_mean"] = get(bnname(bn_i, 'moving_mean'))
sd[f"{bn_py}.running_var"] = get(bnname(bn_i, 'moving_variance'))

# Decoder conv-transpose (up1–up6) and their BNs
dec_py_names = ['up1', 'up2', 'up3', 'up4', 'up5', 'up6']
dec_bn_py = ['bn5', 'bn6', 'bn7', 'bn8', 'bn9', 'bn10']
# TF BN indices for decoder: bn_6, bn_7, ... bn_11 for stem 0 (and +12 per stem)
dec_bn_tf_base = bo + 6 # stem0=6, stem1=18, stem2=30, stem3=42

for i, (py_name, shape) in enumerate(zip(dec_py_names, dec_shapes)):
tf_k = tnname(to + i)
tf_b = tbname(to + i)
sd[f"{py_name}.weight"] = tkernel(tf_k, shape)
sd[f"{py_name}.bias"] = get(tf_b)

bn_py = dec_bn_py[i]
bn_i = dec_bn_tf_base + i
sd[f"{bn_py}.weight"] = get(bnname(bn_i, 'gamma'))
sd[f"{bn_py}.bias"] = get(bnname(bn_i, 'beta'))
sd[f"{bn_py}.running_mean"] = get(bnname(bn_i, 'moving_mean'))
sd[f"{bn_py}.running_var"] = get(bnname(bn_i, 'moving_variance'))

# up7: regular Conv2d(1, 2, kernel_size=4, dilation=2, padding=3)
# TF stores as conv2d_{co+6}/kernel with shape (4, 4, 1, 2)
up7_tf = cname(co + 6)
sd['up7.weight'] = kernel(up7_tf, (4, 4, 1, 2))
sd['up7.bias'] = get(cbname(co + 6))

return sd


def convert(checkpoint_prefix: str, output_dir: str, n_stems: int = 5):
"""Convert a Spleeter TF1 checkpoint to .pt files.

Args:
checkpoint_prefix: path without extension (e.g. /path/to/model)
output_dir: directory to write <instrument>.pt files
n_stems: 4 or 5
"""
instruments = INSTRUMENTS_5 if n_stems == 5 else INSTRUMENTS_4

index_path = checkpoint_prefix + '.index'
data_path = checkpoint_prefix + '.data-00000-of-00001'

print(f"[Spleeter] Reading checkpoint index: {index_path}")
var_info = read_checkpoint(index_path)
print(f"[Spleeter] Found {len(var_info)} variables ({n_stems}-stem)")

os.makedirs(output_dir, exist_ok=True)

for stem_idx, name in enumerate(instruments):
print(f"[Spleeter] Converting stem {stem_idx}: {name}")
sd = build_state_dict(var_info, data_path, stem_idx)

model = UNet()
model.load_state_dict(sd, strict=True)
model.eval()

out_path = os.path.join(output_dir, f"{name}.pt")
torch.save(model.state_dict(), out_path)
print(f"[Spleeter] → {out_path}")

print("[Spleeter] Conversion complete.")


if __name__ == '__main__':
ap = argparse.ArgumentParser()
ap.add_argument('--checkpoint', required=True,
help='Path prefix to TF checkpoint (e.g. /path/to/model)')
ap.add_argument('--output-dir', required=True,
help='Directory to write <instrument>.pt files')
ap.add_argument('--stems', type=int, default=5, choices=[4, 5],
help='Number of stems (4 or 5, default: 5)')
args = ap.parse_args()
convert(args.checkpoint, args.output_dir, n_stems=args.stems)
Loading