Skip to content

RegistrationResidualConvBlock's skip connections work only if in_channels == out_channels #7591

@SomeUserName1

Description

@SomeUserName1

Describe the bug
When using RegistrationResidualConvBlock, the way that skip connections are currently implemented only works for in_channels == out_channels.

        skip = x
        for i, (conv, norm, act) in enumerate(zip(self.layers, self.norms, self.acts)):
            x = conv(x)
            x = norm(x)
            if i == self.num_layers - 1:
                # last block
                x = x + skip 	# if in_channels != out_channels, this will error as the channel dimension doesnt match
            x = act(x)
        return x

To Reproduce
monai.networks.blocks.RegistrationResidualConvBlock(spatial_dims=3, in_channels=6, out_channels=24, num_layers=3)

Will error with

  File "venv/lib/python3.11/site-packages/monai/networks/blocks/regunet_block.py", line 123, in forward
    x = x + skip
        ~~^~~~~~
RuntimeError: The size of tensor a (24) must match the size of tensor b (6) at non-singleton dimension 1

Expected behavior
The block is usable with different numbers of channels.

A potential fix would be to only add skips after the first block which changes the number of channels.

        for i, (conv, norm, act) in enumerate(zip(self.layers, self.norms, self.acts)):
            x = conv(x)
            x = norm(x)
			if i == 0:
				skip = x
            elif i == self.num_layers - 1:
                # last block
                x = x + skip
            x = act(x)
        return x

Another option is to add skips on a layer basis, ignoring the 0th layer:

        for i, (conv, norm, act) in enumerate(zip(self.layers, self.norms, self.acts)):
            x = conv(x)
            x = norm(x)
			if i > 0:
                x = x + skip
            x = act(x)
            skip = x
        return x

Both are deviations from the principle that the input is added in the skip. However this is not possible with a change in the numbers of channels.

Environment

Details ================================ Printing MONAI config... ================================ MONAI version: 1.3.0 Numpy version: 1.26.4 Pytorch version: 2.2.1+cu121 MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False MONAI rev id: 865972f MONAI __file__: /home//workspace/UNet-bSSFP/unet-venv/lib/python3.11/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 5.2.0
scikit-image version: NOT INSTALLED or UNKNOWN VERSION.
scipy version: 1.12.0
Pillow version: 10.2.0
Tensorboard version: 2.16.2
gdown version: 4.7.3
TorchVision version: 0.17.1+cu121
tqdm version: 4.66.2
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 5.9.8
pandas version: 2.0.3
einops version: NOT INSTALLED or UNKNOWN VERSION.
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies

================================
Printing system config...

System: Linux
Linux version: Arch Linux
Platform: Linux-6.7.4-zen1-1-zen-x86_64-with-glibc2.39
Processor:
Machine: x86_64
Python version: 3.11.7
Process name: python
Command: ['python', '-c', 'import monai; monai.config.print_debug_info()']
Open files: []
Num physical CPUs: 24
Num logical CPUs: 32
Num usable CPUs: 32
CPU usage (%): [14.3, 12.1, 16.5, 11.2, 17.3, 12.1, 16.2, 11.1, 66.3, 11.3, 46.9, 12.5, 20.0, 13.1, 15.5, 24.5, 12.2, 11.3, 13.0, 11.2, 14.0, 11.1, 11.2, 11.2, 12.1, 12.1, 13.9, 11.2, 11.1, 12.1, 12.1, 11.1]
CPU freq. (MHz): 1788
Load avg. in last 1, 5, 15 mins (%): [8.5, 8.3, 12.2]
Disk usage (%): 95.5
Avg. sensor temp. (Celsius): UNKNOWN for given OS
Total physical memory (GB): 62.6
Available memory (GB): 56.7
Used memory (GB): 5.2

================================
Printing GPU config...

Num GPUs: 1
Has CUDA: True
CUDA version: 12.1
cuDNN enabled: True
NVIDIA_TF32_OVERRIDE: None
TORCH_ALLOW_TF32_CUBLAS_OVERRIDE: None
cuDNN version: 8902
Current device: 0
Library compiled for CUDA architectures: ['sm_50', 'sm_60', 'sm_70', 'sm_75', 'sm_80', 'sm_86', 'sm_90']
GPU 0 Name: NVIDIA GeForce RTX 4090 Laptop GPU
GPU 0 Is integrated: False
GPU 0 Is multi GPU board: False
GPU 0 Multi processor count: 76
GPU 0 Total memory (GB): 15.7
GPU 0 CUDA capability (maj.min): 8.9
<\details>

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions