This repository was archived by the owner on Feb 12, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 34
Expand file tree
/
Copy pathpredict.py
More file actions
99 lines (73 loc) · 2.19 KB
/
predict.py
File metadata and controls
99 lines (73 loc) · 2.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import numpy as np
import argparse
import os
import keras
import sklearn
import librosa
from keras import backend as K
eps = np.finfo(np.float).eps
def class_mae(y_true, y_pred):
return K.mean(
K.abs(
K.argmax(y_pred, axis=-1) - K.argmax(y_true, axis=-1)
),
axis=-1
)
def load_scaler():
scaler = sklearn.preprocessing.StandardScaler()
with np.load(os.path.join("models", 'scaler.npz')) as data:
scaler.mean_ = data['arr_0']
scaler.scale_ = data['arr_1']
return scaler
def load_model(model_name):
path = os.path.join('models', model_name + '.h5')
return keras.models.load_model(
path,
custom_objects={
'class_mae': class_mae,
'exp': K.exp
}
)
def count(audio, model, scaler):
# compute STFT
X = np.abs(librosa.stft(audio, n_fft=400, hop_length=160)).T
# apply global (featurewise) standardization to mean1, var0
X = scaler.transform(X)
# cut to input shape length (500 frames x 201 STFT bins)
X = X[:500, :]
# apply l2 normalization
Theta = np.linalg.norm(X, axis=1) + eps
X /= np.mean(Theta)
# add sample dimension
X = X[np.newaxis, ...]
if len(model.input_shape) == 4:
X = X[:, np.newaxis, ...]
ys = model.predict(X, verbose=0)
return np.argmax(ys, axis=1)[0]
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Load keras model and predict speaker count'
)
parser.add_argument(
'audio',
help='audio file (samplerate 16 kHz) of 5 seconds duration',
nargs='+',
)
parser.add_argument(
'--model', default='CRNN',
help='model name'
)
parser.add_argument('--print-summary', action='store_true')
args = parser.parse_args()
# load model
model = load_model(args.model)
if args.print_summary:
# print model configuration
model.summary()
# load standardisation parameters
scaler = load_scaler()
for f in args.audio:
# compute audio
audio = librosa.load(f, sr=16000)[0]
estimate = count(audio, model, scaler)
print("Speaker Count Estimate:", f, estimate)