Skip to content

Commit 4105a77

Browse files
committed
refactor: group sigma_up and sigma_down calculations
1 parent 545fac4 commit 4105a77

1 file changed

Lines changed: 28 additions & 37 deletions

File tree

src/denoiser.hpp

Lines changed: 28 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,26 @@ struct Flux2FlowDenoiser : public FluxFlowDenoiser {
761761

762762
typedef std::function<ggml_tensor*(ggml_tensor*, float, int)> denoise_cb_t;
763763

764+
static void generate_ancestral_step(float& sigma_up, float& sigma_down, float sigma_from, float sigma_to, float eta = 1.0f) {
765+
// sigma_up = min(sigma_to, eta * √(sigma_to² * (sigma_from² - sigma_to²) / sigma_from²))
766+
// sigma_down = √(sigma_to² - sigma_sup²)
767+
sigma_up = 0.0f;
768+
sigma_down = sigma_to;
769+
if (eta > 0.0f) {
770+
float sigma_from_sq = sigma_from * sigma_from;
771+
float sigma_to_sq = sigma_to * sigma_to;
772+
if (sigma_from_sq > 0.0f) {
773+
float term = sigma_to_sq * (sigma_from_sq - sigma_to_sq) / sigma_from_sq;
774+
if (term > 0.0f) {
775+
sigma_up = eta * std::sqrt(term);
776+
}
777+
}
778+
sigma_up = std::min(sigma_up, sigma_to);
779+
float sigma_down_sq = sigma_to_sq - sigma_up * sigma_up;
780+
sigma_down = sigma_down_sq > 0.0f ? std::sqrt(sigma_down_sq) : 0.0f;
781+
}
782+
}
783+
764784
// k diffusion reverse ODE: dx = (x - D(x;\sigma)) / \sigma dt; \sigma(t) = t
765785
static bool sample_k_diffusion(sample_method_t method,
766786
denoise_cb_t model,
@@ -797,9 +817,8 @@ static bool sample_k_diffusion(sample_method_t method,
797817
}
798818

799819
// get_ancestral_step
800-
float sigma_up = std::min(sigmas[i + 1],
801-
std::sqrt(sigmas[i + 1] * sigmas[i + 1] * (sigmas[i] * sigmas[i] - sigmas[i + 1] * sigmas[i + 1]) / (sigmas[i] * sigmas[i])));
802-
float sigma_down = std::sqrt(sigmas[i + 1] * sigmas[i + 1] - sigma_up * sigma_up);
820+
float sigma_up, sigma_down;
821+
generate_ancestral_step(sigma_up, sigma_down, sigmas[i], sigmas[i + 1]);
803822

804823
// Euler method
805824
float dt = sigma_down - sigmas[i];
@@ -990,9 +1009,8 @@ static bool sample_k_diffusion(sample_method_t method,
9901009
}
9911010

9921011
// get_ancestral_step
993-
float sigma_up = std::min(sigmas[i + 1],
994-
std::sqrt(sigmas[i + 1] * sigmas[i + 1] * (sigmas[i] * sigmas[i] - sigmas[i + 1] * sigmas[i + 1]) / (sigmas[i] * sigmas[i])));
995-
float sigma_down = std::sqrt(sigmas[i + 1] * sigmas[i + 1] - sigma_up * sigma_up);
1012+
float sigma_up, sigma_down;
1013+
generate_ancestral_step(sigma_up, sigma_down, sigmas[i], sigmas[i + 1]);
9961014
auto t_fn = [](float sigma) -> float { return -log(sigma); };
9971015
auto sigma_fn = [](float t) -> float { return exp(-t); };
9981016

@@ -1719,22 +1737,8 @@ static bool sample_k_diffusion(sample_method_t method,
17191737

17201738
float sigma_from = sigmas[i];
17211739
float sigma_to = sigmas[i + 1];
1722-
float sigma_up = 0.0f;
1723-
float sigma_down = sigma_to;
1724-
1725-
if (eta > 0.0f) {
1726-
float sigma_from_sq = sigma_from * sigma_from;
1727-
float sigma_to_sq = sigma_to * sigma_to;
1728-
if (sigma_from_sq > 0.0f) {
1729-
float term = sigma_to_sq * (sigma_from_sq - sigma_to_sq) / sigma_from_sq;
1730-
if (term > 0.0f) {
1731-
sigma_up = eta * std::sqrt(term);
1732-
}
1733-
}
1734-
sigma_up = std::min(sigma_up, sigma_to);
1735-
float sigma_down_sq = sigma_to_sq - sigma_up * sigma_up;
1736-
sigma_down = sigma_down_sq > 0.0f ? std::sqrt(sigma_down_sq) : 0.0f;
1737-
}
1740+
float sigma_up, sigma_down;
1741+
generate_ancestral_step(sigma_up, sigma_down, sigma_from, sigma_to, eta);
17381742

17391743
if (sigma_down == 0.0f || !have_old_sigma) {
17401744
float dt = sigma_down - sigma_from;
@@ -1826,21 +1830,8 @@ static bool sample_k_diffusion(sample_method_t method,
18261830
return false;
18271831
}
18281832

1829-
float sigma_up = 0.0f;
1830-
float sigma_down = sigma_to;
1831-
if (eta > 0.0f) {
1832-
float sigma_from_sq = sigma_from * sigma_from;
1833-
float sigma_to_sq = sigma_to * sigma_to;
1834-
if (sigma_from_sq > 0.0f) {
1835-
float term = sigma_to_sq * (sigma_from_sq - sigma_to_sq) / sigma_from_sq;
1836-
if (term > 0.0f) {
1837-
sigma_up = eta * std::sqrt(term);
1838-
}
1839-
}
1840-
sigma_up = std::min(sigma_up, sigma_to);
1841-
float sigma_down_sq = sigma_to_sq - sigma_up * sigma_up;
1842-
sigma_down = sigma_down_sq > 0.0f ? std::sqrt(sigma_down_sq) : 0.0f;
1843-
}
1833+
float sigma_up, sigma_down;
1834+
generate_ancestral_step(sigma_up, sigma_down, sigma_from, sigma_to, eta);
18441835

18451836
float* vec_x = (float*)x->data;
18461837
float* vec_x0 = (float*)x0->data;

0 commit comments

Comments
 (0)