1+ import timeit
2+ import glob
3+ import os
4+ import argparse
5+ import math
6+ import numpy as np
7+
8+ # 嘗試 import,如果失敗則提示
9+ try :
10+ import fast_jpeg_decoder as fjd
11+ except ImportError :
12+ print ("❌ Error: Could not import 'fast_jpeg_decoder'. Make sure the C++ module is compiled and in path." )
13+ exit (1 )
14+
15+ try :
16+ from python_implementations import numpy_decoder
17+ except ImportError :
18+ print ("❌ Error: Could not import 'numpy_decoder' from 'python_implementations'." )
19+ exit (1 )
20+
21+ # --- Configuration ---
22+ TEST_DATA_DIR = 'tests/test_data'
23+ OUTPUT_DIR = 'output'
24+ NUMBER_OF_RUNS = 10
25+
26+ def calculate_psnr (img1 , img2 ):
27+ """
28+ 計算兩張圖片的 PSNR (峰值訊噪比)
29+ img1: 測試圖片
30+ img2: 參考圖片 (Ground Truth, 通常是 PIL)
31+ """
32+ if img1 .shape != img2 .shape :
33+ return 0.0
34+
35+ # 計算 MSE (均方誤差)
36+ mse = np .mean ((img1 .astype (float ) - img2 .astype (float )) ** 2 )
37+ if mse == 0 :
38+ return float ('inf' ) # 完全相同
39+
40+ max_pixel = 255.0
41+ psnr = 20 * math .log10 (max_pixel / math .sqrt (mse ))
42+ return psnr
43+
44+ def run_cpp_decoder (image_bytes ):
45+ """Wrapper for the C++ decoder."""
46+ return fjd .load_bytes (image_bytes )
47+
48+ def run_numpy_decoder (image_bytes ):
49+ """Wrapper for the NumPy decoder."""
50+ return numpy_decoder .decode (image_bytes )
51+
52+
53+ def benchmark_single_image (image_path ):
54+ """Benchmark a single image (Includes Verify, Output, and PSNR comparison)."""
55+ filename = os .path .basename (image_path )
56+ print (f"\n { '=' * 60 } " )
57+ print (f"Processing: { filename } " )
58+ print (f"{ '=' * 60 } " )
59+
60+ try :
61+ with open (image_path , 'rb' ) as f :
62+ image_bytes = f .read ()
63+ except FileNotFoundError :
64+ print (f"❌ Error: Test image not found at '{ image_path } '" )
65+ return
66+
67+ file_size = len (image_bytes ) / 1024
68+ print (f"File size: { file_size :.1f} KB" )
69+
70+ # 1. 確保輸出目錄存在
71+ os .makedirs (OUTPUT_DIR , exist_ok = True )
72+
73+ # 2. 準備 "原始圖像" (Ground Truth / Reference)
74+ # 我們使用 PIL 解碼的結果作為 "標準答案"
75+ img_gt = None
76+ try :
77+ from PIL import Image
78+ import io
79+ img_gt = np .array (Image .open (io .BytesIO (image_bytes )).convert ('RGB' ))
80+ except ImportError :
81+ print ("⚠️ PIL not installed, cannot calculate PSNR or save images." )
82+ return
83+
84+ # print(f"\n{'─'*60}")
85+ # print("Decoding & Saving:")
86+ # print(f"{'─'*60}")
87+
88+ # --- 收集所有解碼器的結果 ---
89+ decoders_result = {}
90+
91+ # 1. C++ Decoder
92+ avg_cpp_time = float ('inf' )
93+ try :
94+ img_cpp = run_cpp_decoder (image_bytes )
95+ if img_cpp .size > 0 :
96+ decoders_result ['C++ ' ] = img_cpp
97+ # 存檔
98+ out_name = os .path .join (OUTPUT_DIR , f"cpp_{ filename } .png" )
99+ Image .fromarray (img_cpp ).save (out_name )
100+ # 測速
101+ cpp_time = timeit .timeit (lambda : run_cpp_decoder (image_bytes ), number = NUMBER_OF_RUNS )
102+ avg_cpp_time = (cpp_time / NUMBER_OF_RUNS ) * 1000
103+ else :
104+ print ("❌ C++ Decoder returned empty image" )
105+ except Exception as e :
106+ print (f"❌ C++ Decoder Error: { e } " )
107+
108+ # 2. NumPy Decoder
109+ avg_numpy_time = float ('inf' )
110+ try :
111+ img_numpy = run_numpy_decoder (image_bytes )
112+ # Handle list return type if necessary
113+ if not isinstance (img_numpy , np .ndarray ):
114+ img_numpy = np .array (img_numpy , dtype = np .uint8 )
115+
116+ if img_numpy .size > 0 :
117+ decoders_result ['NumPy' ] = img_numpy
118+ # 存檔
119+ out_name = os .path .join (OUTPUT_DIR , f"numpy_{ filename } .png" )
120+ Image .fromarray (img_numpy ).save (out_name )
121+ # 測速
122+ numpy_time = timeit .timeit (lambda : run_numpy_decoder (image_bytes ), number = NUMBER_OF_RUNS )
123+ avg_numpy_time = (numpy_time / NUMBER_OF_RUNS ) * 1000
124+ else :
125+ print ("❌ NumPy Decoder returned empty image" )
126+ except Exception as e :
127+ print (f"❌ NumPy Decoder Error: { e } " )
128+
129+ # 3. PIL Decoder (本身也加入比較列表,確認基準)
130+ decoders_result ['PIL ' ] = img_gt
131+ # PIL 也可以存一份 png 當作對照組
132+ Image .fromarray (img_gt ).save (os .path .join (OUTPUT_DIR , f"pil_{ filename } .png" ))
133+
134+
135+ # --- 統一計算 PSNR (全部 vs 原始圖像) ---
136+ print (f"\n { '─' * 60 } " )
137+ print ("Quality Metrics (vs Original/PIL):" )
138+ print (f"{ '─' * 60 } " )
139+
140+ for name , img_test in decoders_result .items ():
141+ try :
142+ # 形狀檢查與修正 (針對 flatten array)
143+ if img_test .shape != img_gt .shape :
144+ if img_test .size == img_gt .size :
145+ img_test = img_test .reshape (img_gt .shape )
146+ else :
147+ print (f"{ name } : Shape mismatch { img_test .shape } vs { img_gt .shape } " )
148+ continue
149+
150+ # 計算 PSNR
151+ psnr = calculate_psnr (img_test , img_gt )
152+
153+ # 計算平均像素差異
154+ mean_diff = np .abs (img_test .astype (float ) - img_gt .astype (float )).mean ()
155+
156+ # 格式化輸出
157+ status = "✅" if psnr > 30 or psnr == float ('inf' ) else "⚠️ "
158+ psnr_str = "Infinity" if psnr == float ('inf' ) else f"{ psnr :6.2f} dB"
159+
160+ print (f"{ name } Decoder:" )
161+ print (f"PSNR: { status } { psnr_str } " )
162+ #print(f" Mean Diff: {mean_diff:.2f}")
163+ print ("-" * 20 )
164+
165+ except Exception as e :
166+ print (f"{ name } vs Reference: Error ({ e } )" )
167+
168+ # --- 效能總結 ---
169+ print (f"\n { '─' * 60 } " )
170+ print ("Speed Benchmark Results:" )
171+ print (f"{ '─' * 60 } " )
172+
173+ if avg_cpp_time != float ('inf' ):
174+ print (f"🚀 C++ Decoder: { avg_cpp_time :.2f} ms" )
175+ if avg_numpy_time != float ('inf' ):
176+ print (f"🐍 NumPy Decoder: { avg_numpy_time :.2f} ms" )
177+
178+ if avg_cpp_time != float ('inf' ) and avg_numpy_time != float ('inf' ):
179+ speedup = avg_numpy_time / avg_cpp_time
180+ print (f"\n ⚡ Speedup: C++ is { speedup :.2f} x faster than NumPy" )
181+
182+ def main ():
183+ # Find images
184+ image_patterns = [
185+ os .path .join (TEST_DATA_DIR , '*.jpg' ),
186+ os .path .join (TEST_DATA_DIR , '*.jpeg' ),
187+ ]
188+ test_images = []
189+ for pattern in image_patterns :
190+ test_images .extend (glob .glob (pattern ))
191+
192+ test_images .sort ()
193+
194+ if not test_images :
195+ print (f"❌ No images found in { TEST_DATA_DIR } " )
196+ return
197+
198+ for image_path in test_images :
199+ benchmark_single_image (image_path )
200+
201+ if __name__ == "__main__" :
202+ main ()
0 commit comments