|
50 | 50 | const char* sd_vae_format_name(enum sd_vae_format_t format); |
51 | 51 | static SDVersion sd_vae_format_to_version(enum sd_vae_format_t format, SDVersion fallback); |
52 | 52 |
|
| 53 | +#include <atomic> |
| 54 | + |
53 | 55 | const char* model_version_to_str[] = { |
54 | 56 | "SD 1.x", |
55 | 57 | "SD 1.x Inpaint", |
@@ -155,6 +157,9 @@ static float get_cache_reuse_threshold(const sd_cache_params_t& params) { |
155 | 157 |
|
156 | 158 | /*=============================================== StableDiffusionGGML ================================================*/ |
157 | 159 |
|
| 160 | +static_assert(std::atomic<sd_cancel_mode_t>::is_always_lock_free, |
| 161 | + "sd_cancel_mode_t must be lock-free"); |
| 162 | + |
158 | 163 | class StableDiffusionGGML { |
159 | 164 | public: |
160 | 165 | std::vector<MmapTensorStore> mmap_tensor_store; |
@@ -225,6 +230,20 @@ class StableDiffusionGGML { |
225 | 230 | return module_backend; |
226 | 231 | } |
227 | 232 |
|
| 233 | + std::atomic<sd_cancel_mode_t> cancellation_flag; |
| 234 | + |
| 235 | + void set_cancel_flag(enum sd_cancel_mode_t flag) { |
| 236 | + cancellation_flag.store(flag, std::memory_order_release); |
| 237 | + } |
| 238 | + |
| 239 | + void reset_cancel_flag() { |
| 240 | + set_cancel_flag(SD_CANCEL_RESET); |
| 241 | + } |
| 242 | + |
| 243 | + enum sd_cancel_mode_t get_cancel_flag() { |
| 244 | + return cancellation_flag.load(std::memory_order_acquire); |
| 245 | + } |
| 246 | + |
228 | 247 | bool ensure_backend_pair(SDBackendModule module) { |
229 | 248 | if (backend_for(module) == nullptr) { |
230 | 249 | return false; |
@@ -1968,6 +1987,12 @@ class StableDiffusionGGML { |
1968 | 1987 | SamplePreviewContext preview = prepare_sample_preview_context(); |
1969 | 1988 |
|
1970 | 1989 | auto denoise = [&](const sd::Tensor<float>& x, float sigma, int step) -> sd::guidance::GuiderOutput { |
| 1990 | + enum sd_cancel_mode_t cancel_flag = get_cancel_flag(); |
| 1991 | + if (cancel_flag != SD_CANCEL_RESET) { |
| 1992 | + LOG_DEBUG("cancelling generation"); |
| 1993 | + return {}; |
| 1994 | + } |
| 1995 | + |
1971 | 1996 | if (step == 1 || step == -1) { |
1972 | 1997 | pretty_progress(0, (int)steps, 0); |
1973 | 1998 | } |
@@ -3010,6 +3035,15 @@ void free_sd_ctx(sd_ctx_t* sd_ctx) { |
3010 | 3035 | free(sd_ctx); |
3011 | 3036 | } |
3012 | 3037 |
|
| 3038 | +SD_API void sd_cancel_generation(sd_ctx_t* sd_ctx, enum sd_cancel_mode_t mode) { |
| 3039 | + if (sd_ctx && sd_ctx->sd) { |
| 3040 | + if (mode < SD_CANCEL_ALL || mode > SD_CANCEL_RESET) { |
| 3041 | + mode = SD_CANCEL_ALL; |
| 3042 | + } |
| 3043 | + sd_ctx->sd->set_cancel_flag(mode); |
| 3044 | + } |
| 3045 | +} |
| 3046 | + |
3013 | 3047 | static sd_audio_t* waveform_to_sd_audio(const StableDiffusionGGML* sd, |
3014 | 3048 | const sd::Tensor<float>& waveform) { |
3015 | 3049 | if (sd == nullptr || waveform.empty()) { |
@@ -4196,6 +4230,10 @@ static sd_image_t* decode_image_outputs(sd_ctx_t* sd_ctx, |
4196 | 4230 | int64_t t0 = ggml_time_ms(); |
4197 | 4231 |
|
4198 | 4232 | for (size_t i = 0; i < final_latents.size(); i++) { |
| 4233 | + if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) { |
| 4234 | + LOG_ERROR("cancelling latent decodings"); |
| 4235 | + break; |
| 4236 | + } |
4199 | 4237 | int64_t t1 = ggml_time_ms(); |
4200 | 4238 | sd::Tensor<float> image = sd_ctx->sd->decode_first_stage(final_latents[i]); |
4201 | 4239 | if (image.empty()) { |
@@ -4410,6 +4448,8 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s |
4410 | 4448 | return nullptr; |
4411 | 4449 | } |
4412 | 4450 |
|
| 4451 | + sd_ctx->sd->reset_cancel_flag(); |
| 4452 | + |
4413 | 4453 | int64_t t0 = ggml_time_ms(); |
4414 | 4454 | sd_ctx->sd->vae_tiling_params = sd_img_gen_params->vae_tiling_params; |
4415 | 4455 | GenerationRequest request(sd_ctx, sd_img_gen_params); |
@@ -4445,6 +4485,12 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s |
4445 | 4485 | std::vector<sd::Tensor<float>> final_latents; |
4446 | 4486 | int64_t denoise_start = ggml_time_ms(); |
4447 | 4487 | for (int b = 0; b < request.batch_count; b++) { |
| 4488 | + sd_cancel_mode_t cancel = sd_ctx->sd->get_cancel_flag(); |
| 4489 | + if (cancel == SD_CANCEL_NEW_LATENTS || cancel == SD_CANCEL_ALL) { |
| 4490 | + LOG_ERROR("cancelling generation"); |
| 4491 | + break; |
| 4492 | + } |
| 4493 | + |
4448 | 4494 | int64_t sampling_start = ggml_time_ms(); |
4449 | 4495 | int64_t cur_seed = request.seed + b; |
4450 | 4496 | LOG_INFO("generating image: %i/%i - seed %" PRId64, b + 1, request.batch_count, cur_seed); |
@@ -5218,6 +5264,9 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx, |
5218 | 5264 | if (audio_out != nullptr) { |
5219 | 5265 | *audio_out = nullptr; |
5220 | 5266 | } |
| 5267 | + |
| 5268 | + sd_ctx->sd->reset_cancel_flag(); |
| 5269 | + |
5221 | 5270 | if (num_frames_out != nullptr) { |
5222 | 5271 | *num_frames_out = 0; |
5223 | 5272 | } |
|
0 commit comments