Skip to content

Commit 812f0ab

Browse files
authored
New pitch adjustment
1 parent 73b9837 commit 812f0ab

1 file changed

Lines changed: 57 additions & 69 deletions

File tree

controllable_talknet.py

Lines changed: 57 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@
9393
htmlFor="reference-dropdown",
9494
),
9595
dcc.Store(id="current-f0s"),
96+
dcc.Store(id="current-f0s-nosilence"),
9697
dcc.Store(id="current-filename"),
9798
dcc.Loading(
9899
id="audio-loading",
@@ -157,23 +158,30 @@
157158
dcc.Checklist(
158159
id="pitch-options",
159160
options=[
160-
# {"label": "Singing mode", "value": "pc"},
161-
# {"label": "Set pitch multiplier", "value": "pf"},
161+
{"label": "Change input pitch", "value": "pf"},
162+
{"label": "Auto-tune output", "value": "pc"},
162163
{"label": "Disable reference audio", "value": "dra"},
163164
],
164165
value=[],
165166
),
166-
dcc.Input(
167-
id="pitch-factor",
168-
type="number",
169-
value="1.0",
170-
# style={"width": "7em", "margin-left": "10px"},
171-
min=0.1,
172-
max=10.0,
173-
step=0.01,
174-
disabled=True,
167+
html.Div(
168+
[
169+
html.Label("Semitones", htmlFor="pitch-factor"),
170+
dcc.Input(
171+
id="pitch-factor",
172+
type="number",
173+
value="0",
174+
style={"width": "7em"},
175+
min=-11,
176+
max=11,
177+
step=1,
178+
disabled=True,
179+
),
180+
],
175181
style={
176-
"display": "none",
182+
"flex-direction": "column",
183+
"margin-left": "10px",
184+
"margin-bottom": "0.7em",
177185
},
178186
),
179187
],
@@ -564,7 +572,10 @@ def crepe_f0(wav_path, hop_length=256):
564572
# Hack to make f0 and mel lengths equal
565573
if len(audio) % hop_length == 0:
566574
freq_interp = np.pad(freq_interp, pad_width=[0, 1])
567-
return torch.from_numpy(freq_interp.astype(np.float32))
575+
return (
576+
torch.from_numpy(freq_interp.astype(np.float32)),
577+
torch.from_numpy(frequency.astype(np.float32)),
578+
)
568579

569580

570581
def f0_to_audio(f0s):
@@ -587,25 +598,15 @@ def f0_to_audio(f0s):
587598

588599

589600
@app.callback(
590-
[
591-
dash.dependencies.Output("custom-model", "style"),
592-
dash.dependencies.Output("pitch-options", "value"),
593-
],
601+
dash.dependencies.Output("custom-model", "style"),
594602
dash.dependencies.Input("model-dropdown", "value"),
595-
dash.dependencies.State("pitch-options", "value"),
596603
)
597-
def update_model(model, options):
604+
def update_model(model):
598605
if model is not None and model.split("|")[0] == "Custom":
599606
style = {"margin-bottom": "0.7em", "display": "block"}
600607
else:
601608
style = {"display": "none"}
602-
# new_options = options
603-
"""if model is not None:
604-
if "singing" in model.split("|")[1] and "pc" not in new_options:
605-
new_options.append("pc")
606-
elif "singing" not in model.split("|")[1] and "pc" in new_options:
607-
new_options.remove("pc")"""
608-
return [style, options]
609+
return style
609610

610611

611612
@app.callback(
@@ -654,6 +655,7 @@ def update_filelist(n_clicks):
654655
[
655656
dash.dependencies.Output("audio-loading-output", "children"),
656657
dash.dependencies.Output("current-f0s", "data"),
658+
dash.dependencies.Output("current-f0s-nosilence", "data"),
657659
dash.dependencies.Output("current-filename", "data"),
658660
],
659661
[
@@ -672,11 +674,13 @@ def select_file(dropdown_value):
672674
map_metadata="-1",
673675
fflags="+bitexact",
674676
).overwrite_output().run(quiet=True)
677+
fo_with_silence, f0_wo_silence = crepe_f0(
678+
os.path.join(UPLOAD_DIRECTORY, "output", dropdown_value + "_conv.wav")
679+
)
675680
return [
676681
"Analyzed " + dropdown_value,
677-
crepe_f0(
678-
os.path.join(UPLOAD_DIRECTORY, "output", dropdown_value + "_conv.wav")
679-
),
682+
fo_with_silence,
683+
f0_wo_silence,
680684
dropdown_value,
681685
]
682686
else:
@@ -762,7 +766,6 @@ def download_model(model, custom_model):
762766
dash.dependencies.Output("audio-out", "src"),
763767
dash.dependencies.Output("generated-info", "children"),
764768
dash.dependencies.Output("audio-out", "style"),
765-
dash.dependencies.Output("pitch-factor", "value"),
766769
dash.dependencies.Output("audio-out", "title"),
767770
],
768771
[dash.dependencies.Input("gen-button", "n_clicks")],
@@ -774,6 +777,7 @@ def download_model(model, custom_model):
774777
dash.dependencies.State("pitch-factor", "value"),
775778
dash.dependencies.State("current-filename", "data"),
776779
dash.dependencies.State("current-f0s", "data"),
780+
dash.dependencies.State("current-f0s-nosilence", "data"),
777781
],
778782
)
779783
def generate_audio(
@@ -785,27 +789,26 @@ def generate_audio(
785789
pitch_factor,
786790
wav_name,
787791
f0s,
792+
f0s_wo_silence,
788793
):
789794
global tnmodel, tnpath, tndurs, tnpitch, hifigan, h, denoiser, hifipath
790795

791796
if n_clicks is None:
792797
raise PreventUpdate
793798
if model is None:
794-
return [None, "No character selected", playback_hide, pitch_factor, None]
799+
return [None, "No character selected", playback_hide, None]
795800
if transcript is None or transcript.strip() == "":
796801
return [
797802
None,
798803
"No transcript entered",
799804
playback_hide,
800-
pitch_factor,
801805
None,
802806
]
803807
if wav_name is None and "dra" not in pitch_options:
804808
return [
805809
None,
806810
"No reference audio selected",
807811
playback_hide,
808-
pitch_factor,
809812
None,
810813
]
811814
load_error, talknet_path, hifigan_path = download_model(
@@ -816,7 +819,6 @@ def generate_audio(
816819
None,
817820
load_error,
818821
playback_hide,
819-
pitch_factor,
820822
None,
821823
]
822824

@@ -855,12 +857,18 @@ def generate_audio(
855857
None,
856858
"Model doesn't support pitch prediction",
857859
playback_hide,
858-
pitch_factor,
859860
None,
860861
]
861862
spect = tnmodel.generate_spectrogram(tokens=tokens)
862863
else:
863864
durs, arpa, t = get_duration(wav_name, transcript)
865+
866+
# Change pitch
867+
if "pf" in pitch_options:
868+
f0_factor = np.power(np.e, (0.0577623 * float(pitch_factor)))
869+
f0s = [x * f0_factor for x in f0s]
870+
f0s_wo_silence = [x * f0_factor for x in f0s_wo_silence]
871+
864872
spect = tnmodel.force_spectrogram(
865873
tokens=tokens,
866874
durs=torch.from_numpy(durs).view(1, -1).to("cuda:0"),
@@ -878,53 +886,34 @@ def generate_audio(
878886
audio_denoised.detach().cpu().numpy().reshape(-1).astype(np.int16)
879887
)
880888

881-
# Pitch correction
889+
# Auto-tuning
882890
if "pc" in pitch_options and "dra" not in pitch_options:
883-
884-
def get_f0(audio, sr):
885-
_, frequency, _, _ = crepe.predict(audio, sr, viterbi=True)
886-
return torch.from_numpy(frequency.astype(np.float32))
887-
888-
input_pitch = get_f0(audio_np, 22050)
889-
target_sr, target_audio = wavfile.read(
890-
os.path.join(UPLOAD_DIRECTORY, "output", wav_name + "_conv.wav")
891-
)
892-
target_pitch = get_f0(target_audio, target_sr)
893-
factor = torch.mean(input_pitch) / torch.mean(target_pitch)
894-
if (
895-
max(factor, float(pitch_factor)) / min(factor, float(pitch_factor))
896-
< 2.0
897-
and float(pitch_factor) != 1.0
898-
):
899-
factor = float(pitch_factor)
900-
if "pf" in pitch_options:
901-
factor = float(pitch_factor)
902-
target_pitch *= factor
903-
else:
904-
octaves = [0.125, 0.25, 0.5, 1.0, 2.0, 4.0, 8.0]
905-
nearest_octave = min(octaves, key=lambda x: abs(x - factor))
906-
target_pitch *= nearest_octave
907-
if len(target_pitch) < len(input_pitch):
891+
_, output_freq, _, _ = crepe.predict(audio_np, 22050, viterbi=True)
892+
output_pitch = torch.from_numpy(output_freq.astype(np.float32))
893+
target_pitch = torch.FloatTensor(f0s_wo_silence)
894+
factor = torch.mean(output_pitch) / torch.mean(target_pitch)
895+
896+
octaves = [0.125, 0.25, 0.5, 1.0, 2.0, 4.0, 8.0]
897+
nearest_octave = min(octaves, key=lambda x: abs(x - factor))
898+
target_pitch *= nearest_octave
899+
if len(target_pitch) < len(output_pitch):
908900
target_pitch = torch.nn.functional.pad(
909901
target_pitch,
910-
(0, list(input_pitch.shape)[0] - list(target_pitch.shape)[0]),
902+
(0, list(output_pitch.shape)[0] - list(target_pitch.shape)[0]),
911903
"constant",
912904
0,
913905
)
914-
if len(target_pitch) > len(input_pitch):
915-
target_pitch = target_pitch[0 : list(input_pitch.shape)[0]]
906+
if len(target_pitch) > len(output_pitch):
907+
target_pitch = target_pitch[0 : list(output_pitch.shape)[0]]
916908

917909
audio_np = psola.vocode(
918910
audio_np, 22050, target_pitch=target_pitch
919911
).astype(np.float32)
920912
normalize = (1.0 / np.max(np.abs(audio_np))) ** 0.9
921913
audio_np = audio_np * normalize * MAX_WAV_VALUE
922914
audio_np = audio_np.astype(np.int16)
923-
else:
924-
factor = pitch_factor
925915

926916
# Resample to 32k
927-
928917
wave = resampy.resample(
929918
audio_np,
930919
h.sampling_rate,
@@ -971,13 +960,12 @@ def get_f0(audio, sr):
971960
sound = "data:audio/x-wav;base64," + b64.decode("ascii")
972961

973962
output_name = "TalkNet_" + str(int(time.time()))
974-
return [sound, arpa, playback_style, factor, output_name]
963+
return [sound, arpa, playback_style, output_name]
975964
except Exception:
976965
return [
977966
None,
978967
str(traceback.format_exc()),
979968
playback_hide,
980-
pitch_factor,
981969
None,
982970
]
983971

0 commit comments

Comments
 (0)