3636from denoiser import Denoiser
3737
3838app = JupyterDash (__name__ )
39- UPLOAD_DIRECTORY = "/content"
39+ RUN_PATH = os .path .dirname (os .path .realpath (__file__ ))
40+ if RUN_PATH == "/content" :
41+ UI_MODE = "colab"
42+ else :
43+ UI_MODE = "offline"
4044torch .set_grad_enabled (False )
4145
46+ app .title = "Controllable TalkNet"
4247app .layout = html .Div (
4348 children = [
4449 html .H1 (
8994 },
9095 ),
9196 html .Label (
92- "Upload reference audio to " + UPLOAD_DIRECTORY ,
97+ "Upload reference audio to " + RUN_PATH ,
9398 htmlFor = "reference-dropdown" ,
9499 ),
95100 dcc .Store (id = "current-f0s" ),
@@ -453,27 +458,23 @@ def preprocess_tokens(tokens, blank):
453458
454459
455460def get_duration (wav_name , transcript ):
456- if not os .path .exists (os .path .join (UPLOAD_DIRECTORY , "output " )):
457- os .mkdir (os .path .join (UPLOAD_DIRECTORY , "output " ))
461+ if not os .path .exists (os .path .join (RUN_PATH , "temp " )):
462+ os .mkdir (os .path .join (RUN_PATH , "temp " ))
458463 if "_" not in transcript :
459464 generate_json (
460- os .path .join (UPLOAD_DIRECTORY , "output " , wav_name + "_conv.wav" )
465+ os .path .join (RUN_PATH , "temp " , wav_name + "_conv.wav" )
461466 + "|"
462467 + transcript .strip (),
463- os .path .join (UPLOAD_DIRECTORY , "output " , wav_name + ".json" ),
468+ os .path .join (RUN_PATH , "temp " , wav_name + ".json" ),
464469 )
465470 else :
466471 generate_json (
467- os .path .join (UPLOAD_DIRECTORY , "output" , wav_name + "_conv.wav" )
468- + "|"
469- + "dummy" ,
470- os .path .join (UPLOAD_DIRECTORY , "output" , wav_name + ".json" ),
472+ os .path .join (RUN_PATH , "temp" , wav_name + "_conv.wav" ) + "|" + "dummy" ,
473+ os .path .join (RUN_PATH , "temp" , wav_name + ".json" ),
471474 )
472475
473476 data_config = {
474- "manifest_filepath" : os .path .join (
475- UPLOAD_DIRECTORY , "output" , wav_name + ".json"
476- ),
477+ "manifest_filepath" : os .path .join (RUN_PATH , "temp" , wav_name + ".json" ),
477478 "sample_rate" : 22050 ,
478479 "batch_size" : 1 ,
479480 }
@@ -645,7 +646,7 @@ def update_pitch_options(value):
645646def update_filelist (n_clicks ):
646647 filelist = []
647648 supported_formats = [".wav" , ".ogg" , ".mp3" , "flac" , ".aac" ]
648- for x in sorted (os .listdir (UPLOAD_DIRECTORY )):
649+ for x in sorted (os .listdir (RUN_PATH )):
649650 if x [- 4 :].lower () in supported_formats :
650651 filelist .append ({"label" : x , "value" : x })
651652 return filelist
@@ -664,18 +665,18 @@ def update_filelist(n_clicks):
664665)
665666def select_file (dropdown_value ):
666667 if dropdown_value is not None :
667- if not os .path .exists (os .path .join (UPLOAD_DIRECTORY , "output " )):
668- os .mkdir (os .path .join (UPLOAD_DIRECTORY , "output " ))
669- ffmpeg .input (os .path .join (UPLOAD_DIRECTORY , dropdown_value )).output (
670- os .path .join (UPLOAD_DIRECTORY , "output " , dropdown_value + "_conv.wav" ),
668+ if not os .path .exists (os .path .join (RUN_PATH , "temp " )):
669+ os .mkdir (os .path .join (RUN_PATH , "temp " ))
670+ ffmpeg .input (os .path .join (RUN_PATH , dropdown_value )).output (
671+ os .path .join (RUN_PATH , "temp " , dropdown_value + "_conv.wav" ),
671672 ar = "22050" ,
672673 ac = "1" ,
673674 acodec = "pcm_s16le" ,
674675 map_metadata = "-1" ,
675676 fflags = "+bitexact" ,
676677 ).overwrite_output ().run (quiet = True )
677678 fo_with_silence , f0_wo_silence = crepe_f0 (
678- os .path .join (UPLOAD_DIRECTORY , "output " , dropdown_value + "_conv.wav" )
679+ os .path .join (RUN_PATH , "temp " , dropdown_value + "_conv.wav" )
679680 )
680681 return [
681682 "Analyzed " + dropdown_value ,
@@ -724,25 +725,25 @@ def download_model(model, custom_model):
724725 drive_id = model
725726 if drive_id == "" or drive_id is None :
726727 return ("Missing Drive ID" , None , None )
727- if not os .path .exists (os .path .join (UPLOAD_DIRECTORY , "models" )):
728- os .mkdir (os .path .join (UPLOAD_DIRECTORY , "models" ))
729- if not os .path .exists (os .path .join (UPLOAD_DIRECTORY , "models" , drive_id )):
730- os .mkdir (os .path .join (UPLOAD_DIRECTORY , "models" , drive_id ))
731- zip_path = os .path .join (UPLOAD_DIRECTORY , "models" , drive_id , "model.zip" )
728+ if not os .path .exists (os .path .join (RUN_PATH , "models" )):
729+ os .mkdir (os .path .join (RUN_PATH , "models" ))
730+ if not os .path .exists (os .path .join (RUN_PATH , "models" , drive_id )):
731+ os .mkdir (os .path .join (RUN_PATH , "models" , drive_id ))
732+ zip_path = os .path .join (RUN_PATH , "models" , drive_id , "model.zip" )
732733 gdown .download (
733734 d + drive_id ,
734735 zip_path ,
735736 quiet = False ,
736737 )
737738 if not os .path .exists (zip_path ):
738- os .rmdir (os .path .join (UPLOAD_DIRECTORY , "models" , drive_id ))
739+ os .rmdir (os .path .join (RUN_PATH , "models" , drive_id ))
739740 return ("Model download failed" , None , None )
740741 if os .stat (zip_path ).st_size < 16 :
741742 os .remove (zip_path )
742- os .rmdir (os .path .join (UPLOAD_DIRECTORY , "models" , drive_id ))
743+ os .rmdir (os .path .join (RUN_PATH , "models" , drive_id ))
743744 return ("Model zip is empty" , None , None )
744745 with zipfile .ZipFile (zip_path , "r" ) as zip_ref :
745- zip_ref .extractall (os .path .join (UPLOAD_DIRECTORY , "models" , drive_id ))
746+ zip_ref .extractall (os .path .join (RUN_PATH , "models" , drive_id ))
746747 os .remove (zip_path )
747748
748749 # Download super-resolution HiFi-GAN
@@ -757,8 +758,8 @@ def download_model(model, custom_model):
757758
758759 return (
759760 None ,
760- os .path .join (UPLOAD_DIRECTORY , "models" , drive_id , "TalkNetSpect.nemo" ),
761- os .path .join (UPLOAD_DIRECTORY , "models" , drive_id , "hifiganmodel" ),
761+ os .path .join (RUN_PATH , "models" , drive_id , "TalkNetSpect.nemo" ),
762+ os .path .join (RUN_PATH , "models" , drive_id , "hifiganmodel" ),
762763 )
763764 except Exception as e :
764765 return (str (e ), None , None )
@@ -879,7 +880,10 @@ def generate_audio(
879880 spect = tnmodel .force_spectrogram (
880881 tokens = tokens ,
881882 durs = torch .from_numpy (durs ).view (1 , - 1 ).to ("cuda:0" ),
882- f0 = torch .FloatTensor (f0s ).view (1 , - 1 ).to ("cuda:0" ),
883+ f0 = torch .FloatTensor (f0s )
884+ .view (1 , - 1 )
885+ .type (torch .LongTensor )
886+ .to ("cuda:0" ),
883887 )
884888
885889 if hifipath != hifigan_path :
0 commit comments