diff --git a/models/backbones/swin_v1.py b/models/backbones/swin_v1.py index e762bba..3c6e824 100644 --- a/models/backbones/swin_v1.py +++ b/models/backbones/swin_v1.py @@ -64,9 +64,9 @@ def window_reverse(windows, window_size, H, W): Returns: x: (B, H, W, C) """ - B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + C = int(windows.shape[-1]) + x = windows.view(-1, H // window_size, W // window_size, window_size, window_size, C) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C) return x