Skip to content

Commit b40aed9

Browse files
authored
Bug fixes for offline mode
1 parent ccbd69c commit b40aed9

3 files changed

Lines changed: 67 additions & 30 deletions

File tree

controllable_talknet.py

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,14 @@
3636
from denoiser import Denoiser
3737

3838
app = JupyterDash(__name__)
39-
UPLOAD_DIRECTORY = "/content"
39+
RUN_PATH = os.path.dirname(os.path.realpath(__file__))
40+
if RUN_PATH == "/content":
41+
UI_MODE = "colab"
42+
else:
43+
UI_MODE = "offline"
4044
torch.set_grad_enabled(False)
4145

46+
app.title = "Controllable TalkNet"
4247
app.layout = html.Div(
4348
children=[
4449
html.H1(
@@ -89,7 +94,7 @@
8994
},
9095
),
9196
html.Label(
92-
"Upload reference audio to " + UPLOAD_DIRECTORY,
97+
"Upload reference audio to " + RUN_PATH,
9398
htmlFor="reference-dropdown",
9499
),
95100
dcc.Store(id="current-f0s"),
@@ -453,27 +458,23 @@ def preprocess_tokens(tokens, blank):
453458

454459

455460
def get_duration(wav_name, transcript):
456-
if not os.path.exists(os.path.join(UPLOAD_DIRECTORY, "output")):
457-
os.mkdir(os.path.join(UPLOAD_DIRECTORY, "output"))
461+
if not os.path.exists(os.path.join(RUN_PATH, "temp")):
462+
os.mkdir(os.path.join(RUN_PATH, "temp"))
458463
if "_" not in transcript:
459464
generate_json(
460-
os.path.join(UPLOAD_DIRECTORY, "output", wav_name + "_conv.wav")
465+
os.path.join(RUN_PATH, "temp", wav_name + "_conv.wav")
461466
+ "|"
462467
+ transcript.strip(),
463-
os.path.join(UPLOAD_DIRECTORY, "output", wav_name + ".json"),
468+
os.path.join(RUN_PATH, "temp", wav_name + ".json"),
464469
)
465470
else:
466471
generate_json(
467-
os.path.join(UPLOAD_DIRECTORY, "output", wav_name + "_conv.wav")
468-
+ "|"
469-
+ "dummy",
470-
os.path.join(UPLOAD_DIRECTORY, "output", wav_name + ".json"),
472+
os.path.join(RUN_PATH, "temp", wav_name + "_conv.wav") + "|" + "dummy",
473+
os.path.join(RUN_PATH, "temp", wav_name + ".json"),
471474
)
472475

473476
data_config = {
474-
"manifest_filepath": os.path.join(
475-
UPLOAD_DIRECTORY, "output", wav_name + ".json"
476-
),
477+
"manifest_filepath": os.path.join(RUN_PATH, "temp", wav_name + ".json"),
477478
"sample_rate": 22050,
478479
"batch_size": 1,
479480
}
@@ -645,7 +646,7 @@ def update_pitch_options(value):
645646
def update_filelist(n_clicks):
646647
filelist = []
647648
supported_formats = [".wav", ".ogg", ".mp3", "flac", ".aac"]
648-
for x in sorted(os.listdir(UPLOAD_DIRECTORY)):
649+
for x in sorted(os.listdir(RUN_PATH)):
649650
if x[-4:].lower() in supported_formats:
650651
filelist.append({"label": x, "value": x})
651652
return filelist
@@ -664,18 +665,18 @@ def update_filelist(n_clicks):
664665
)
665666
def select_file(dropdown_value):
666667
if dropdown_value is not None:
667-
if not os.path.exists(os.path.join(UPLOAD_DIRECTORY, "output")):
668-
os.mkdir(os.path.join(UPLOAD_DIRECTORY, "output"))
669-
ffmpeg.input(os.path.join(UPLOAD_DIRECTORY, dropdown_value)).output(
670-
os.path.join(UPLOAD_DIRECTORY, "output", dropdown_value + "_conv.wav"),
668+
if not os.path.exists(os.path.join(RUN_PATH, "temp")):
669+
os.mkdir(os.path.join(RUN_PATH, "temp"))
670+
ffmpeg.input(os.path.join(RUN_PATH, dropdown_value)).output(
671+
os.path.join(RUN_PATH, "temp", dropdown_value + "_conv.wav"),
671672
ar="22050",
672673
ac="1",
673674
acodec="pcm_s16le",
674675
map_metadata="-1",
675676
fflags="+bitexact",
676677
).overwrite_output().run(quiet=True)
677678
fo_with_silence, f0_wo_silence = crepe_f0(
678-
os.path.join(UPLOAD_DIRECTORY, "output", dropdown_value + "_conv.wav")
679+
os.path.join(RUN_PATH, "temp", dropdown_value + "_conv.wav")
679680
)
680681
return [
681682
"Analyzed " + dropdown_value,
@@ -724,25 +725,25 @@ def download_model(model, custom_model):
724725
drive_id = model
725726
if drive_id == "" or drive_id is None:
726727
return ("Missing Drive ID", None, None)
727-
if not os.path.exists(os.path.join(UPLOAD_DIRECTORY, "models")):
728-
os.mkdir(os.path.join(UPLOAD_DIRECTORY, "models"))
729-
if not os.path.exists(os.path.join(UPLOAD_DIRECTORY, "models", drive_id)):
730-
os.mkdir(os.path.join(UPLOAD_DIRECTORY, "models", drive_id))
731-
zip_path = os.path.join(UPLOAD_DIRECTORY, "models", drive_id, "model.zip")
728+
if not os.path.exists(os.path.join(RUN_PATH, "models")):
729+
os.mkdir(os.path.join(RUN_PATH, "models"))
730+
if not os.path.exists(os.path.join(RUN_PATH, "models", drive_id)):
731+
os.mkdir(os.path.join(RUN_PATH, "models", drive_id))
732+
zip_path = os.path.join(RUN_PATH, "models", drive_id, "model.zip")
732733
gdown.download(
733734
d + drive_id,
734735
zip_path,
735736
quiet=False,
736737
)
737738
if not os.path.exists(zip_path):
738-
os.rmdir(os.path.join(UPLOAD_DIRECTORY, "models", drive_id))
739+
os.rmdir(os.path.join(RUN_PATH, "models", drive_id))
739740
return ("Model download failed", None, None)
740741
if os.stat(zip_path).st_size < 16:
741742
os.remove(zip_path)
742-
os.rmdir(os.path.join(UPLOAD_DIRECTORY, "models", drive_id))
743+
os.rmdir(os.path.join(RUN_PATH, "models", drive_id))
743744
return ("Model zip is empty", None, None)
744745
with zipfile.ZipFile(zip_path, "r") as zip_ref:
745-
zip_ref.extractall(os.path.join(UPLOAD_DIRECTORY, "models", drive_id))
746+
zip_ref.extractall(os.path.join(RUN_PATH, "models", drive_id))
746747
os.remove(zip_path)
747748

748749
# Download super-resolution HiFi-GAN
@@ -757,8 +758,8 @@ def download_model(model, custom_model):
757758

758759
return (
759760
None,
760-
os.path.join(UPLOAD_DIRECTORY, "models", drive_id, "TalkNetSpect.nemo"),
761-
os.path.join(UPLOAD_DIRECTORY, "models", drive_id, "hifiganmodel"),
761+
os.path.join(RUN_PATH, "models", drive_id, "TalkNetSpect.nemo"),
762+
os.path.join(RUN_PATH, "models", drive_id, "hifiganmodel"),
762763
)
763764
except Exception as e:
764765
return (str(e), None, None)
@@ -879,7 +880,10 @@ def generate_audio(
879880
spect = tnmodel.force_spectrogram(
880881
tokens=tokens,
881882
durs=torch.from_numpy(durs).view(1, -1).to("cuda:0"),
882-
f0=torch.FloatTensor(f0s).view(1, -1).to("cuda:0"),
883+
f0=torch.FloatTensor(f0s)
884+
.view(1, -1)
885+
.type(torch.LongTensor)
886+
.to("cuda:0"),
883887
)
884888

885889
if hifipath != hifigan_path:

requirements.txt

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
tensorflow==2.4.1
2+
tensorboard==2.4.1
3+
dash==1.21.0
4+
jupyter-dash==0.4.0
5+
psola==0.0.1
6+
wget==3.2
7+
unidecode==1.2.0
8+
pysptk==0.1.18
9+
frozendict==2.0.3
10+
torch==1.8.1+cu111
11+
torchvision==0.9.1+cu111
12+
torchaudio==0.8.1
13+
torchtext==0.9.1
14+
torch_stft==0.1.4
15+
kaldiio==2.17.2
16+
pydub==0.25.1
17+
pyannote.audio==1.1.2
18+
g2p_en==2.1.0
19+
pesq==0.0.2
20+
pystoi==0.3.3
21+
crepe==0.0.12
22+
resampy==0.2.2
23+
ffmpeg-python==0.2.0
24+
tqdm
25+
gdown==3.13.0
26+
editdistance==0.5.3

talknet_offline.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
if __name__ == '__main__':
2+
from controllable_talknet import *
3+
app.run_server(
4+
mode="external",
5+
debug=False,
6+
threaded=True,
7+
)

0 commit comments

Comments
 (0)