Skip to content
This repository was archived by the owner on Nov 23, 2023. It is now read-only.

Commit 0014690

Browse files
patch from exprnn branch
1 parent 5302653 commit 0014690

4 files changed

Lines changed: 90 additions & 83 deletions

File tree

notepredictor/notepredictor/model.py

Lines changed: 10 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -2,88 +2,12 @@
22
from torch import nn
33
import torch.nn.functional as F
44

5-
# import geotorch
6-
7-
# shim torch RNN,GRU classes to have same API as LSTM
8-
def rnn_shim(cls):
9-
"""LSTM API for GRU and RNN.
10-
11-
hidden state is first element of state tuple"""
12-
class shim(cls):
13-
def forward(self, input, states):
14-
assert len(states)==1
15-
out, h = super().forward(input, *states)
16-
return out, (h,)
17-
return shim
18-
19-
GRU = rnn_shim(nn.GRU)
20-
RNN = rnn_shim(nn.RNN)
21-
LSTM = nn.LSTM
22-
23-
class ExpRNN(nn.Module):
24-
pass
25-
26-
# class ExpRNNCell(nn.Module):
27-
# pass
28-
29-
class GenericRNN(nn.Module):
30-
kind_cls = {
31-
'gru':GRU,
32-
'lstm':LSTM,
33-
'elman':RNN,
34-
'exprnn':ExpRNN
35-
}
36-
# desiderata:
37-
# support geotorch constraints
38-
# clean API for multiple layers, multiple cell states (e.g. LSTM)
39-
def __init__(self, kind, *a, **kw):
40-
super().__init__()
41-
if kw.get('bidirectional'): raise ValueError("""
42-
bidirectional GenericRNN not supported.
43-
""")
44-
cls = GenericRNN.kind_cls[kind]
45-
self.rnn = cls(*a, **kw)
46-
47-
def __getattr__(self, a):
48-
try:
49-
return super().__getattr__(a)
50-
except AttributeError:
51-
return getattr(self.rnn, a)
52-
53-
def forward(self, x, initial_state):
54-
"""
55-
Args:
56-
x: Tensor[batch x time x channel] if batch_first else [time x batch x channel]
57-
initial_state: List[Tensor[layers x batch x hidden]]], list of components
58-
with 0 being hidden state (e.g. 1 is cell state for LSTM).
59-
Returns:
60-
hidden: hidden states of top layers Tensor[batch x time x hidden]
61-
or [time x batch x hidden]
62-
new_states: List[Tensor[layers x batch x hidden]]
63-
"""
64-
hidden, final_state = self.rnn.forward(x, initial_state) #forward or __call__?
65-
return hidden, final_state
66-
67-
## NOTE: individual time-step API might be useful, not actually needed yet though
68-
# def step(self, x, state):
69-
# """
70-
# Args:
71-
# x: Tensor[batch x channel]
72-
# state: List[Tensor[layers x batch x hidden]]], list of components
73-
# with 0 being hidden state (e.g. 1 is cell state for LSTM).
74-
# Returns:
75-
# hidden: hidden state of top layer [batch x hidden]
76-
# new_states: List[Tensor[layers x batch x hidden]]
77-
# """
78-
# time_idx = 1 if self.rnn.batch_first else 0
79-
# x = x.unsqueeze(time_idx)
80-
# hidden, state = self.forward(x, state)
81-
# return hidden.squeeze(time_idx), state
82-
5+
from .rnn import GenericRNN
836

847
class PitchPredictor(nn.Module):
858
# note: use named arguments only for benefit of training script
86-
def __init__(self, emb_size=128, hidden_size=512, domain_size=128, num_layers=1):
9+
def __init__(self, emb_size=128, hidden_size=512, domain_size=128,
10+
num_layers=1, kind='gru', dropout=0):
8711
"""
8812
"""
8913
super().__init__()
@@ -93,14 +17,18 @@ def __init__(self, emb_size=128, hidden_size=512, domain_size=128, num_layers=1)
9317

9418
self.emb = nn.Embedding(domain_size, emb_size)
9519
self.proj = nn.Linear(hidden_size, domain_size)
20+
#### DEBUG
21+
with torch.no_grad():
22+
self.proj.weight.mul_(1e-2)
9623

97-
self.rnn = GenericRNN('gru', emb_size, hidden_size,
98-
num_layers=num_layers, batch_first=True)
24+
self.rnn = GenericRNN(kind, emb_size, hidden_size,
25+
num_layers=num_layers, batch_first=True, dropout=dropout)
9926

10027
# learnable initial state
10128
self.initial_state = nn.ParameterList([
10229
# layer x batch x hidden
103-
nn.Parameter(torch.randn(num_layers,1,hidden_size)*hidden_size**-0.5),
30+
nn.Parameter(torch.randn(num_layers,1,hidden_size)*hidden_size**-0.5)
31+
for _ in range(2 if kind=='lstm' else 1)
10432
])
10533

10634
# persistent state for inference

notepredictor/notepredictor/rnn.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import math
2+
3+
import torch
4+
from torch import nn
5+
import torch.nn.functional as F
6+
7+
class ExpRNN(nn.Module):
8+
def __init__(self, input_size, hidden_size, **kw):
9+
raise NotImplementedError("see `exprnn` branch")
10+
11+
def rnn_shim(cls):
12+
"""LSTM API for GRU and RNN.
13+
14+
hidden state is first element of state tuple"""
15+
class shim(cls):
16+
def forward(self, input, states=(None,)):
17+
assert len(states)==1
18+
out, h = super().forward(input, *states)
19+
return out, (h,)
20+
return shim
21+
22+
GRU = rnn_shim(nn.GRU)
23+
RNN = rnn_shim(nn.RNN)
24+
LSTM = nn.LSTM
25+
26+
27+
class GenericRNN(nn.Module):
28+
kind_cls = {
29+
'gru':GRU,
30+
'lstm':LSTM,
31+
'elman':RNN,
32+
'exprnn':ExpRNN
33+
}
34+
def __init__(self, kind, *a, **kw):
35+
super().__init__()
36+
if kw.get('bidirectional'): raise ValueError("""
37+
bidirectional GenericRNN not supported.
38+
""")
39+
cls = GenericRNN.kind_cls[kind]
40+
self.kind = kind
41+
self.rnn = cls(*a, **kw)
42+
43+
def __getattr__(self, a):
44+
try:
45+
return super().__getattr__(a)
46+
except AttributeError:
47+
return getattr(self.rnn, a)
48+
49+
def forward(self, x, initial_state):
50+
"""
51+
Args:
52+
x: Tensor[batch x time x channel] if batch_first else [time x batch x channel]
53+
initial_state: List[Tensor[layers x batch x hidden]]], list of components
54+
with 0 being hidden state (e.g. 1 is cell state for LSTM).
55+
Returns:
56+
hidden: hidden states of top layers Tensor[batch x time x hidden]
57+
or [time x batch x hidden]
58+
new_states: List[Tensor[layers x batch x hidden]]
59+
"""
60+
hidden, final_state = self.rnn.forward(x, initial_state) #forward or __call__?
61+
return hidden, final_state
62+
63+
## NOTE: individual time-step API might be useful, not actually needed yet though
64+
# def step(self, x, state):
65+
# """
66+
# Args:
67+
# x: Tensor[batch x channel]
68+
# state: List[Tensor[layers x batch x hidden]]], list of components
69+
# with 0 being hidden state (e.g. 1 is cell state for LSTM).
70+
# Returns:
71+
# hidden: hidden state of top layer [batch x hidden]
72+
# new_states: List[Tensor[layers x batch x hidden]]
73+
# """
74+
# time_idx = 1 if self.rnn.batch_first else 0
75+
# x = x.unsqueeze(time_idx)
76+
# hidden, state = self.forward(x, state)
77+
# return hidden.squeeze(time_idx), state

pytorch-osc/pytorch-osc.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def predictor_handler(address, *args):
2020
print(f"/load {args}")
2121
global predictor
2222
predictor = PitchPredictor.from_checkpoint(*args)
23+
predictor.eval()
2324

2425
elif(address[2] == "predict"):
2526
print(f"/predict {args}")
@@ -60,6 +61,7 @@ def main(ip="127.0.0.1", send=57120, receive=9999, checkpoint=None):
6061

6162
if checkpoint is not None:
6263
predictor = PitchPredictor.from_checkpoint(checkpoint)
64+
predictor.eval()
6365

6466
asyncio.run(init_main())
6567

scripts/train_pitch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(self,
3030
adam_eps = 1e-08,
3131
weight_decay = 0.01,
3232
seed = 0, # random seed
33-
n_jobs = 0, # for dataloaders
33+
n_jobs = 1, # for dataloaders
3434
device = 'cpu', # 'cuda:0'
3535
epoch_size = None, # in iterations, None for whole dataset
3636
):

0 commit comments

Comments
 (0)