@@ -91,65 +91,6 @@ def get_norm():
9191 def forward (self , x ):
9292 return self .net (x )
9393
94- # class ModalityTransformer(nn.Module):
95- # """
96- # Model joint distribution of note modalities (e.g. pitch, time, velocity).
97-
98- # This is an autoregressive Transformer model for the *internal* structure of notes.
99- # It is *not* autoregressive in time, but in modality.
100- # At training time, it executes in parallel over all timesteps and modalities, with
101- # time dependencies provided via the RNN backbone.
102-
103- # At sampling time it is called serially, one modality at a time,
104- # repeatedly at each time step.
105-
106- # Inspired by XLNet: http://arxiv.org/abs/1906.08237
107- # """
108- # def __init__(self, input_size, hidden_size, heads=4, layers=1):
109- # super().__init__()
110- # self.net = nn.TransformerDecoder(
111- # nn.TransformerDecoderLayer(
112- # input_size, heads, hidden_size, norm_first=False
113- # ), layers)
114-
115- # def forward(self, ctx, h_ctx, h_tgt):
116- # """
117- # Args:
118- # ctx: list of Tensor[batch x time x input_size], length note_dim-1
119- # these are the embedded ground truth values
120- # h_ctx: Tensor[batch x time x input_size]
121- # projection of RNN state (need something to attend to when ctx is empty)
122- # h_tgt: list of Tensor[batch x time x input_size], length note_dim
123- # these are projections of the RNN state for each target,
124- # which the Transformer will map to distribution parameters.
125- # """
126- # # explicitly broadcast
127- # h_ctx, *ctx = torch.broadcast_tensors(h_ctx, *ctx)
128- # h_ctx, *h_tgt = torch.broadcast_tensors(h_ctx, *h_tgt)
129-
130- # # h_tgt is 'target' w.r.t TransformerDecoder
131- # # h_ctx and context are 'memory'
132- # batch_size = h_ctx.shape[0]*h_ctx.shape[1]
133- # # fold time into batch, stack modes
134- # tgt = torch.stack([
135- # item.reshape(batch_size,-1)
136- # for item in h_tgt
137- # ],0)
138- # mem = torch.stack([
139- # item.reshape(batch_size,-1)
140- # for item in [h_ctx, *ctx]
141- # ],0)
142- # # now "time"(mode) x "batch"(+time) x channel
143-
144- # # generate a mask
145- # # this is both the target and memory mask
146- # # masking is such that each target can only depend on "previous" context
147- # n = len(h_tgt)
148- # mask = ~tgt.new_ones((n,n), dtype=bool).tril()
149-
150- # x = self.net(tgt, mem, mask, mask)
151- # return list(x.reshape(n, *h_ctx.shape).unbind(0))
152-
15394
15495class Notochord (nn .Module ):
15596 # note: use named arguments only for benefit of training script
0 commit comments