Skip to content

Commit b867e85

Browse files
committed
Added utility to convert images to greyscale using luminance weights (better for the current CNN structure).
1 parent ae467c3 commit b867e85

1 file changed

Lines changed: 61 additions & 0 deletions

File tree

framework/data_utils.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,72 @@
66
from datasets import load_from_disk
77
from sklearn.model_selection import train_test_split
88
from torch.utils.data import DataLoader
9+
from PIL import Image
910

1011
from framework import utils
1112
from framework.datasets import CIFAR10Dataset
1213

1314

15+
def convert_to_grayscale(image: np.ndarray) -> np.ndarray:
16+
"""Convert RGB/RGBA image to grayscale.
17+
18+
Args:
19+
image: Image array with shape (H, W, C) where C is 3 (RGB) or 4 (RGBA)
20+
21+
Returns:
22+
Grayscale image with shape (H, W)
23+
"""
24+
if len(image.shape) == 2:
25+
# Already grayscale
26+
return image
27+
elif len(image.shape) == 3:
28+
if image.shape[2] == 1:
29+
# Single channel, just squeeze
30+
return image.squeeze(axis=2)
31+
elif image.shape[2] == 3:
32+
# RGB -> Grayscale using luminance weights
33+
# Source: https://www.songho.ca/dsp/luminance/luminance.html
34+
return np.dot(image[...,:3], [0.299, 0.587, 0.114])
35+
elif image.shape[2] == 4:
36+
# RGBA -> Grayscale (ignore alpha)
37+
return np.dot(image[...,:3], [0.299, 0.587, 0.114])
38+
39+
raise ValueError(f"Unsupported image shape: {image.shape}")
40+
41+
42+
def preprocess_images_to_grayscale(images: List[np.ndarray]) -> List[np.ndarray]:
43+
"""Convert a list of images to grayscale.
44+
45+
Args:
46+
images: List of image arrays
47+
48+
Returns:
49+
List of grayscale image arrays
50+
"""
51+
return [convert_to_grayscale(img) for img in images]
52+
53+
54+
def convert_dataset_to_grayscale(dataset):
55+
"""Convert HuggingFace dataset images to grayscale in-place preprocessing.
56+
57+
Args:
58+
dataset: HuggingFace dataset with 'image' column
59+
60+
Returns:
61+
List of grayscale images and labels
62+
"""
63+
images = []
64+
labels = []
65+
66+
for item in dataset:
67+
img = np.array(item['image'])
68+
gray_img = convert_to_grayscale(img)
69+
images.append(gray_img)
70+
labels.append(item['label'])
71+
72+
return images, np.array(labels)
73+
74+
1475
def load_cifar10_data():
1576
"""Load CIFAR-10 dataset (grayscale from processed datasets)."""
1677
repo_root = Path(__file__).resolve().parents[1]

0 commit comments

Comments
 (0)