Skip to content

Commit 2e4d795

Browse files
committed
Add the log spectral distance metric to the conformance test script.
1 parent 4f53a11 commit 2e4d795

4 files changed

Lines changed: 187 additions & 45 deletions

File tree

tests/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ To set up a python environment using pip
119119
```
120120
python3 -m venv venv
121121
source venv/bin/activate
122-
pip install scipy protobuf tqdm numpy
122+
pip install scipy protobuf tqdm numpy librosa
123123
```
124124

125125
Run the test suite.

tests/dsp_utils.py

Lines changed: 72 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
import wave
66
import numpy as np
77
import scipy.io.wavfile as wavfile
8+
import librosa
89

910

10-
def calc_average_channel_psnr_pcm(
11+
def calc_per_channel_psnr_pcm(
1112
ref_signal: np.ndarray, signal: np.ndarray, sampwidth_bytes: int
1213
):
1314
"""Calculates the PSNR between two signals.
@@ -19,8 +20,7 @@ def calc_average_channel_psnr_pcm(
1920
24-bit).
2021
2122
Returns:
22-
The average PSNR in dB across all channels, or -1 if all channels are
23-
identical.
23+
The per channel PSNR in dB.
2424
"""
2525
assert (
2626
sampwidth_bytes > 1
@@ -40,25 +40,78 @@ def calc_average_channel_psnr_pcm(
4040
for i in range(num_channels):
4141
mse_value = mse[i] if num_channels > 1 else mse
4242
if mse_value == 0:
43+
psnr_list.append(np.inf)
4344
logging.debug("ch#%d PSNR: inf", i)
4445
else:
4546
psnr_value = 10 * math.log10(max_value**2 / mse_value)
4647
psnr_list.append(psnr_value)
4748
logging.debug("ch#%d PSNR: %f dB", i, psnr_value)
4849

49-
return -1 if len(psnr_list) == 0 else sum(psnr_list) / len(psnr_list)
50+
return psnr_list
5051

5152

52-
def calc_average_channel_psnr_wav(ref_filepath: str, target_filepath: str):
53-
"""Calculates the PSNR between two WAV files.
53+
def calc_per_channel_lsd_pcm(ref_signal: np.ndarray,
54+
signal: np.ndarray,
55+
sampling_rate: int):
56+
"""Calculates the log spectral distance using Mel bins between two signals.
57+
58+
Args:
59+
ref_signal: The reference signal as a numpy array.
60+
signal: The signal to compare as a numpy array.
61+
sampling rate: The sampling rate of the signals in Hz.
62+
63+
Returns:
64+
The per channel log spectral distance in dB.
65+
"""
66+
eps = 1e-4
67+
68+
# Convert to float
69+
ref_signal = ref_signal / np.iinfo(ref_signal.dtype).max
70+
signal = signal / np.iinfo(signal.dtype).max
71+
72+
lsd_list = list()
73+
74+
# To support mono channel
75+
num_channels = 1 if ref_signal.shape[1:] == () else ref_signal.shape[1]
76+
for i in range(num_channels):
77+
ref_channel = ref_signal[:, i] if num_channels > 1 else ref_signal
78+
signal_channel = signal[:, i] if num_channels > 1 else signal
79+
80+
lsd_frames = list()
81+
82+
# Compute mel spectrogram
83+
mel_ref = librosa.feature.melspectrogram(y=ref_channel, sr=sampling_rate)
84+
mel_signal = librosa.feature.melspectrogram(y=signal_channel,
85+
sr=sampling_rate)
86+
87+
log_mel_ref = 10 * np.log10(mel_ref + eps)
88+
log_mel_signal = 10 * np.log10(mel_signal + eps)
89+
90+
diff_squared = (log_mel_ref - log_mel_signal) ** 2
91+
92+
# Average across mel bins, which is the 0th dimension
93+
lsd_per_frame = np.sqrt(np.mean(diff_squared, axis=0))
94+
95+
# shape: (1, num_frames) -> (num_frames,)
96+
lsd_per_frame = np.squeeze(lsd_per_frame)
97+
98+
lsd_value = np.mean(lsd_per_frame)
99+
lsd_list.append(lsd_value)
100+
logging.debug('ch#d LSD: %f dB', i, lsd_value)
101+
102+
return lsd_list
103+
104+
105+
def calc_score_wav(ref_filepath: str, target_filepath: str, metric: str):
106+
"""Calculates the score between two WAV files.
54107
55108
Args:
56109
ref_filepath: Path to the reference WAV file.
57110
target_filepath: Path to the target WAV file to compare.
111+
metric: one of 'PSNR' or 'SNR'.
58112
59113
Returns:
60-
The average PSNR in dB across all channels. Or -1 if all channels are
61-
identical.
114+
The score in dB, averaged over all channels.
62115
63116
Raises:
64117
Exception: If the wav files have different samplerate, channels, bit-depth
@@ -99,6 +152,14 @@ def calc_average_channel_psnr_wav(ref_filepath: str, target_filepath: str):
99152
_, ref_data = wavfile.read(ref_filepath)
100153
_, target_data = wavfile.read(target_filepath)
101154

102-
return calc_average_channel_psnr_pcm(
103-
ref_data, target_data, ref_wav.getsampwidth()
104-
)
155+
if metric == 'PSNR':
156+
scores_list = calc_per_channel_psnr_pcm(
157+
ref_data, target_data, ref_wav.getsampwidth()
158+
)
159+
elif metric == 'LSD':
160+
scores_list = calc_per_channel_lsd_pcm(ref_data, target_data,
161+
ref_wav.getframerate())
162+
else:
163+
return None
164+
165+
return np.mean(scores_list)

tests/run_decode_and_psnr_test.py

Lines changed: 107 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,53 @@
3939
layout_index=1,
4040
reason='Extension layouts cannot be decoded.',
4141
),
42+
TestExclusions(
43+
file_name_prefix='test_000119',
44+
mix_presentation_id=42,
45+
layout_index=0,
46+
reason='Signal too short',
47+
),
48+
TestExclusions(
49+
file_name_prefix='test_000120',
50+
mix_presentation_id=42,
51+
layout_index=0,
52+
reason='Signal too short',
53+
),
54+
TestExclusions(
55+
file_name_prefix='test_000122',
56+
mix_presentation_id=42,
57+
layout_index=0,
58+
reason='Signal too short',
59+
),
60+
TestExclusions(
61+
file_name_prefix='test_000129',
62+
mix_presentation_id=42,
63+
layout_index=0,
64+
reason='Signal too short',
65+
),
66+
TestExclusions(
67+
file_name_prefix='test_000130',
68+
mix_presentation_id=42,
69+
layout_index=0,
70+
reason='Signal too short',
71+
),
4272
]
4373

4474
# Opus/AAC are lossy codecs, we allow a more lenient threshold for them.
45-
LOSSY_PSNR_THRESHOLD = 30
46-
LOSSLESS_PSNR_THRESHOLD = 80
47-
75+
# Some test files use stress signals, e.g. sawtooth. Allow a more lenient
76+
# threshold for them.
77+
THRESHOLDS = {
78+
# PSNR: larger is better
79+
'PSNR': {
80+
'lossless': {'base': 80},
81+
'lossy': {'base': 30}
82+
},
83+
# LSD: smaller is better
84+
'LSD': {
85+
'lossless': {'base': 1.0},
86+
'lossy': {'base': 2.5, 'stress': 10.0}
87+
}
88+
}
4889

4990
class ResultStatus(enum.Enum):
5091
SUCCESS = 1
@@ -60,16 +101,16 @@ class Result:
60101
mix_presentation_id: int
61102
sub_mix_index: int
62103
layout_index: int
63-
psnr_score: Optional[float] = None
104+
score: Optional[float] = None
64105
reason: Optional[str] = None
65106
iamfdec_command: Optional[str] = None
66107

67108
def log(self, status: ResultStatus):
68109
logging.debug(
69110
'%s: %s >= %s for %s',
70111
status.name,
71-
self.psnr_score,
72-
self.is_lossy,
112+
self.score,
113+
'lossy codec' if self.is_lossy else 'lossless codec',
73114
self.test_prefix,
74115
)
75116
logging.debug('')
@@ -81,11 +122,14 @@ class TestSummary:
81122
default_factory=lambda: defaultdict(list)
82123
)
83124

84-
def print_test_summary(self, csv_summary_file=None):
125+
def print_test_summary(self,
126+
csv_summary_file: str,
127+
metric: str):
85128
"""Prints test summary to console and optionally a CSV file.
86129
87130
Args:
88131
csv_summary_file: Path to CSV file to log test results.
132+
metric: Name of the metric used.
89133
"""
90134
logging.info('\n-----------------SUMMARY-----------------')
91135
for status in ResultStatus:
@@ -102,7 +146,7 @@ def print_test_summary(self, csv_summary_file=None):
102146
'Submix Index',
103147
'Layout Index',
104148
'Status',
105-
'PSNR',
149+
metric,
106150
'Is Lossy',
107151
'Reason',
108152
'Command',
@@ -115,7 +159,7 @@ def print_test_summary(self, csv_summary_file=None):
115159
item.sub_mix_index,
116160
item.layout_index,
117161
status.name,
118-
item.psnr_score if item.psnr_score is not None else '',
162+
item.score if item.score is not None else '',
119163
'lossy' if item.is_lossy else 'lossless',
120164
item.reason if item.reason is not None else '',
121165
item.iamfdec_command
@@ -148,15 +192,35 @@ def run_decoder(args, metadata):
148192
return True, cmd_str
149193

150194

151-
def run_psnr_test(args, metadata):
152-
"""Gets PSNR score, returns None if calculation fails.
195+
def get_threshold(metric: str, is_lossy: bool, is_stress_signal: bool):
196+
"""Gets threshold value.
197+
198+
Args:
199+
metric: Metric name.
200+
is_lossy: Whether the codec used is lossy or lossless.
201+
is_stress_signal: Whether the signal used is a stress signal type.
202+
203+
Returns:
204+
The threshold value.
205+
"""
206+
codec_type = 'lossy' if is_lossy else 'lossless'
207+
threshold_options = THRESHOLDS.get(metric, {}).get(codec_type,{})
208+
209+
if is_stress_signal and 'stress' in threshold_options:
210+
return threshold_options['stress']
211+
212+
return threshold_options['base']
213+
214+
215+
def compute_metrics(args, metadata):
216+
"""Gets output score, returns None if calculation fails.
153217
154218
Args:
155219
args: Command line arguments.
156220
metadata: Metadata for the test vector.
157221
158222
Returns:
159-
A tuple of (ResultStatus, reason, psnr_score).
223+
A tuple of (ResultStatus, reason, score).
160224
"""
161225
ref_file = os.path.join(
162226
args.test_file_directory, metadata.golden_wav_file_name
@@ -168,25 +232,29 @@ def run_psnr_test(args, metadata):
168232
assert os.path.exists(test_file), f'Test file {test_file} does not exist.'
169233
logging.debug('ref_file: %s', ref_file)
170234
logging.debug('test_file: %s', test_file)
235+
171236
try:
172-
raw_psnr_score = dsp_utils.calc_average_channel_psnr_wav(
173-
ref_file, test_file
174-
)
237+
score = dsp_utils.calc_score_wav(ref_file, test_file, args.metric)
175238
except ValueError as e:
176-
print(f'Failed to calculate PSNR: {e}')
177-
return ResultStatus.CRASH, 'PSNR calculation failed', None
239+
print(f'Failed to calculate {args.metric}: {e}')
240+
return ResultStatus.CRASH, f'{args.metric} calculation failed', None
178241

179-
psnr_score = 100 if raw_psnr_score == -1 else raw_psnr_score
180-
# Check if this PSNR is a pass or a fail, it depends on whether the test
242+
# Check if this score is a pass or a fail, it depends on whether the test
181243
# represents a lossy or lossless codec.
182-
logging.debug('psnr score: %s', psnr_score)
183-
threshold = (
184-
LOSSY_PSNR_THRESHOLD if metadata.is_lossy else LOSSLESS_PSNR_THRESHOLD
185-
)
186-
if psnr_score >= threshold:
187-
return ResultStatus.SUCCESS, None, psnr_score
244+
logging.debug('%s score: %s', args.metric, score)
245+
threshold = get_threshold(args.metric, metadata.is_lossy,
246+
metadata.is_stress_signal)
247+
248+
if args.metric == 'LSD':
249+
if score <= threshold:
250+
return ResultStatus.SUCCESS, None, score
251+
else:
252+
return ResultStatus.FAILURE, 'Score above threshold.', score
188253
else:
189-
return ResultStatus.FAILURE, 'PSNR score below threshold.', psnr_score
254+
if score >= threshold:
255+
return ResultStatus.SUCCESS, None, score
256+
else:
257+
return ResultStatus.FAILURE, 'Score below threshold.', score
190258

191259

192260
def _is_excluded(metadata, exclusions, args):
@@ -241,17 +309,17 @@ def run_tests(args, text_proto_files) -> TestSummary:
241309
cleanup_after_decode = (
242310
generated_file_is_new and not args.preserve_output_files
243311
)
244-
status, reason, psnr_score, cmd_str = None, None, None, None
312+
status, reason, score, cmd_str = None, None, None, None
245313
if skip_reason := _is_excluded(metadata, exclusions, args):
246314
# Test was intentionally excluded.
247315
reason = skip_reason
248316
status = ResultStatus.SKIPPED
249317
else:
250318
decoder_success, cmd_str = run_decoder(args, metadata)
251319
if decoder_success:
252-
# Run the PSNR test, this could crash, or be better than or worse than
253-
# the threshold PSNR.
254-
status, reason, psnr_score = run_psnr_test(args, metadata)
320+
# Compute the score, this could crash, or be better than or worse than
321+
# the threshold score.
322+
status, reason, score = compute_metrics(args, metadata)
255323
else:
256324
# Decoder crashed.
257325
reason = 'iamfdec crash'
@@ -264,7 +332,7 @@ def run_tests(args, text_proto_files) -> TestSummary:
264332
metadata.mix_presentation_id,
265333
metadata.sub_mix_index,
266334
metadata.layout_index,
267-
psnr_score,
335+
score,
268336
reason,
269337
iamfdec_command=cmd_str,
270338
)
@@ -308,6 +376,13 @@ def main():
308376
action='store_true',
309377
help='Enable testing binaural layouts.',
310378
)
379+
parser.add_argument(
380+
'-m',
381+
'--metric',
382+
default='PSNR',
383+
choices=['PSNR', 'LSD'],
384+
help='The metric to use: PSNR or LSD.',
385+
)
311386
args = parser.parse_args()
312387

313388
logging.basicConfig(
@@ -329,7 +404,7 @@ def main():
329404
text_proto_files.sort()
330405

331406
test_summary = run_tests(args, text_proto_files)
332-
test_summary.print_test_summary(args.csv_summary_file)
407+
test_summary.print_test_summary(args.csv_summary_file, args.metric)
333408

334409

335410
if __name__ == '__main__':

0 commit comments

Comments
 (0)