Skip to content

Commit bcfd691

Browse files
committed
add stride trick to framesig and write corresponding tests.
Running time is reduces by a factor of ~4 for a long signal
1 parent a350d7a commit bcfd691

2 files changed

Lines changed: 80 additions & 31 deletions

File tree

python_speech_features/sigproc.py

Lines changed: 49 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,19 @@
66
import math
77
import logging
88

9+
910
def round_half_up(number):
1011
return int(decimal.Decimal(number).quantize(decimal.Decimal('1'), rounding=decimal.ROUND_HALF_UP))
1112

1213

13-
def framesig(sig,frame_len,frame_step,winfunc=lambda x:numpy.ones((x,))):
14+
def rolling_window(a, window, step=1):
15+
# http://ellisvalentiner.com/post/2017-03-21-np-strides-trick
16+
shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
17+
strides = a.strides + (a.strides[-1],)
18+
return numpy.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)[::step]
19+
20+
21+
def framesig(sig, frame_len, frame_step, winfunc=lambda x: numpy.ones((x,)), stride_trick=True):
1422
"""Frame a signal into overlapping frames.
1523
1624
:param sig: the audio signal to frame.
@@ -25,21 +33,26 @@ def framesig(sig,frame_len,frame_step,winfunc=lambda x:numpy.ones((x,))):
2533
if slen <= frame_len:
2634
numframes = 1
2735
else:
28-
numframes = 1 + int(math.ceil((1.0*slen - frame_len)/frame_step))
36+
numframes = 1 + int(math.ceil((1.0 * slen - frame_len) / frame_step))
2937

30-
padlen = int((numframes-1)*frame_step + frame_len)
38+
padlen = int((numframes - 1) * frame_step + frame_len)
3139

3240
zeros = numpy.zeros((padlen - slen,))
33-
padsignal = numpy.concatenate((sig,zeros))
41+
padsignal = numpy.concatenate((sig, zeros))
42+
if stride_trick:
43+
win = winfunc(frame_len)
44+
frames = rolling_window(padsignal, window=frame_len, step=frame_step)
45+
else:
46+
indices = numpy.tile(numpy.arange(0, frame_len), (numframes, 1)) + numpy.tile(
47+
numpy.arange(0, numframes * frame_step, frame_step), (frame_len, 1)).T
48+
indices = numpy.array(indices, dtype=numpy.int32)
49+
frames = padsignal[indices]
50+
win = numpy.tile(winfunc(frame_len), (numframes, 1))
3451

35-
indices = numpy.tile(numpy.arange(0,frame_len),(numframes,1)) + numpy.tile(numpy.arange(0,numframes*frame_step,frame_step),(frame_len,1)).T
36-
indices = numpy.array(indices,dtype=numpy.int32)
37-
frames = padsignal[indices]
38-
win = numpy.tile(winfunc(frame_len),(numframes,1))
39-
return frames*win
52+
return frames * win
4053

4154

42-
def deframesig(frames,siglen,frame_len,frame_step,winfunc=lambda x:numpy.ones((x,))):
55+
def deframesig(frames, siglen, frame_len, frame_step, winfunc=lambda x: numpy.ones((x,))):
4356
"""Does overlap-add procedure to undo the action of framesig.
4457
4558
:param frames: the array of frames.
@@ -54,68 +67,73 @@ def deframesig(frames,siglen,frame_len,frame_step,winfunc=lambda x:numpy.ones((x
5467
numframes = numpy.shape(frames)[0]
5568
assert numpy.shape(frames)[1] == frame_len, '"frames" matrix is wrong size, 2nd dim is not equal to frame_len'
5669

57-
indices = numpy.tile(numpy.arange(0,frame_len),(numframes,1)) + numpy.tile(numpy.arange(0,numframes*frame_step,frame_step),(frame_len,1)).T
58-
indices = numpy.array(indices,dtype=numpy.int32)
59-
padlen = (numframes-1)*frame_step + frame_len
70+
indices = numpy.tile(numpy.arange(0, frame_len), (numframes, 1)) + numpy.tile(
71+
numpy.arange(0, numframes * frame_step, frame_step), (frame_len, 1)).T
72+
indices = numpy.array(indices, dtype=numpy.int32)
73+
padlen = (numframes - 1) * frame_step + frame_len
6074

6175
if siglen <= 0: siglen = padlen
6276

6377
rec_signal = numpy.zeros((padlen,))
6478
window_correction = numpy.zeros((padlen,))
6579
win = winfunc(frame_len)
6680

67-
for i in range(0,numframes):
68-
window_correction[indices[i,:]] = window_correction[indices[i,:]] + win + 1e-15 #add a little bit so it is never zero
69-
rec_signal[indices[i,:]] = rec_signal[indices[i,:]] + frames[i,:]
81+
for i in range(0, numframes):
82+
window_correction[indices[i, :]] = window_correction[
83+
indices[i, :]] + win + 1e-15 # add a little bit so it is never zero
84+
rec_signal[indices[i, :]] = rec_signal[indices[i, :]] + frames[i, :]
7085

71-
rec_signal = rec_signal/window_correction
86+
rec_signal = rec_signal / window_correction
7287
return rec_signal[0:siglen]
7388

74-
def magspec(frames,NFFT):
89+
90+
def magspec(frames, NFFT):
7591
"""Compute the magnitude spectrum of each frame in frames. If frames is an NxD matrix, output will be Nx(NFFT/2+1).
7692
7793
:param frames: the array of frames. Each row is a frame.
7894
:param NFFT: the FFT length to use. If NFFT > frame_len, the frames are zero-padded.
7995
:returns: If frames is an NxD matrix, output will be Nx(NFFT/2+1). Each row will be the magnitude spectrum of the corresponding frame.
8096
"""
8197
if numpy.shape(frames)[1] > NFFT:
82-
logging.warn('frame length (%d) is greater than FFT size (%d), frame will be truncated. Increase NFFT to avoid.', numpy.shape(frames)[1], NFFT)
83-
complex_spec = numpy.fft.rfft(frames,NFFT)
98+
logging.warn(
99+
'frame length (%d) is greater than FFT size (%d), frame will be truncated. Increase NFFT to avoid.',
100+
numpy.shape(frames)[1], NFFT)
101+
complex_spec = numpy.fft.rfft(frames, NFFT)
84102
return numpy.absolute(complex_spec)
85103

86-
def powspec(frames,NFFT):
104+
105+
def powspec(frames, NFFT):
87106
"""Compute the power spectrum of each frame in frames. If frames is an NxD matrix, output will be Nx(NFFT/2+1).
88107
89108
:param frames: the array of frames. Each row is a frame.
90109
:param NFFT: the FFT length to use. If NFFT > frame_len, the frames are zero-padded.
91110
:returns: If frames is an NxD matrix, output will be Nx(NFFT/2+1). Each row will be the power spectrum of the corresponding frame.
92111
"""
93-
return 1.0/NFFT * numpy.square(magspec(frames,NFFT))
112+
return 1.0 / NFFT * numpy.square(magspec(frames, NFFT))
94113

95-
def logpowspec(frames,NFFT,norm=1):
114+
115+
def logpowspec(frames, NFFT, norm=1):
96116
"""Compute the log power spectrum of each frame in frames. If frames is an NxD matrix, output will be Nx(NFFT/2+1).
97117
98118
:param frames: the array of frames. Each row is a frame.
99119
:param NFFT: the FFT length to use. If NFFT > frame_len, the frames are zero-padded.
100120
:param norm: If norm=1, the log power spectrum is normalised so that the max value (across all frames) is 0.
101121
:returns: If frames is an NxD matrix, output will be Nx(NFFT/2+1). Each row will be the log power spectrum of the corresponding frame.
102122
"""
103-
ps = powspec(frames,NFFT);
104-
ps[ps<=1e-30] = 1e-30
105-
lps = 10*numpy.log10(ps)
123+
ps = powspec(frames, NFFT);
124+
ps[ps <= 1e-30] = 1e-30
125+
lps = 10 * numpy.log10(ps)
106126
if norm:
107127
return lps - numpy.max(lps)
108128
else:
109129
return lps
110130

111-
def preemphasis(signal,coeff=0.95):
131+
132+
def preemphasis(signal, coeff=0.95):
112133
"""perform preemphasis on the input signal.
113134
114135
:param signal: The signal to filter.
115136
:param coeff: The preemphasis coefficient. 0 is no filter, default is 0.95.
116137
:returns: the filtered signal.
117138
"""
118-
return numpy.append(signal[0],signal[1:]-coeff*signal[:-1])
119-
120-
121-
139+
return numpy.append(signal[0], signal[1:] - coeff * signal[:-1])

test/test_sigproc.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from python_speech_features import sigproc
2+
import unittest
3+
import numpy as np
4+
import time
5+
6+
7+
class test_case(unittest.TestCase):
8+
def test_frame_sig(self):
9+
n = 10000124
10+
frame_len = 37
11+
frame_step = 13
12+
x = np.random.rand(n)
13+
t0 = time.time()
14+
y_old = sigproc.framesig(x, frame_len=frame_len, frame_step=frame_step, stride_trick=False)
15+
t1 = time.time()
16+
y_new = sigproc.framesig(x, frame_len=frame_len, frame_step=frame_step, stride_trick=True)
17+
t_new = time.time() - t1
18+
t_old = t1 - t0
19+
self.assertTupleEqual(y_old.shape, y_new.shape)
20+
np.testing.assert_array_equal(y_old, y_new)
21+
self.assertLess(t_new, t_old)
22+
print('new run time %3.2f < %3.2f sec' % (t_new, t_old))
23+
24+
def test_rolling(self):
25+
x = np.arange(10)
26+
y = sigproc.rolling_window(x, window=4, step=3)
27+
y_expected = np.array([[0, 1, 2, 3],
28+
[3, 4, 5, 6],
29+
[6, 7, 8, 9]]
30+
)
31+
y = np.testing.assert_array_equal(y, y_expected)

0 commit comments

Comments
 (0)