-
Notifications
You must be signed in to change notification settings - Fork 10
Expand file tree
/
Copy pathtest_Woosh-AE.py
More file actions
60 lines (47 loc) · 1.41 KB
/
test_Woosh-AE.py
File metadata and controls
60 lines (47 loc) · 1.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# %%
import os
import time
import torch
import torchaudio
from woosh.inference.flowmatching_sampler import flowmatching_integrate
from woosh.components.autoencoders import AudioAutoEncoder
from woosh.components.base import LoadConfig
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
# %%
# Load model
COMPONENT_PATH = "checkpoints/Woosh-AE"
ae = AudioAutoEncoder(LoadConfig(path=COMPONENT_PATH))
ae.load_from_config()
ae.eval().to(device)
# %%
# Prepare inputs
batch_size = 1
# load audio sample
audio, fs = torchaudio.load("samples/810333__mokasza__glass-breaking.mp3")
audio = audio[0:1,:]
audio = audio.unsqueeze(0)
audio = audio.detach().to(device=device)
# Denoise using ldm and transform to audio with autoencoder
start_time = time.perf_counter()
x = ae.forward(audio)
audio_fake = ae.inverse(x).detach()
end_time = time.perf_counter()
print(f"Encoding/decoding took {end_time - start_time:.2f} seconds on {device}")
# Move to CPU and save outputs
audio_fake = audio_fake.cpu()
os.makedirs("outputs", exist_ok=True)
for i in range(batch_size):
max_abs_value = torch.max(torch.abs(audio_fake[i]))
normalization_factor = max_abs_value if max_abs_value > 1.0 else 1.0
scaled = audio_fake[i] / normalization_factor
torchaudio.save(
f"outputs/Woosh-AE_{i}.wav",
scaled,
sample_rate=48000,
)
# %%