Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 84 additions & 1 deletion test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@
from torchvision.transforms.functional import pil_modes_mapping, to_pil_image
from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2._utils import check_type, is_pure_tensor
from torchvision.transforms.v2.functional._geometry import _get_perspective_coeffs, _parallelogram_to_bounding_boxes
from torchvision.transforms.v2.functional._geometry import (
_get_perspective_coeffs,
_largest_inscribed_crop_size,
_parallelogram_to_bounding_boxes,
)
from torchvision.transforms.v2.functional._meta import (
_cxcywh_to_xywh,
_cxcywh_to_xyxy,
Expand Down Expand Up @@ -2435,6 +2439,85 @@ def test_functional_image_fast_path_correctness(self, size, angle, expand):

torch.testing.assert_close(actual, expected)

@pytest.mark.parametrize("size", [(100, 100), (120, 80)])
@pytest.mark.parametrize("angle", [15.0, 30.0, 45.0])
def test_transform_crop_removes_fill(self, size, angle):
# Output of crop=True should contain no fill pixels when input is fully non-zero
h, w = size
image = tv_tensors.Image(torch.full((3, h, w), 200, dtype=torch.uint8))
transform = transforms.RandomRotation((angle, angle), fill=0, crop=True)
output = transform(image)
assert output.min().item() > 0
assert output.shape[-2] < h or output.shape[-1] < w

@pytest.mark.parametrize("size", [(100, 100), (120, 80)])
@pytest.mark.parametrize("angle", [15.0, 30.0, 45.0])
def test_transform_crop_consistent_across_inputs(self, size, angle):
# Image, mask, and bounding boxes should all be cropped to the same canvas size
h, w = size
image = tv_tensors.Image(torch.full((3, h, w), 200, dtype=torch.uint8))
mask = tv_tensors.Mask(torch.ones(1, h, w, dtype=torch.uint8))
boxes = tv_tensors.BoundingBoxes(
torch.tensor([[10.0, 10.0, 50.0, 50.0]]),
format=tv_tensors.BoundingBoxFormat.XYXY,
canvas_size=(h, w),
)
transform = transforms.RandomRotation((angle, angle), crop=True)
out_image, out_mask, out_boxes = transform(image, mask, boxes)
assert out_image.shape[-2:] == out_mask.shape[-2:]
assert out_boxes.canvas_size == (out_image.shape[-2], out_image.shape[-1])

def test_transform_crop_and_expand_mutually_exclusive(self):
with pytest.raises(ValueError, match="crop and expand are mutually exclusive"):
transforms.RandomRotation(30, expand=True, crop=True)

@pytest.mark.parametrize("angle", [0.0, 180.0])
@pytest.mark.parametrize("size", [(100, 100), (120, 80), (80, 120)])
def test_transform_crop_half_turn_preserves_size(self, size, angle):
# 0° and 180° introduce no padding regardless of aspect ratio
h, w = size
image = tv_tensors.Image(torch.zeros(3, h, w, dtype=torch.uint8))
transform = transforms.RandomRotation((angle, angle), crop=True)
output = transform(image)
assert output.shape == image.shape

@pytest.mark.parametrize("angle", [90.0, 270.0])
@pytest.mark.parametrize("size", [(100, 100), (120, 80), (80, 120)])
def test_transform_crop_quarter_turn(self, size, angle):
# With expand=False, only the central min(h, w) square of the rotated content
# remains inside the canvas; that's the largest fill-free crop.
h, w = size
image = tv_tensors.Image(torch.full((3, h, w), 200, dtype=torch.uint8))
transform = transforms.RandomRotation((angle, angle), fill=0, crop=True)
output = transform(image)
expected = min(h, w)
assert output.shape == (3, expected, expected)
assert output.min().item() > 0

def test_largest_inscribed_crop_size(self):
assert _largest_inscribed_crop_size(100, 100, 0) == (100, 100)
assert _largest_inscribed_crop_size(200, 100, 0) == (100, 200)

# 45° square: inscribed square has side = 100 / sqrt(2) ≈ 70.71 → floor to 70
crop_h, crop_w = _largest_inscribed_crop_size(100, 100, 45)
assert crop_h == crop_w == 70

# 90° on non-square: clamped to the central min(h, w) square that fits in the canvas
assert _largest_inscribed_crop_size(120, 80, 90) == (80, 80)
assert _largest_inscribed_crop_size(80, 120, 90) == (80, 80)

# 180° preserves full size on any aspect ratio
assert _largest_inscribed_crop_size(120, 80, 180) == (80, 120)

# Negative angles are symmetric to positive ones (abs() on sin/cos)
for w, h, a in [(100, 100, 30), (200, 100, 15), (120, 80, 45)]:
assert _largest_inscribed_crop_size(w, h, -a) == _largest_inscribed_crop_size(w, h, a)

# Crop is always smaller than or equal to original dimensions
for w, h, a in [(200, 100, 20), (640, 480, 15), (50, 50, 37)]:
ch, cw = _largest_inscribed_crop_size(w, h, a)
assert ch <= h and cw <= w


class TestContainerTransforms:
class BuiltinTransform(transforms.Transform):
Expand Down
29 changes: 26 additions & 3 deletions torchvision/transforms/v2/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,9 @@ class RandomRotation(Transform):
Fill value can be also a dictionary mapping data type to the fill value, e.g.
``fill={tv_tensors.Image: 127, tv_tensors.Mask: 0}`` where ``Image`` will be filled with 127 and
``Mask`` will be filled with 0.
crop (bool, optional): If ``True``, the rotated output is center-cropped to the largest axis-aligned
rectangle that fits entirely within the rotated image, removing any fill/padding regions introduced
by the rotation. Mutually exclusive with ``expand``. Default is ``False``.

.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters

Expand All @@ -620,11 +623,15 @@ def __init__(
expand: bool = False,
center: Optional[list[float]] = None,
fill: Union[_FillType, dict[Union[type, str], _FillType]] = 0,
crop: bool = False,
) -> None:
super().__init__()
if crop and expand:
raise ValueError("crop and expand are mutually exclusive")
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
self.interpolation = interpolation
self.expand = expand
self.crop = crop

self.fill = fill
self._fill = _setup_fill_arg(fill)
Expand All @@ -634,21 +641,37 @@ def __init__(

self.center = center

def _extract_params_for_v1_transform(self) -> dict[str, Any]:
params = super()._extract_params_for_v1_transform()
if params.pop("crop"):
raise ValueError(
f"{type(self).__name__}() cannot be scripted when crop=True, "
"as this feature is not supported by the v1 transform."
)
return params

def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item()
return dict(angle=angle)
params: dict[str, Any] = dict(angle=angle)
if self.crop:
height, width = query_size(flat_inputs)
params["crop_hw"] = F._geometry._largest_inscribed_crop_size(width, height, angle)
return params

def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
fill = _get_fill(self._fill, type(inpt))
return self._call_kernel(
output = self._call_kernel(
F.rotate,
inpt,
**params,
angle=params["angle"],
interpolation=self.interpolation,
expand=self.expand,
center=self.center,
fill=fill,
)
if self.crop:
output = self._call_kernel(F.center_crop, output, output_size=list(params["crop_hw"]))
return output


class RandomAffine(Transform):
Expand Down
43 changes: 43 additions & 0 deletions torchvision/transforms/v2/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -1335,6 +1335,49 @@ def affine_video(
)


def _largest_inscribed_crop_size(width: int, height: int, angle: float) -> tuple[int, int]:
"""Compute the largest axis-aligned rectangle inscribed in a rotated width x height rectangle.

Returns ``(crop_height, crop_width)`` as integers.

Approach taken from https://stackoverflow.com/questions/16702966/rotate-image-and-crop-out-black-borders
"""
angle_rad = math.radians(angle)
sin_a = abs(math.sin(angle_rad))
cos_a = abs(math.cos(angle_rad))

# Guard against small values of sin_a or cos_a that will cause us to truncate a pixel
# off rotations that introduce no padding (e.g. 90° or 180°).
if sin_a < 1e-10:
return height, width
if cos_a < 1e-10:
# 90°/270°: rotated rectangle has dims swapped; clamp to canvas (h, w).
return min(width, height), min(height, width)

width_is_longer = width >= height
side_long = width if width_is_longer else height
side_short = height if width_is_longer else width

if side_short <= 2.0 * sin_a * cos_a * side_long or abs(sin_a - cos_a) < 1e-10:
# Half-constrained: two crop corners touch the longer side.
# Also handles the 45° case where sin_a == cos_a and the denominator
# in the fully constrained solution goes to zero.
x = 0.5 * side_short
if width_is_longer:
crop_w, crop_h = x / sin_a, x / cos_a
else:
crop_w, crop_h = x / cos_a, x / sin_a
else:
# Fully constrained: crop touches all four sides
cos_2a = cos_a * cos_a - sin_a * sin_a
crop_w = (width * cos_a - height * sin_a) / cos_2a
crop_h = (height * cos_a - width * sin_a) / cos_2a

# Use floor (int()) to guarantee the crop region contains no fill pixels.
# Clamp to image dimensions for edge cases like wide images rotated near 90°.
return min(math.floor(crop_h), height), min(math.floor(crop_w), width)


def rotate(
inpt: torch.Tensor,
angle: float,
Expand Down