diff --git a/include/stable-diffusion.h b/include/stable-diffusion.h index 2f2851c2e..1c04367b1 100644 --- a/include/stable-diffusion.h +++ b/include/stable-diffusion.h @@ -452,6 +452,17 @@ SD_API void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params); SD_API char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params); SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params); +enum sd_cancel_mode_t { + // Stop the current generation as soon as possible. + SD_CANCEL_ALL, + // Finish the current image sample, then skip additional batch latents and return completed images. + SD_CANCEL_NEW_LATENTS, + // Clear a pending cancellation request. + SD_CANCEL_RESET +}; + +SD_API void sd_cancel_generation(sd_ctx_t* sd_ctx, enum sd_cancel_mode_t mode); + SD_API void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params); SD_API bool generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* sd_vid_gen_params, diff --git a/src/stable-diffusion.cpp b/src/stable-diffusion.cpp index 1cc7edfce..c2a1974b2 100644 --- a/src/stable-diffusion.cpp +++ b/src/stable-diffusion.cpp @@ -53,6 +53,8 @@ const char* sd_vae_format_name(enum sd_vae_format_t format); static SDVersion sd_vae_format_to_version(enum sd_vae_format_t format, SDVersion fallback); +#include + const char* model_version_to_str[] = { "SD 1.x", "SD 1.x Inpaint", @@ -159,6 +161,9 @@ static float get_cache_reuse_threshold(const sd_cache_params_t& params) { /*=============================================== StableDiffusionGGML ================================================*/ +static_assert(std::atomic::is_always_lock_free, + "sd_cancel_mode_t must be lock-free"); + class StableDiffusionGGML { public: SDBackendManager backend_manager; @@ -222,6 +227,20 @@ class StableDiffusionGGML { return module_backend; } + std::atomic cancellation_flag = SD_CANCEL_RESET; + + void set_cancel_flag(enum sd_cancel_mode_t flag) { + cancellation_flag.store(flag, std::memory_order_release); + } + + void reset_cancel_flag() { + set_cancel_flag(SD_CANCEL_RESET); + } + + enum sd_cancel_mode_t get_cancel_flag() { + return cancellation_flag.load(std::memory_order_acquire); + } + size_t max_graph_vram_bytes_for_module(SDBackendModule module) { return max_vram_assignment.bytes_for_backend(backend_for(module)); } @@ -1941,6 +1960,11 @@ class StableDiffusionGGML { SamplePreviewContext preview = prepare_sample_preview_context(); auto denoise = [&](const sd::Tensor& x, float sigma, int step) -> sd::guidance::GuiderOutput { + if (get_cancel_flag() == SD_CANCEL_ALL) { + LOG_DEBUG("cancelling generation"); + return {}; + } + if (step == 1 || step == -1) { pretty_progress(0, (int)steps, 0); last_progress_us = ggml_time_us(); @@ -2963,6 +2987,15 @@ void free_sd_ctx(sd_ctx_t* sd_ctx) { free(sd_ctx); } +SD_API void sd_cancel_generation(sd_ctx_t* sd_ctx, enum sd_cancel_mode_t mode) { + if (sd_ctx && sd_ctx->sd) { + if (mode < SD_CANCEL_ALL || mode > SD_CANCEL_RESET) { + mode = SD_CANCEL_ALL; + } + sd_ctx->sd->set_cancel_flag(mode); + } +} + static sd_audio_t* waveform_to_sd_audio(const StableDiffusionGGML* sd, const sd::Tensor& waveform) { if (sd == nullptr || waveform.empty()) { @@ -4150,15 +4183,29 @@ static std::optional prepare_image_generation_embeds(sd_c static sd_image_t* decode_image_outputs(sd_ctx_t* sd_ctx, const GenerationRequest& request, const std::vector>& final_latents) { - if (final_latents.size() != static_cast(request.batch_count)) { - LOG_ERROR("expected %d latents, got %zu", request.batch_count, final_latents.size()); + if (final_latents.empty()) { + LOG_ERROR("no latent images to decode"); + return nullptr; + } + if (final_latents.size() > static_cast(request.batch_count)) { + LOG_ERROR("expected at most %d latents, got %zu", request.batch_count, final_latents.size()); return nullptr; } - LOG_INFO("decoding %zu latents", final_latents.size()); + if (final_latents.size() < static_cast(request.batch_count)) { + LOG_INFO("decoding %zu/%d latents", final_latents.size(), request.batch_count); + } else { + LOG_INFO("decoding %zu latents", final_latents.size()); + } std::vector> decoded_images; - int64_t t0 = ggml_time_ms(); + int64_t t0 = ggml_time_ms(); + bool cancelled = false; for (size_t i = 0; i < final_latents.size(); i++) { + if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) { + LOG_ERROR("cancelling latent decodings"); + cancelled = true; + break; + } int64_t t1 = ggml_time_ms(); sd::Tensor image = sd_ctx->sd->decode_first_stage(final_latents[i]); if (image.empty()) { @@ -4172,6 +4219,10 @@ static sd_image_t* decode_image_outputs(sd_ctx_t* sd_ctx, int64_t t4 = ggml_time_ms(); LOG_INFO("decode_first_stage completed, taking %.2fs", (t4 - t0) * 1.0f / 1000); + if (decoded_images.empty()) { + LOG_ERROR(cancelled ? "cancelled before any latent images were decoded" : "no decoded images"); + return nullptr; + } sd_image_t* result_images = (sd_image_t*)calloc(request.batch_count, sizeof(sd_image_t)); if (result_images == nullptr) { @@ -4190,6 +4241,11 @@ static sd::Tensor upscale_hires_latent(sd_ctx_t* sd_ctx, const sd::Tensor& latent, const GenerationRequest& request, UpscalerGGML* upscaler) { + if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) { + LOG_ERROR("cancelling hires latent upscale"); + return {}; + } + auto get_hires_latent_target_shape = [&]() { std::vector target_shape = latent.shape(); if (target_shape.size() < 2) { @@ -4262,6 +4318,10 @@ static sd::Tensor upscale_hires_latent(sd_ctx_t* sd_ctx, sd_hires_upscaler_name(request.hires.upscaler)); return {}; } + if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) { + LOG_ERROR("cancelling hires image upscale"); + return {}; + } sd::Tensor upscaled_tensor; if (request.hires.upscaler == SD_HIRES_UPSCALER_MODEL) { @@ -4298,6 +4358,10 @@ static sd::Tensor upscale_hires_latent(sd_ctx_t* sd_ctx, upscaled_tensor = sd::ops::clamp(upscaled_tensor, 0.0f, 1.0f); } + if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) { + LOG_ERROR("cancelling hires latent encode"); + return {}; + } sd::Tensor upscaled_latent = sd_ctx->sd->encode_first_stage(upscaled_tensor); if (upscaled_latent.empty()) { LOG_ERROR("encode_first_stage failed after hires %s upscale", @@ -4362,6 +4426,8 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s return nullptr; } + sd_ctx->sd->reset_cancel_flag(); + int64_t t0 = ggml_time_ms(); sd_ctx->sd->vae_tiling_params = sd_img_gen_params->vae_tiling_params; GenerationRequest request(sd_ctx, sd_img_gen_params); @@ -4397,6 +4463,18 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s std::vector> final_latents; int64_t denoise_start = ggml_time_ms(); for (int b = 0; b < request.batch_count; b++) { + sd_cancel_mode_t cancel = sd_ctx->sd->get_cancel_flag(); + if (cancel == SD_CANCEL_ALL) { + LOG_ERROR("cancelling generation"); + return nullptr; + } + if (cancel == SD_CANCEL_NEW_LATENTS) { + LOG_INFO("cancelling new latent generation, returning %zu/%d completed latents", + final_latents.size(), + request.batch_count); + break; + } + int64_t sampling_start = ggml_time_ms(); int64_t cur_seed = request.seed + b; LOG_INFO("generating image: %i/%i - seed %" PRId64, b + 1, request.batch_count, cur_seed); @@ -4446,12 +4524,24 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s LOG_INFO("generating %zu latent images completed, taking %.2fs", final_latents.size(), (denoise_end - denoise_start) * 1.0f / 1000); + if (final_latents.empty()) { + LOG_ERROR("no latent images generated"); + return nullptr; + } if (request.hires.enabled && request.hires.target_width > 0) { + if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) { + LOG_ERROR("cancelling generation before hires fix"); + return nullptr; + } LOG_INFO("hires fix: upscaling to %dx%d", request.hires.target_width, request.hires.target_height); std::unique_ptr hires_upscaler; if (request.hires.upscaler == SD_HIRES_UPSCALER_MODEL) { + if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) { + LOG_ERROR("cancelling generation before hires model load"); + return nullptr; + } LOG_INFO("hires fix: loading model upscaler from '%s'", request.hires.model_path); hires_upscaler = std::make_unique(sd_ctx->sd->n_threads, false, @@ -4485,6 +4575,10 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s std::vector> hires_final_latents; int64_t hires_denoise_start = ggml_time_ms(); for (int b = 0; b < (int)final_latents.size(); b++) { + if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) { + LOG_ERROR("cancelling generation during hires fix"); + return nullptr; + } int64_t cur_seed = request.seed + b; sd_ctx->sd->rng->manual_seed(cur_seed); sd_ctx->sd->sampler_rng->manual_seed(cur_seed); @@ -4915,6 +5009,10 @@ static sd_image_t* decode_video_outputs(sd_ctx_t* sd_ctx, LOG_ERROR("no latent video to decode"); return nullptr; } + if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) { + LOG_ERROR("cancelling video decode"); + return nullptr; + } sd::Tensor video_latent = final_latent; if (sd_version_is_ltxav(sd_ctx->sd->version) && video_latent.shape()[3] > sd_ctx->sd->get_latent_channel()) { @@ -5160,6 +5258,9 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx, if (audio_out != nullptr) { *audio_out = nullptr; } + + sd_ctx->sd->reset_cancel_flag(); + if (num_frames_out != nullptr) { *num_frames_out = 0; } @@ -5221,6 +5322,10 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx, sd::Tensor noise = sd::Tensor::randn_like(x_t, sd_ctx->sd->rng); if (plan.high_noise_sample_steps > 0) { + if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) { + LOG_ERROR("cancelling generation before high-noise sampling"); + return false; + } LOG_DEBUG("sample(high noise) %dx%dx%d", W, H, T); int64_t sampling_start = ggml_time_ms(); @@ -5263,6 +5368,10 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx, LOG_INFO("sampling(high noise) completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000); } + if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) { + LOG_ERROR("cancelling generation before sampling"); + return false; + } LOG_DEBUG("sample %dx%dx%d", W, H, T); int64_t sampling_start = ggml_time_ms(); sd::Tensor final_latent = sd_ctx->sd->sample(sd_ctx->sd->diffusion_model, @@ -5299,6 +5408,10 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx, LOG_INFO("sampling completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000); if (latent_upscale_enabled) { + if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) { + LOG_ERROR("cancelling generation before latent upscale"); + return false; + } int64_t upscale_start = ggml_time_ms(); sd::Tensor upscaled_latent = upscale_ltx_spatial_video_latent(sd_ctx, request.hires.model_path, @@ -5358,6 +5471,10 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx, } sd::Tensor hires_denoise_mask; sd::Tensor hires_video_positions; + if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) { + LOG_ERROR("cancelling generation before latent upscale refine"); + return false; + } if (!apply_ltxv_refine_image_conditioning(sd_ctx, sd_vid_gen_params, hires_request, @@ -5437,6 +5554,10 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx, if (sd_version_is_ltxav(sd_ctx->sd->version) && latents.audio_length > 0 && sd_ctx->sd->audio_vae_model != nullptr) { + if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) { + LOG_ERROR("cancelling generation before audio decode"); + return false; + } int64_t audio_latent_decode_start = ggml_time_ms(); auto audio_latent = unpack_ltxav_audio_latent(final_latent, @@ -5469,6 +5590,11 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx, final_latent = sd::ops::slice(final_latent, 2, latents.ref_image_num, final_latent.shape()[2]); } + if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) { + LOG_ERROR("cancelling generation before video decode"); + free_sd_audio(generated_audio); + return false; + } auto result = decode_video_outputs(sd_ctx, latent_upscale_enabled ? hires_request : request, final_latent, num_frames_out); if (result == nullptr) { free_sd_audio(generated_audio);