Skip to content

Commit 497f9dd

Browse files
committed
update
1 parent 9bacb26 commit 497f9dd

3 files changed

Lines changed: 3 additions & 6 deletions

File tree

0 Bytes
Binary file not shown.

inference.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,9 @@ def generate(output_directory, tensorboard_directory,
7777
end = torch.cuda.Event(enable_timing=True)
7878
start.record()
7979

80-
# inference
8180
generated_audio = sampling(net, (1,1,audio_length),
8281
diffusion_hyperparams,
83-
condition=ground_truth_mel_spectrogram,
84-
print_every_n_steps=diffusion_config["T"] // 5)
82+
condition=ground_truth_mel_spectrogram)
8583

8684
end.record()
8785
torch.cuda.synchronize()

util.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def calc_diffusion_hyperparams(T, beta_0, beta_T):
123123
return diffusion_hyperparams
124124

125125

126-
def sampling(net, size, diffusion_hyperparams, condition=None, print_every_n_steps=100):
126+
def sampling(net, size, diffusion_hyperparams, condition=None):
127127
"""
128128
Perform the complete sampling step according to p(x_0|x_T) = \prod_{t=1}^T p_{\theta}(x_{t-1}|x_t)
129129
@@ -134,8 +134,7 @@ def sampling(net, size, diffusion_hyperparams, condition=None, print_every_n_ste
134134
diffusion_hyperparams (dict): dictionary of diffusion hyperparameters returned by calc_diffusion_hyperparams
135135
note, the tensors need to be cuda tensors
136136
condition (torch.tensor): ground truth mel spectrogram read from disk
137-
None if used for unconditional generation
138-
print_every_n_steps (int): print status every this number of reverse steps
137+
None if used for unconditional generation
139138
140139
Returns:
141140
the generated audio(s) in torch.tensor, shape=size

0 commit comments

Comments
 (0)