@@ -761,6 +761,26 @@ struct Flux2FlowDenoiser : public FluxFlowDenoiser {
761761
762762typedef std::function<ggml_tensor*(ggml_tensor*, float , int )> denoise_cb_t ;
763763
764+ static void generate_ancestral_step (float & sigma_up, float & sigma_down, float sigma_from, float sigma_to, float eta = 1 .0f ) {
765+ // sigma_up = min(sigma_to, eta * √(sigma_to² * (sigma_from² - sigma_to²) / sigma_from²))
766+ // sigma_down = √(sigma_to² - sigma_sup²)
767+ sigma_up = 0 .0f ;
768+ sigma_down = sigma_to;
769+ if (eta > 0 .0f ) {
770+ float sigma_from_sq = sigma_from * sigma_from;
771+ float sigma_to_sq = sigma_to * sigma_to;
772+ if (sigma_from_sq > 0 .0f ) {
773+ float term = sigma_to_sq * (sigma_from_sq - sigma_to_sq) / sigma_from_sq;
774+ if (term > 0 .0f ) {
775+ sigma_up = eta * std::sqrt (term);
776+ }
777+ }
778+ sigma_up = std::min (sigma_up, sigma_to);
779+ float sigma_down_sq = sigma_to_sq - sigma_up * sigma_up;
780+ sigma_down = sigma_down_sq > 0 .0f ? std::sqrt (sigma_down_sq) : 0 .0f ;
781+ }
782+ }
783+
764784// k diffusion reverse ODE: dx = (x - D(x;\sigma)) / \sigma dt; \sigma(t) = t
765785static bool sample_k_diffusion (sample_method_t method,
766786 denoise_cb_t model,
@@ -797,9 +817,8 @@ static bool sample_k_diffusion(sample_method_t method,
797817 }
798818
799819 // get_ancestral_step
800- float sigma_up = std::min (sigmas[i + 1 ],
801- std::sqrt (sigmas[i + 1 ] * sigmas[i + 1 ] * (sigmas[i] * sigmas[i] - sigmas[i + 1 ] * sigmas[i + 1 ]) / (sigmas[i] * sigmas[i])));
802- float sigma_down = std::sqrt (sigmas[i + 1 ] * sigmas[i + 1 ] - sigma_up * sigma_up);
820+ float sigma_up, sigma_down;
821+ generate_ancestral_step (sigma_up, sigma_down, sigmas[i], sigmas[i + 1 ]);
803822
804823 // Euler method
805824 float dt = sigma_down - sigmas[i];
@@ -990,9 +1009,8 @@ static bool sample_k_diffusion(sample_method_t method,
9901009 }
9911010
9921011 // get_ancestral_step
993- float sigma_up = std::min (sigmas[i + 1 ],
994- std::sqrt (sigmas[i + 1 ] * sigmas[i + 1 ] * (sigmas[i] * sigmas[i] - sigmas[i + 1 ] * sigmas[i + 1 ]) / (sigmas[i] * sigmas[i])));
995- float sigma_down = std::sqrt (sigmas[i + 1 ] * sigmas[i + 1 ] - sigma_up * sigma_up);
1012+ float sigma_up, sigma_down;
1013+ generate_ancestral_step (sigma_up, sigma_down, sigmas[i], sigmas[i + 1 ]);
9961014 auto t_fn = [](float sigma) -> float { return -log (sigma); };
9971015 auto sigma_fn = [](float t) -> float { return exp (-t); };
9981016
@@ -1719,22 +1737,8 @@ static bool sample_k_diffusion(sample_method_t method,
17191737
17201738 float sigma_from = sigmas[i];
17211739 float sigma_to = sigmas[i + 1 ];
1722- float sigma_up = 0 .0f ;
1723- float sigma_down = sigma_to;
1724-
1725- if (eta > 0 .0f ) {
1726- float sigma_from_sq = sigma_from * sigma_from;
1727- float sigma_to_sq = sigma_to * sigma_to;
1728- if (sigma_from_sq > 0 .0f ) {
1729- float term = sigma_to_sq * (sigma_from_sq - sigma_to_sq) / sigma_from_sq;
1730- if (term > 0 .0f ) {
1731- sigma_up = eta * std::sqrt (term);
1732- }
1733- }
1734- sigma_up = std::min (sigma_up, sigma_to);
1735- float sigma_down_sq = sigma_to_sq - sigma_up * sigma_up;
1736- sigma_down = sigma_down_sq > 0 .0f ? std::sqrt (sigma_down_sq) : 0 .0f ;
1737- }
1740+ float sigma_up, sigma_down;
1741+ generate_ancestral_step (sigma_up, sigma_down, sigma_from, sigma_to, eta);
17381742
17391743 if (sigma_down == 0 .0f || !have_old_sigma) {
17401744 float dt = sigma_down - sigma_from;
@@ -1826,21 +1830,8 @@ static bool sample_k_diffusion(sample_method_t method,
18261830 return false ;
18271831 }
18281832
1829- float sigma_up = 0 .0f ;
1830- float sigma_down = sigma_to;
1831- if (eta > 0 .0f ) {
1832- float sigma_from_sq = sigma_from * sigma_from;
1833- float sigma_to_sq = sigma_to * sigma_to;
1834- if (sigma_from_sq > 0 .0f ) {
1835- float term = sigma_to_sq * (sigma_from_sq - sigma_to_sq) / sigma_from_sq;
1836- if (term > 0 .0f ) {
1837- sigma_up = eta * std::sqrt (term);
1838- }
1839- }
1840- sigma_up = std::min (sigma_up, sigma_to);
1841- float sigma_down_sq = sigma_to_sq - sigma_up * sigma_up;
1842- sigma_down = sigma_down_sq > 0 .0f ? std::sqrt (sigma_down_sq) : 0 .0f ;
1843- }
1833+ float sigma_up, sigma_down;
1834+ generate_ancestral_step (sigma_up, sigma_down, sigma_from, sigma_to, eta);
18441835
18451836 float * vec_x = (float *)x->data ;
18461837 float * vec_x0 = (float *)x0->data ;
0 commit comments