-
Notifications
You must be signed in to change notification settings - Fork 321
Expand file tree
/
Copy pathvision_process.py
More file actions
340 lines (293 loc) · 12.3 KB
/
vision_process.py
File metadata and controls
340 lines (293 loc) · 12.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
from __future__ import annotations
import math
import os
import torch
import numpy as np
from PIL import Image
from collections import defaultdict
from typing import List, Optional, Union, Tuple
from transformers.image_processing_utils_fast import (
BaseImageProcessorFast,
group_images_by_shape,
reorder_images,
)
from torchvision.transforms import InterpolationMode
from transformers.image_utils import (
OPENAI_CLIP_MEAN,
OPENAI_CLIP_STD,
ChannelDimension,
PILImageResampling,
SizeDict,
)
from torchvision.transforms.v2 import functional as F
from lightllm.utils.log_utils import init_logger
logger = init_logger(__name__)
def closest_factor_pair(n):
"""Find the factor pair of n closest to sqrt(n). Returns (smaller, larger)."""
sqrt_n = int(math.sqrt(n))
for i in range(sqrt_n, 0, -1):
if n % i == 0:
return i, n // i
return 1, n
@torch.no_grad()
def qwen_vl_check_max_len_infer(model, max_batch_size):
"""OOM pre-check for Qwen-family vision models.
Constructs worst-case dummy images at max_pixels resolution,
replicates for max_batch_size, and runs a forward pass to validate
GPU memory is sufficient.
"""
disable_check = os.getenv("DISABLE_CHECK_MAX_LEN_INFER", None) is not None
if disable_check:
return
unit = model.patch_size * model.spatial_merge_size
max_pixels = model.processor.max_pixels
max_patches = max_pixels // (unit * unit)
if max_patches < 1:
max_patches = 1
h_factor, w_factor = closest_factor_pair(max_patches)
worst_h = unit * h_factor
worst_w = unit * w_factor
try:
dummy_image = Image.new("RGB", (worst_w, worst_h), color=(128, 128, 128))
pixel_values, grid_thw = model.processor.preprocess(dummy_image)
pixel_values = pixel_values.repeat(max_batch_size, 1, 1)
grid_thw = grid_thw.repeat(max_batch_size, 1)
pixel_values = pixel_values.to("cuda", dtype=model.data_type, non_blocking=True)
grid_thw = grid_thw.to("cuda", non_blocking=True)
result = model.forward(pixel_values, grid_thw=grid_thw)
del result, pixel_values, grid_thw
torch.cuda.empty_cache()
logger.info(f"vit check max_len {max_batch_size} infer ok")
except (RuntimeError, torch.OutOfMemoryError, ValueError):
logger.exception("Qwen VL check max len infer failed")
exception_str = (
"Vit check max len infer fail, you can try: 1.Set the --visual_infer_batch_size to a smaller value."
)
logger.error(exception_str)
raise RuntimeError(exception_str)
IMAGE_FACTOR = 28
MIN_PIXELS = 4 * 28 * 28
MAX_PIXELS = 16384 * 28 * 28
MAX_RATIO = 200
FRAME_FACTOR = 2
FPS = 2.0
FPS_MIN_FRAMES = 4
FPS_MAX_FRAMES = 768
def smart_resize(
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
) -> tuple[int, int]:
if max(height, width) / min(height, width) > MAX_RATIO:
raise ValueError(
f"absolute aspect ratio must be smaller than MAX_RATIO, got {max(height, width) / min(height, width)}"
)
h_bar = round(height / factor) * factor
w_bar = round(width / factor) * factor
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = max(factor, math.floor(height / beta / factor) * factor)
w_bar = max(factor, math.floor(width / beta / factor) * factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = math.ceil(height * beta / factor) * factor
w_bar = math.ceil(width * beta / factor) * factor
return h_bar, w_bar
def resize_image(
image_file: Image.Image, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
) -> tuple[Image.Image]:
image = image_file.convert("RGB")
width, height = image.size
resized_height, resized_width = smart_resize(
height,
width,
factor=factor,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
image = image.resize((resized_width, resized_height))
return image
class Qwen2VLImageProcessor(BaseImageProcessorFast):
def __init__(
self,
size: dict = None,
do_resize: bool = True,
resample: PILImageResampling = PILImageResampling.BICUBIC,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: bool = True,
min_pixels: int = 56 * 56,
max_pixels: int = 28 * 28 * 1280,
patch_size: int = 14,
temporal_patch_size: int = 2,
merge_size: int = 2,
disable_grouping: Optional[bool] = None,
interpolation: Optional["F.InterpolationMode"] = InterpolationMode.BICUBIC,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.size = size
self.do_resize = do_resize
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.do_convert_rgb = do_convert_rgb
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
self.min_pixels = min_pixels
self.max_pixels = max_pixels
self.patch_size = patch_size
self.temporal_patch_size = temporal_patch_size
self.merge_size = merge_size
self.disable_grouping = disable_grouping
self.interpolation = interpolation
self.data_format = ChannelDimension.FIRST
if isinstance(self.size, dict):
shortest = self.size.get("shortest_edge", None)
longest = self.size.get("longest_edge", None)
if shortest is not None:
self.min_pixels = shortest
if longest is not None:
self.max_pixels = longest
self._fused_cache = {} # key: (do_norm, do_rescale, rescale_factor, device)
def _get_fused_mean_std(
self,
do_normalize: bool,
image_mean: Union[float, list[float]],
image_std: Union[float, list[float]],
do_rescale: bool,
rescale_factor: float,
device: Optional["torch.device"],
) -> tuple[torch.Tensor, torch.Tensor, bool]:
key = (bool(do_normalize), bool(do_rescale), float(rescale_factor), str(device))
if key not in self._fused_cache:
if do_rescale and do_normalize:
mean = torch.tensor(image_mean) * (1.0 / rescale_factor)
std = torch.tensor(image_std) * (1.0 / rescale_factor)
do_rescale = False
else:
mean = torch.tensor(image_mean)
std = torch.tensor(image_std)
self._fused_cache[key] = (mean.to(device=device), std.to(device=device), do_rescale)
return self._fused_cache[key]
def rescale_and_normalize(
self,
images: "torch.Tensor",
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: Union[float, list[float]],
image_std: Union[float, list[float]],
) -> "torch.Tensor":
"""
Rescale and normalize images.
"""
image_mean, image_std, do_rescale = self._get_fused_mean_std(
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
device=images.device,
)
# if/elif as we use fused rescale and normalize if both are set to True
if do_normalize:
images = self.normalize(images.to(dtype=torch.float32), image_mean, image_std)
elif do_rescale:
images = self.rescale(images, rescale_factor)
return images
@torch.inference_mode()
def preprocess(self, image) -> Tuple[torch.Tensor, torch.Tensor]:
try:
return self._preprocess_bydevice(image, device="cuda")
except Exception as e:
logger.warning(f"Exception during image preprocessing on CUDA: {str(e)}")
torch.cuda.current_stream().synchronize()
return self._preprocess_bydevice(image, device="cpu")
def _preprocess_bydevice(self, image, device="cuda") -> Tuple[torch.Tensor, torch.Tensor]:
if image.mode != "RGB":
image = image.convert("RGB")
image_arr = np.asarray(image, dtype=np.uint8)
# Copy to ensure writable array (avoids PyTorch warning for read-only NumPy arrays)
image_data = (
torch.from_numpy(image_arr.copy()).permute(2, 0, 1).contiguous().to(device=device, non_blocking=True)
)
grouped_images, grouped_images_index = group_images_by_shape(
[image_data], disable_grouping=self.disable_grouping
)
resized_images_grouped = {}
for shape, stacked_images in grouped_images.items():
height, width = stacked_images.shape[-2:]
if self.do_resize:
resized_height, resized_width = smart_resize(
height,
width,
factor=self.patch_size * self.merge_size,
min_pixels=self.min_pixels,
max_pixels=self.max_pixels,
)
stacked_images = self.resize(
image=stacked_images,
size=SizeDict(height=resized_height, width=resized_width),
interpolation=self.interpolation,
)
resized_images_grouped[shape] = stacked_images
grouped_images = None
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
resized_images_grouped = None
grouped_images, grouped_images_index = group_images_by_shape(
resized_images, disable_grouping=self.disable_grouping
)
resized_images = None
processed_images_grouped = {}
processed_grids = {}
for shape, stacked_images in grouped_images.items():
stacked_images = stacked_images.to("cuda", non_blocking=True)
resized_height, resized_width = stacked_images.shape[-2:]
patches = self.rescale_and_normalize(
stacked_images,
self.do_rescale,
self.rescale_factor,
self.do_normalize,
self.image_mean,
self.image_std,
)
if patches.ndim == 4:
patches = patches.unsqueeze(1)
if patches.shape[1] % self.temporal_patch_size != 0:
repeats = patches[:, -1:].repeat(1, self.temporal_patch_size - 1, 1, 1, 1)
patches = torch.cat([patches, repeats], dim=1)
batch_size, grid_t, channel = patches.shape[:3]
grid_t = grid_t // self.temporal_patch_size
grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
patches = (
patches.view(
batch_size,
grid_t,
self.temporal_patch_size,
channel,
grid_h // self.merge_size,
self.merge_size,
self.patch_size,
grid_w // self.merge_size,
self.merge_size,
self.patch_size,
)
.permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9)
.contiguous()
)
flatten_patches = patches.view(
batch_size,
grid_t * grid_h * grid_w,
channel * self.temporal_patch_size * self.patch_size * self.patch_size,
)
processed_images_grouped[shape] = flatten_patches
processed_grids[shape] = [[grid_t, grid_h, grid_w]] * batch_size
grouped_images = None
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
processed_grids = reorder_images(processed_grids, grouped_images_index)
pixel_values = torch.cat(processed_images, dim=0)
image_grid_thw = torch.as_tensor(processed_grids)
return pixel_values, image_grid_thw