@@ -2777,6 +2777,7 @@ void sd_sample_params_init(sd_sample_params_t* sample_params) {
27772777 sample_params->scheduler = SCHEDULER_COUNT;
27782778 sample_params->sample_method = SAMPLE_METHOD_COUNT;
27792779 sample_params->sample_steps = 20 ;
2780+ sample_params->eta = INFINITY;
27802781 sample_params->custom_sigmas = nullptr ;
27812782 sample_params->custom_sigmas_count = 0 ;
27822783 sample_params->flow_shift = INFINITY;
@@ -2953,6 +2954,21 @@ enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx) {
29532954 return EULER_A_SAMPLE_METHOD;
29542955}
29552956
2957+ static float sd_get_default_eta (enum sample_method_t sample_method) {
2958+ switch (sample_method) {
2959+ case DDIM_TRAILING_SAMPLE_METHOD:
2960+ case TCD_SAMPLE_METHOD:
2961+ case RES_MULTISTEP_SAMPLE_METHOD:
2962+ case RES_2S_SAMPLE_METHOD:
2963+ return 0 .0f ;
2964+ case EULER_A_SAMPLE_METHOD:
2965+ case DPMPP2S_A_SAMPLE_METHOD:
2966+ return 1 .0f ;
2967+ default :
2968+ return INFINITY;
2969+ }
2970+ }
2971+
29562972enum scheduler_t sd_get_default_scheduler (const sd_ctx_t * sd_ctx, enum sample_method_t sample_method) {
29572973 if (sd_ctx != nullptr && sd_ctx->sd != nullptr ) {
29582974 auto edm_v_denoiser = std::dynamic_pointer_cast<EDMVDenoiser>(sd_ctx->sd ->denoiser );
@@ -3331,7 +3347,16 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
33313347 if (sample_method == SAMPLE_METHOD_COUNT) {
33323348 sample_method = sd_get_default_sample_method (sd_ctx);
33333349 }
3334- LOG_INFO (" sampling using %s method" , sampling_methods_str[sample_method]);
3350+ float eta = sd_img_gen_params->sample_params .eta ;
3351+ float default_eta = sd_get_default_eta (sample_method);
3352+ if (default_eta != INFINITY) {
3353+ if (eta == INFINITY) {
3354+ eta = default_eta;
3355+ }
3356+ LOG_INFO (" sampling using %s method (eta %g)" , sampling_methods_str[sample_method], eta);
3357+ } else {
3358+ LOG_INFO (" sampling using %s method" , sampling_methods_str[sample_method]);
3359+ }
33353360
33363361 int sample_steps = sd_img_gen_params->sample_params .sample_steps ;
33373362 std::vector<float > sigmas;
@@ -3546,7 +3571,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
35463571 SAFE_STR (sd_img_gen_params->negative_prompt ),
35473572 sd_img_gen_params->clip_skip ,
35483573 guidance,
3549- sd_img_gen_params-> sample_params . eta ,
3574+ eta,
35503575 sd_img_gen_params->sample_params .shifted_timestep ,
35513576 width,
35523577 height,
0 commit comments