Skip to content

Commit 8a5abf7

Browse files
authored
Merge pull request #148 from jwcullen/decode_and_psnr
Add a script to decode and compare PSNR of the entire test suite.
2 parents e217d15 + 3ce17ae commit 8a5abf7

18 files changed

Lines changed: 917 additions & 140 deletions

tests/README.md

Lines changed: 71 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ Theses file describe metadata about the test vector to encode an
3434
- `base_test`: The recommended textproto to diff against.
3535
- Other fields refer to the OBUs and data within the test vector.
3636

37-
# Input WAV files
37+
## Input WAV files
3838

3939
Test vectors may have multiple substreams with several input .wav files. These
4040
.wav files may be shared with other test vectors. The .textproto file has a
@@ -68,7 +68,7 @@ Title | Summary
6868
`Transport_TOA_5s.wav` | Short clip of vehicles driving by using third-order ambisonics. | 16 | 48kHz | pcm_s16le | 5s
6969
`Transport_9.1.6_5s.wav` | Short clip of vehicles driving by using 9.1.6. | 16 | 48kHz | pcm_s16le | 5s
7070

71-
# Output WAV files
71+
## Output WAV files
7272

7373
Output wav files are based on the
7474
[layout](https://aomediacodec.github.io/iamf/#syntax-layout) in the mix
@@ -93,25 +93,82 @@ Sound System 12 | IAMF | C
9393
Sound System 13 | IAMF | FL, FR, FC, LFE, BL, BR, FLc, FRc, SiL, SiR, TpFL, TpFR, TpBL, TpBR, TpSiL, TpSiR
9494
Binaural Layout | IAMF | L2, R2
9595

96-
# Verification
96+
## Decode and Verification
9797

98-
For test cases using Opus or AAC codecs, the average PSNR value must exceed 30, and for the other codecs, an average PSNR value exceeding 80 is considered PASS.
99-
You can use `psnr_calc.py` file to calculate PSNR between reference and generated output.
98+
For test cases with lossy codecs (Opus or AAC), the average PSNR value must
99+
exceed 30. otherwise the average PSNR must exceed 80.
100100

101-
- How to use `psnr_calc.py` script:
102-
```
103-
python psnr_calc.py
104-
--dir <directory path containing the target file and reference file>
105-
--target <target wav file name>
106-
--ref <reference wav file name>
107-
```
101+
`run_decode_and_psnr_tests` will run the decoder for all reference test cases
102+
and compare the PSNR between all outputs.
103+
104+
Prerequisites:
105+
106+
- The path to a built `iamfdec`, usually
107+
`libiamf/code/test/tools/iamfdec/iamfdec`
108+
- `protoc`, and compiled `libiamf/code/proto` files.
109+
- A python environment with `scipy`, `protobuf`, `tqdm`, `numpy`.
110+
111+
Note that example commands below assume a working directory of `libiamf/tests`.
112+
113+
To compile the proto files run
114+
115+
`protoc -I=proto/ --python_out=proto/ proto/*.proto`
116+
117+
To set up a python environment using pip
118+
119+
```
120+
python3 -m venv venv
121+
source venv/bin/activate
122+
pip install scipy protobuf tqdm numpy librosa
123+
```
124+
125+
Run the test suite.
126+
127+
Arguments:
108128

109-
- Calculate PSNR values of multiple wav files
129+
`iamfdec_path`, full path to the built `iamfdec` tool. `test_file_directory`,
130+
full path to folder containing `.textproto` and reference `.wav` files.
131+
`working_directory`, full path to write audio files produced by `iamfdec`.
132+
`csv_summary`, optionally included, full path and filename to write a summary of
133+
test results.
134+
135+
```
136+
python3 run_decode_and_psnr_test.py --iamfdec_path /your/full/path/to/libiamf/code/test/tools/iamfdec/iamfdec --test_file_directory /your/full/path/to/libiamf/tests/ --working_directory /your/path/for/scratch/wav/files --csv_summary /your/path/to/write/summary.csv
137+
```
138+
139+
For a simple configuration, this example will dump all files to the current
140+
working directory.
141+
142+
`python3 run_decode_and_psnr_test.py --iamfdec_path
143+
../code/test/tools/iamfdec/iamfdec --test_file_directory $PWD --csv_summary
144+
$PWD/summary.csv -w $PWD`
145+
146+
Extra arguments:
147+
148+
`regex_filter`, optionally included, regex to filter output files. For example
149+
`--regex_filter="000100"` will run a single file, or
150+
`--regex_filter="0001\d{2}"` will process files in the range [test_000100,
151+
test_000199]. `verbose_test_summary`, turns on verbose logging.
152+
`--preserve_output_files`, set to keep the output generated `.wav` files,
153+
otherwise they are deleted.
154+
155+
## Verification Only
156+
157+
For test cases using Opus or AAC codecs, the average PSNR value must exceed 30,
158+
and for the other codecs, an average PSNR value exceeding 80 is considered PASS.
159+
You can use `psnr_calc.py` file to calculate PSNR between reference and
160+
generated output.
161+
162+
- How to use `psnr_calc.py` script: `python psnr_calc.py --dir <directory path
163+
containing the target file and reference file> --target <target wav file
164+
name> --ref <reference wav file name> --verbose`
165+
166+
- Calculate PSNR values of multiple wav files
110167

111168
Multiple files can be entered as `::`
112169

113170
```
114171
Example:
115172
116173
python psnr_calc.py --dir . --target target1.wav::target2.wav --ref ref1.wav::ref2.wav
117-
```
174+
```

tests/dsp_utils.py

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
"""PSNR calculation utilities."""
2+
3+
import logging
4+
import math
5+
import wave
6+
import numpy as np
7+
import scipy.io.wavfile as wavfile
8+
import librosa
9+
10+
11+
def calc_per_channel_psnr_pcm(
12+
ref_signal: np.ndarray, signal: np.ndarray, sampwidth_bytes: int
13+
):
14+
"""Calculates the PSNR between two signals.
15+
16+
Args:
17+
ref_signal: The reference signal as a numpy array.
18+
signal: The signal to compare as a numpy array.
19+
sampwidth_bytes: The sample width in bytes (e.g. 2 for 16-bit, 3 for
20+
24-bit).
21+
22+
Returns:
23+
The per channel PSNR in dB.
24+
"""
25+
assert (
26+
sampwidth_bytes > 1
27+
), "Supports sample format: [pcm_s16le, pcm_s24le, pcm_s32le]"
28+
max_value = pow(2, sampwidth_bytes * 8) - 1
29+
30+
# To prevent overflow
31+
ref_signal = ref_signal.astype("int64")
32+
signal = signal.astype("int64")
33+
34+
mse = np.mean((ref_signal - signal) ** 2, axis=0, dtype="float64")
35+
36+
psnr_list = list()
37+
38+
# To support mono signal
39+
num_channels = 1 if ref_signal.shape[1:] == () else ref_signal.shape[1]
40+
for i in range(num_channels):
41+
mse_value = mse[i] if num_channels > 1 else mse
42+
if mse_value == 0:
43+
psnr_list.append(np.inf)
44+
logging.debug("ch#%d PSNR: inf", i)
45+
else:
46+
psnr_value = 10 * math.log10(max_value**2 / mse_value)
47+
psnr_list.append(psnr_value)
48+
logging.debug("ch#%d PSNR: %f dB", i, psnr_value)
49+
50+
return psnr_list
51+
52+
53+
def calc_per_channel_lsd_pcm(ref_signal: np.ndarray,
54+
signal: np.ndarray,
55+
sampling_rate: int):
56+
"""Calculates the log spectral distance using Mel bins between two signals.
57+
58+
Args:
59+
ref_signal: The reference signal as a numpy array.
60+
signal: The signal to compare as a numpy array.
61+
sampling rate: The sampling rate of the signals in Hz.
62+
63+
Returns:
64+
The per channel log spectral distance in dB.
65+
"""
66+
eps = 1e-4
67+
68+
# Convert to float
69+
ref_signal = ref_signal / np.iinfo(ref_signal.dtype).max
70+
signal = signal / np.iinfo(signal.dtype).max
71+
72+
lsd_list = list()
73+
74+
# To support mono channel
75+
num_channels = 1 if ref_signal.shape[1:] == () else ref_signal.shape[1]
76+
for i in range(num_channels):
77+
ref_channel = ref_signal[:, i] if num_channels > 1 else ref_signal
78+
signal_channel = signal[:, i] if num_channels > 1 else signal
79+
80+
lsd_frames = list()
81+
82+
# Compute mel spectrogram
83+
mel_ref = librosa.feature.melspectrogram(y=ref_channel, sr=sampling_rate)
84+
mel_signal = librosa.feature.melspectrogram(y=signal_channel,
85+
sr=sampling_rate)
86+
87+
log_mel_ref = 10 * np.log10(mel_ref + eps)
88+
log_mel_signal = 10 * np.log10(mel_signal + eps)
89+
90+
diff_squared = (log_mel_ref - log_mel_signal) ** 2
91+
92+
# Average across mel bins, which is the 0th dimension
93+
lsd_per_frame = np.sqrt(np.mean(diff_squared, axis=0))
94+
95+
# shape: (1, num_frames) -> (num_frames,)
96+
lsd_per_frame = np.squeeze(lsd_per_frame)
97+
98+
lsd_value = np.mean(lsd_per_frame)
99+
lsd_list.append(lsd_value)
100+
logging.debug('ch#d LSD: %f dB', i, lsd_value)
101+
102+
return lsd_list
103+
104+
105+
def calc_score_wav(ref_filepath: str, target_filepath: str, metric: str):
106+
"""Calculates the score between two WAV files.
107+
108+
Args:
109+
ref_filepath: Path to the reference WAV file.
110+
target_filepath: Path to the target WAV file to compare.
111+
metric: one of 'PSNR' or 'SNR'.
112+
113+
Returns:
114+
The score in dB, averaged over all channels.
115+
116+
Raises:
117+
Exception: If the wav files have different samplerate, channels, bit-depth
118+
or number of samples.
119+
"""
120+
ref_wav = wave.open(ref_filepath, "rb")
121+
target_wav = wave.open(target_filepath, "rb")
122+
123+
# Check sampling rate
124+
if ref_wav.getframerate() != target_wav.getframerate():
125+
raise ValueError(
126+
"Sampling rate of reference file and comparison file are different:"
127+
f" {ref_filepath} vs {target_filepath}"
128+
)
129+
130+
# Check number of channels
131+
if ref_wav.getnchannels() != target_wav.getnchannels():
132+
raise ValueError(
133+
"Number of channels of reference file and comparison file are"
134+
f" different: {ref_filepath} vs {target_filepath}"
135+
)
136+
137+
# Check number of samples
138+
if ref_wav.getnframes() != target_wav.getnframes():
139+
raise ValueError(
140+
"Number of samples of reference file and comparison file are different:"
141+
f" {ref_filepath} vs {target_filepath}"
142+
)
143+
144+
# Check bit depth
145+
if ref_wav.getsampwidth() != target_wav.getsampwidth():
146+
raise ValueError(
147+
"Bit depth of reference file and comparison file are different:"
148+
f" {ref_filepath} vs {target_filepath}"
149+
)
150+
151+
# Open wav as a np array
152+
_, ref_data = wavfile.read(ref_filepath)
153+
_, target_data = wavfile.read(target_filepath)
154+
155+
if metric == 'PSNR':
156+
scores_list = calc_per_channel_psnr_pcm(
157+
ref_data, target_data, ref_wav.getsampwidth()
158+
)
159+
elif metric == 'LSD':
160+
scores_list = calc_per_channel_lsd_pcm(ref_data, target_data,
161+
ref_wav.getframerate())
162+
else:
163+
return None
164+
165+
return np.mean(scores_list)

0 commit comments

Comments
 (0)