Skip to content

Commit 92c9e0a

Browse files
committed
benchmark 腳本
1 parent da7357a commit 92c9e0a

1 file changed

Lines changed: 202 additions & 0 deletions

File tree

benchmarks/run_benchmark.py

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
import timeit
2+
import glob
3+
import os
4+
import argparse
5+
import math
6+
import numpy as np
7+
8+
# 嘗試 import,如果失敗則提示
9+
try:
10+
import fast_jpeg_decoder as fjd
11+
except ImportError:
12+
print("❌ Error: Could not import 'fast_jpeg_decoder'. Make sure the C++ module is compiled and in path.")
13+
exit(1)
14+
15+
try:
16+
from python_implementations import numpy_decoder
17+
except ImportError:
18+
print("❌ Error: Could not import 'numpy_decoder' from 'python_implementations'.")
19+
exit(1)
20+
21+
# --- Configuration ---
22+
TEST_DATA_DIR = 'tests/test_data'
23+
OUTPUT_DIR = 'output'
24+
NUMBER_OF_RUNS = 10
25+
26+
def calculate_psnr(img1, img2):
27+
"""
28+
計算兩張圖片的 PSNR (峰值訊噪比)
29+
img1: 測試圖片
30+
img2: 參考圖片 (Ground Truth, 通常是 PIL)
31+
"""
32+
if img1.shape != img2.shape:
33+
return 0.0
34+
35+
# 計算 MSE (均方誤差)
36+
mse = np.mean((img1.astype(float) - img2.astype(float)) ** 2)
37+
if mse == 0:
38+
return float('inf') # 完全相同
39+
40+
max_pixel = 255.0
41+
psnr = 20 * math.log10(max_pixel / math.sqrt(mse))
42+
return psnr
43+
44+
def run_cpp_decoder(image_bytes):
45+
"""Wrapper for the C++ decoder."""
46+
return fjd.load_bytes(image_bytes)
47+
48+
def run_numpy_decoder(image_bytes):
49+
"""Wrapper for the NumPy decoder."""
50+
return numpy_decoder.decode(image_bytes)
51+
52+
53+
def benchmark_single_image(image_path):
54+
"""Benchmark a single image (Includes Verify, Output, and PSNR comparison)."""
55+
filename = os.path.basename(image_path)
56+
print(f"\n{'='*60}")
57+
print(f"Processing: {filename}")
58+
print(f"{'='*60}")
59+
60+
try:
61+
with open(image_path, 'rb') as f:
62+
image_bytes = f.read()
63+
except FileNotFoundError:
64+
print(f"❌ Error: Test image not found at '{image_path}'")
65+
return
66+
67+
file_size = len(image_bytes) / 1024
68+
print(f"File size: {file_size:.1f} KB")
69+
70+
# 1. 確保輸出目錄存在
71+
os.makedirs(OUTPUT_DIR, exist_ok=True)
72+
73+
# 2. 準備 "原始圖像" (Ground Truth / Reference)
74+
# 我們使用 PIL 解碼的結果作為 "標準答案"
75+
img_gt = None
76+
try:
77+
from PIL import Image
78+
import io
79+
img_gt = np.array(Image.open(io.BytesIO(image_bytes)).convert('RGB'))
80+
except ImportError:
81+
print("⚠️ PIL not installed, cannot calculate PSNR or save images.")
82+
return
83+
84+
# print(f"\n{'─'*60}")
85+
# print("Decoding & Saving:")
86+
# print(f"{'─'*60}")
87+
88+
# --- 收集所有解碼器的結果 ---
89+
decoders_result = {}
90+
91+
# 1. C++ Decoder
92+
avg_cpp_time = float('inf')
93+
try:
94+
img_cpp = run_cpp_decoder(image_bytes)
95+
if img_cpp.size > 0:
96+
decoders_result['C++ '] = img_cpp
97+
# 存檔
98+
out_name = os.path.join(OUTPUT_DIR, f"cpp_{filename}.png")
99+
Image.fromarray(img_cpp).save(out_name)
100+
# 測速
101+
cpp_time = timeit.timeit(lambda: run_cpp_decoder(image_bytes), number=NUMBER_OF_RUNS)
102+
avg_cpp_time = (cpp_time / NUMBER_OF_RUNS) * 1000
103+
else:
104+
print("❌ C++ Decoder returned empty image")
105+
except Exception as e:
106+
print(f"❌ C++ Decoder Error: {e}")
107+
108+
# 2. NumPy Decoder
109+
avg_numpy_time = float('inf')
110+
try:
111+
img_numpy = run_numpy_decoder(image_bytes)
112+
# Handle list return type if necessary
113+
if not isinstance(img_numpy, np.ndarray):
114+
img_numpy = np.array(img_numpy, dtype=np.uint8)
115+
116+
if img_numpy.size > 0:
117+
decoders_result['NumPy'] = img_numpy
118+
# 存檔
119+
out_name = os.path.join(OUTPUT_DIR, f"numpy_{filename}.png")
120+
Image.fromarray(img_numpy).save(out_name)
121+
# 測速
122+
numpy_time = timeit.timeit(lambda: run_numpy_decoder(image_bytes), number=NUMBER_OF_RUNS)
123+
avg_numpy_time = (numpy_time / NUMBER_OF_RUNS) * 1000
124+
else:
125+
print("❌ NumPy Decoder returned empty image")
126+
except Exception as e:
127+
print(f"❌ NumPy Decoder Error: {e}")
128+
129+
# 3. PIL Decoder (本身也加入比較列表,確認基準)
130+
decoders_result['PIL '] = img_gt
131+
# PIL 也可以存一份 png 當作對照組
132+
Image.fromarray(img_gt).save(os.path.join(OUTPUT_DIR, f"pil_{filename}.png"))
133+
134+
135+
# --- 統一計算 PSNR (全部 vs 原始圖像) ---
136+
print(f"\n{'─'*60}")
137+
print("Quality Metrics (vs Original/PIL):")
138+
print(f"{'─'*60}")
139+
140+
for name, img_test in decoders_result.items():
141+
try:
142+
# 形狀檢查與修正 (針對 flatten array)
143+
if img_test.shape != img_gt.shape:
144+
if img_test.size == img_gt.size:
145+
img_test = img_test.reshape(img_gt.shape)
146+
else:
147+
print(f"{name}: Shape mismatch {img_test.shape} vs {img_gt.shape}")
148+
continue
149+
150+
# 計算 PSNR
151+
psnr = calculate_psnr(img_test, img_gt)
152+
153+
# 計算平均像素差異
154+
mean_diff = np.abs(img_test.astype(float) - img_gt.astype(float)).mean()
155+
156+
# 格式化輸出
157+
status = "✅" if psnr > 30 or psnr == float('inf') else "⚠️ "
158+
psnr_str = "Infinity" if psnr == float('inf') else f"{psnr:6.2f} dB"
159+
160+
print(f"{name} Decoder:")
161+
print(f"PSNR: {status} {psnr_str}")
162+
#print(f" Mean Diff: {mean_diff:.2f}")
163+
print("-" * 20)
164+
165+
except Exception as e:
166+
print(f"{name} vs Reference: Error ({e})")
167+
168+
# --- 效能總結 ---
169+
print(f"\n{'─'*60}")
170+
print("Speed Benchmark Results:")
171+
print(f"{'─'*60}")
172+
173+
if avg_cpp_time != float('inf'):
174+
print(f"🚀 C++ Decoder: {avg_cpp_time:.2f} ms")
175+
if avg_numpy_time != float('inf'):
176+
print(f"🐍 NumPy Decoder: {avg_numpy_time:.2f} ms")
177+
178+
if avg_cpp_time != float('inf') and avg_numpy_time != float('inf'):
179+
speedup = avg_numpy_time / avg_cpp_time
180+
print(f"\n⚡ Speedup: C++ is {speedup:.2f}x faster than NumPy")
181+
182+
def main():
183+
# Find images
184+
image_patterns = [
185+
os.path.join(TEST_DATA_DIR, '*.jpg'),
186+
os.path.join(TEST_DATA_DIR, '*.jpeg'),
187+
]
188+
test_images = []
189+
for pattern in image_patterns:
190+
test_images.extend(glob.glob(pattern))
191+
192+
test_images.sort()
193+
194+
if not test_images:
195+
print(f"❌ No images found in {TEST_DATA_DIR}")
196+
return
197+
198+
for image_path in test_images:
199+
benchmark_single_image(image_path)
200+
201+
if __name__ == "__main__":
202+
main()

0 commit comments

Comments
 (0)