From 648bfb4a1dfe8984d829b829963fd4ceb08a9190 Mon Sep 17 00:00:00 2001 From: sg3-141-592 <5122866+sg3-141-592@users.noreply.github.com> Date: Sun, 29 Mar 2026 15:37:35 +0100 Subject: [PATCH] Add 'crop' option to RandomRotate that centre crops the image to remove any padding regions introduced by the rotation --- test/test_transforms_v2.py | 85 ++++++++++++++++++- torchvision/transforms/v2/_geometry.py | 29 ++++++- .../transforms/v2/functional/_geometry.py | 43 ++++++++++ 3 files changed, 153 insertions(+), 4 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 30d7ba69bea..9eafa35e37e 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -52,7 +52,11 @@ from torchvision.transforms.functional import pil_modes_mapping, to_pil_image from torchvision.transforms.v2 import functional as F from torchvision.transforms.v2._utils import check_type, is_pure_tensor -from torchvision.transforms.v2.functional._geometry import _get_perspective_coeffs, _parallelogram_to_bounding_boxes +from torchvision.transforms.v2.functional._geometry import ( + _get_perspective_coeffs, + _largest_inscribed_crop_size, + _parallelogram_to_bounding_boxes, +) from torchvision.transforms.v2.functional._meta import ( _cxcywh_to_xywh, _cxcywh_to_xyxy, @@ -2435,6 +2439,85 @@ def test_functional_image_fast_path_correctness(self, size, angle, expand): torch.testing.assert_close(actual, expected) + @pytest.mark.parametrize("size", [(100, 100), (120, 80)]) + @pytest.mark.parametrize("angle", [15.0, 30.0, 45.0]) + def test_transform_crop_removes_fill(self, size, angle): + # Output of crop=True should contain no fill pixels when input is fully non-zero + h, w = size + image = tv_tensors.Image(torch.full((3, h, w), 200, dtype=torch.uint8)) + transform = transforms.RandomRotation((angle, angle), fill=0, crop=True) + output = transform(image) + assert output.min().item() > 0 + assert output.shape[-2] < h or output.shape[-1] < w + + @pytest.mark.parametrize("size", [(100, 100), (120, 80)]) + @pytest.mark.parametrize("angle", [15.0, 30.0, 45.0]) + def test_transform_crop_consistent_across_inputs(self, size, angle): + # Image, mask, and bounding boxes should all be cropped to the same canvas size + h, w = size + image = tv_tensors.Image(torch.full((3, h, w), 200, dtype=torch.uint8)) + mask = tv_tensors.Mask(torch.ones(1, h, w, dtype=torch.uint8)) + boxes = tv_tensors.BoundingBoxes( + torch.tensor([[10.0, 10.0, 50.0, 50.0]]), + format=tv_tensors.BoundingBoxFormat.XYXY, + canvas_size=(h, w), + ) + transform = transforms.RandomRotation((angle, angle), crop=True) + out_image, out_mask, out_boxes = transform(image, mask, boxes) + assert out_image.shape[-2:] == out_mask.shape[-2:] + assert out_boxes.canvas_size == (out_image.shape[-2], out_image.shape[-1]) + + def test_transform_crop_and_expand_mutually_exclusive(self): + with pytest.raises(ValueError, match="crop and expand are mutually exclusive"): + transforms.RandomRotation(30, expand=True, crop=True) + + @pytest.mark.parametrize("angle", [0.0, 180.0]) + @pytest.mark.parametrize("size", [(100, 100), (120, 80), (80, 120)]) + def test_transform_crop_half_turn_preserves_size(self, size, angle): + # 0° and 180° introduce no padding regardless of aspect ratio + h, w = size + image = tv_tensors.Image(torch.zeros(3, h, w, dtype=torch.uint8)) + transform = transforms.RandomRotation((angle, angle), crop=True) + output = transform(image) + assert output.shape == image.shape + + @pytest.mark.parametrize("angle", [90.0, 270.0]) + @pytest.mark.parametrize("size", [(100, 100), (120, 80), (80, 120)]) + def test_transform_crop_quarter_turn(self, size, angle): + # With expand=False, only the central min(h, w) square of the rotated content + # remains inside the canvas; that's the largest fill-free crop. + h, w = size + image = tv_tensors.Image(torch.full((3, h, w), 200, dtype=torch.uint8)) + transform = transforms.RandomRotation((angle, angle), fill=0, crop=True) + output = transform(image) + expected = min(h, w) + assert output.shape == (3, expected, expected) + assert output.min().item() > 0 + + def test_largest_inscribed_crop_size(self): + assert _largest_inscribed_crop_size(100, 100, 0) == (100, 100) + assert _largest_inscribed_crop_size(200, 100, 0) == (100, 200) + + # 45° square: inscribed square has side = 100 / sqrt(2) ≈ 70.71 → floor to 70 + crop_h, crop_w = _largest_inscribed_crop_size(100, 100, 45) + assert crop_h == crop_w == 70 + + # 90° on non-square: clamped to the central min(h, w) square that fits in the canvas + assert _largest_inscribed_crop_size(120, 80, 90) == (80, 80) + assert _largest_inscribed_crop_size(80, 120, 90) == (80, 80) + + # 180° preserves full size on any aspect ratio + assert _largest_inscribed_crop_size(120, 80, 180) == (80, 120) + + # Negative angles are symmetric to positive ones (abs() on sin/cos) + for w, h, a in [(100, 100, 30), (200, 100, 15), (120, 80, 45)]: + assert _largest_inscribed_crop_size(w, h, -a) == _largest_inscribed_crop_size(w, h, a) + + # Crop is always smaller than or equal to original dimensions + for w, h, a in [(200, 100, 20), (640, 480, 15), (50, 50, 37)]: + ch, cw = _largest_inscribed_crop_size(w, h, a) + assert ch <= h and cw <= w + class TestContainerTransforms: class BuiltinTransform(transforms.Transform): diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index 17ba073d310..3de09c4bea4 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -606,6 +606,9 @@ class RandomRotation(Transform): Fill value can be also a dictionary mapping data type to the fill value, e.g. ``fill={tv_tensors.Image: 127, tv_tensors.Mask: 0}`` where ``Image`` will be filled with 127 and ``Mask`` will be filled with 0. + crop (bool, optional): If ``True``, the rotated output is center-cropped to the largest axis-aligned + rectangle that fits entirely within the rotated image, removing any fill/padding regions introduced + by the rotation. Mutually exclusive with ``expand``. Default is ``False``. .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters @@ -620,11 +623,15 @@ def __init__( expand: bool = False, center: Optional[list[float]] = None, fill: Union[_FillType, dict[Union[type, str], _FillType]] = 0, + crop: bool = False, ) -> None: super().__init__() + if crop and expand: + raise ValueError("crop and expand are mutually exclusive") self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) self.interpolation = interpolation self.expand = expand + self.crop = crop self.fill = fill self._fill = _setup_fill_arg(fill) @@ -634,21 +641,37 @@ def __init__( self.center = center + def _extract_params_for_v1_transform(self) -> dict[str, Any]: + params = super()._extract_params_for_v1_transform() + if params.pop("crop"): + raise ValueError( + f"{type(self).__name__}() cannot be scripted when crop=True, " + "as this feature is not supported by the v1 transform." + ) + return params + def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item() - return dict(angle=angle) + params: dict[str, Any] = dict(angle=angle) + if self.crop: + height, width = query_size(flat_inputs) + params["crop_hw"] = F._geometry._largest_inscribed_crop_size(width, height, angle) + return params def transform(self, inpt: Any, params: dict[str, Any]) -> Any: fill = _get_fill(self._fill, type(inpt)) - return self._call_kernel( + output = self._call_kernel( F.rotate, inpt, - **params, + angle=params["angle"], interpolation=self.interpolation, expand=self.expand, center=self.center, fill=fill, ) + if self.crop: + output = self._call_kernel(F.center_crop, output, output_size=list(params["crop_hw"])) + return output class RandomAffine(Transform): diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index a0a2e610868..a1bb5f3aed4 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -1335,6 +1335,49 @@ def affine_video( ) +def _largest_inscribed_crop_size(width: int, height: int, angle: float) -> tuple[int, int]: + """Compute the largest axis-aligned rectangle inscribed in a rotated width x height rectangle. + + Returns ``(crop_height, crop_width)`` as integers. + + Approach taken from https://stackoverflow.com/questions/16702966/rotate-image-and-crop-out-black-borders + """ + angle_rad = math.radians(angle) + sin_a = abs(math.sin(angle_rad)) + cos_a = abs(math.cos(angle_rad)) + + # Guard against small values of sin_a or cos_a that will cause us to truncate a pixel + # off rotations that introduce no padding (e.g. 90° or 180°). + if sin_a < 1e-10: + return height, width + if cos_a < 1e-10: + # 90°/270°: rotated rectangle has dims swapped; clamp to canvas (h, w). + return min(width, height), min(height, width) + + width_is_longer = width >= height + side_long = width if width_is_longer else height + side_short = height if width_is_longer else width + + if side_short <= 2.0 * sin_a * cos_a * side_long or abs(sin_a - cos_a) < 1e-10: + # Half-constrained: two crop corners touch the longer side. + # Also handles the 45° case where sin_a == cos_a and the denominator + # in the fully constrained solution goes to zero. + x = 0.5 * side_short + if width_is_longer: + crop_w, crop_h = x / sin_a, x / cos_a + else: + crop_w, crop_h = x / cos_a, x / sin_a + else: + # Fully constrained: crop touches all four sides + cos_2a = cos_a * cos_a - sin_a * sin_a + crop_w = (width * cos_a - height * sin_a) / cos_2a + crop_h = (height * cos_a - width * sin_a) / cos_2a + + # Use floor (int()) to guarantee the crop region contains no fill pixels. + # Clamp to image dimensions for edge cases like wide images rotated near 90°. + return min(math.floor(crop_h), height), min(math.floor(crop_w), width) + + def rotate( inpt: torch.Tensor, angle: float,