From 4533c37cf5857fdc29b86d520a2ff62519f93591 Mon Sep 17 00:00:00 2001 From: rmatif Date: Sat, 2 Aug 2025 00:49:19 +0000 Subject: [PATCH 1/3] feat: Add timestep shift and two new schedulers --- denoiser.hpp | 84 +++++++++- examples/cli/main.cpp | 21 ++- stable-diffusion.cpp | 377 ++++++++++++++++++++++++------------------ stable-diffusion.h | 3 + 4 files changed, 322 insertions(+), 163 deletions(-) diff --git a/denoiser.hpp b/denoiser.hpp index d4bcec59..1304e632 100644 --- a/denoiser.hpp +++ b/denoiser.hpp @@ -232,6 +232,25 @@ struct GITSSchedule : SigmaSchedule { } }; +struct SGMUniformSchedule : SigmaSchedule { + std::vector get_sigmas(uint32_t n, float sigma_min_in, float sigma_max_in, t_to_sigma_t t_to_sigma_func) override { + + std::vector result; + if (n == 0) { + result.push_back(0.0f); + return result; + } + result.reserve(n + 1); + int t_max = TIMESTEPS -1; + float step = static_cast(t_max) / static_cast(n > 1 ? (n -1) : 1) ; + for(uint32_t i=0; i get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) { // These *COULD* be function arguments here, @@ -251,6 +270,36 @@ struct KarrasSchedule : SigmaSchedule { } }; +struct SimpleSchedule : SigmaSchedule { + std::vector get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override { + std::vector result_sigmas; + + if (n == 0) { + return result_sigmas; + } + + result_sigmas.reserve(n + 1); + + int model_sigmas_len = TIMESTEPS; + + float step_factor = static_cast(model_sigmas_len) / static_cast(n); + + for (uint32_t i = 0; i < n; ++i) { + + int offset_from_start_of_py_array = static_cast(static_cast(i) * step_factor); + int timestep_index = model_sigmas_len - 1 - offset_from_start_of_py_array; + + if (timestep_index < 0) { + timestep_index = 0; + } + + result_sigmas.push_back(t_to_sigma(static_cast(timestep_index))); + } + result_sigmas.push_back(0.0f); + return result_sigmas; + } +}; + struct Denoiser { std::shared_ptr schedule = std::make_shared(); virtual float sigma_min() = 0; @@ -262,8 +311,39 @@ struct Denoiser { virtual ggml_tensor* inverse_noise_scaling(float sigma, ggml_tensor* latent) = 0; virtual std::vector get_sigmas(uint32_t n) { - auto bound_t_to_sigma = std::bind(&Denoiser::t_to_sigma, this, std::placeholders::_1); - return schedule->get_sigmas(n, sigma_min(), sigma_max(), bound_t_to_sigma); + // Check if the current schedule is SGMUniformSchedule + if (std::dynamic_pointer_cast(schedule)) { + std::vector sigs; + sigs.reserve(n + 1); + + if (n == 0) { + sigs.push_back(0.0f); + return sigs; + } + + // Use the Denoiser's own sigma_to_t and t_to_sigma methods + float start_t_val = this->sigma_to_t(this->sigma_max()); + float end_t_val = this->sigma_to_t(this->sigma_min()); + + float dt_per_step; + if (n > 0) { + dt_per_step = (end_t_val - start_t_val) / static_cast(n); + } else { + dt_per_step = 0.0f; + } + + for (uint32_t i = 0; i < n; ++i) { + float current_t = start_t_val + static_cast(i) * dt_per_step; + sigs.push_back(this->t_to_sigma(current_t)); + } + + sigs.push_back(0.0f); + return sigs; + + } else { // For all other schedules, use the existing virtual dispatch + auto bound_t_to_sigma = std::bind(&Denoiser::t_to_sigma, this, std::placeholders::_1); + return schedule->get_sigmas(n, sigma_min(), sigma_max(), bound_t_to_sigma); + } } }; diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 140e3843..d9dec605 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -105,6 +105,7 @@ struct SDParams { float slg_scale = 0.f; float skip_layer_start = 0.01f; float skip_layer_end = 0.2f; + int shifted_timestep = -1; bool chroma_use_dit_mask = true; bool chroma_use_t5_mask = false; @@ -163,6 +164,7 @@ void print_params(SDParams params) { printf(" batch_count: %d\n", params.batch_count); printf(" vae_tiling: %s\n", params.vae_tiling ? "true" : "false"); printf(" upscale_repeats: %d\n", params.upscale_repeats); + printf(" timestep_shift: %d\n", params.shifted_timestep); printf(" chroma_use_dit_mask: %s\n", params.chroma_use_dit_mask ? "true" : "false"); printf(" chroma_use_t5_mask: %s\n", params.chroma_use_t5_mask ? "true" : "false"); printf(" chroma_t5_mask_pad: %d\n", params.chroma_t5_mask_pad); @@ -223,7 +225,7 @@ void print_usage(int argc, const char* argv[]) { printf(" --rng {std_default, cuda} RNG (default: cuda)\n"); printf(" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n"); printf(" -b, --batch-count COUNT number of images to generate\n"); - printf(" --schedule {discrete, karras, exponential, ays, gits} Denoiser sigma schedule (default: discrete)\n"); + printf(" --schedule {discrete, karras, exponential, ays, gits, sgm_uniform, simple} Denoiser sigma schedule (default: discrete)\n"); printf(" --clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)\n"); printf(" <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x\n"); printf(" --vae-tiling process vae in tiles to reduce memory usage\n"); @@ -235,6 +237,7 @@ void print_usage(int argc, const char* argv[]) { printf(" --control-net-cpu keep controlnet in cpu (for low vram)\n"); printf(" --canny apply canny preprocessor (edge detection)\n"); printf(" --color colors the logging tags according to level\n"); + printf(" --timestep-shift N shift timestep for NitroFusion models, default: -1 off, recommended N for NitroSD-Realism around 250 and 500 for NitroSD-Vibrant\n"); printf(" --chroma-disable-dit-mask disable dit mask for chroma\n"); printf(" --chroma-enable-t5-mask enable t5 mask for chroma\n"); printf(" --chroma-t5-mask-pad PAD_SIZE t5 mask pad size of chroma\n"); @@ -487,7 +490,7 @@ void parse_args(int argc, const char** argv, SDParams& params) { const char* arg = argv[index]; params.schedule = str_to_schedule(arg); if (params.schedule == SCHEDULE_COUNT) { - fprintf(stderr, "error: invalid schedule %s\n", + fprintf(stderr, "error: invalid schedule %s, must be one of [discrete, karras, exponential, ays, gits, sgm_uniform, simple]\n", arg); return -1; } @@ -568,7 +571,18 @@ void parse_args(int argc, const char** argv, SDParams& params) { {"-r", "--ref-image", "", on_ref_image_arg}, {"-h", "--help", "", on_help_arg}, }; - + auto on_timestep_shift_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + params.shifted_timestep = std::stoi(argv[index]); + if (params.shifted_timestep != -1 && (params.shifted_timestep < 1 || params.shifted_timestep > 1000)) { + fprintf(stderr, "error: timestep-shift must be between 1 and 1000, or -1 to disable\n"); + return -1; + } + return 1; + }; + options.manual_options.push_back({"", "--timestep-shift", "", on_timestep_shift_arg}); if (!parse_options(argc, argv, options)) { print_usage(argc, argv); exit(1); @@ -979,6 +993,7 @@ int main(int argc, const char* argv[]) { params.style_ratio, params.normalize_input, params.input_id_images_path.c_str(), + params.shifted_timestep, }; results = generate_image(sd_ctx, &img_gen_params); diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 2594ba2b..258c419a 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -628,6 +628,16 @@ class StableDiffusionGGML { denoiser->schedule = std::make_shared(); denoiser->schedule->version = version; break; + case SGM_UNIFORM: + LOG_INFO("Running with SGM Uniform schedule"); + denoiser->schedule = std::make_shared(); + denoiser->schedule->version = version; + break; + case SIMPLE: + LOG_INFO("Running with Simple schedule"); + denoiser->schedule = std::make_shared(); + denoiser->schedule->version = version; + break; case DEFAULT: // Don't touch anything. break; @@ -848,7 +858,13 @@ class StableDiffusionGGML { int start_merge_step, SDCondition id_cond, std::vector ref_latents = {}, - ggml_tensor* denoise_mask = nullptr) { + ggml_tensor* denoise_mask = nullptr, + int shifted_timestep = -1) { + + if (shifted_timestep > 0 && !sd_version_is_sdxl(version)) { + LOG_WARN("Timestep shifting is only supported for SDXL models. Ignoring --timestep-shift."); + shifted_timestep = -1; + } std::vector skip_layers(guidance.slg.layers, guidance.slg.layers + guidance.slg.layer_count); float cfg_scale = guidance.txt_cfg; @@ -907,181 +923,218 @@ class StableDiffusionGGML { if (has_img_cond) { out_img_cond = ggml_dup_tensor(work_ctx, x); } - struct ggml_tensor* denoised = ggml_dup_tensor(work_ctx, x); - - auto denoise = [&](ggml_tensor* input, float sigma, int step) -> ggml_tensor* { - if (step == 1) { - pretty_progress(0, (int)steps, 0); - } - int64_t t0 = ggml_time_us(); - - std::vector scaling = denoiser->get_scalings(sigma); - GGML_ASSERT(scaling.size() == 3); - float c_skip = scaling[0]; - float c_out = scaling[1]; - float c_in = scaling[2]; - - float t = denoiser->sigma_to_t(sigma); - std::vector timesteps_vec(x->ne[3], t); // [N, ] - auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec); - std::vector guidance_vec(x->ne[3], guidance.distilled_guidance); - auto guidance_tensor = vector_to_ggml_tensor(work_ctx, guidance_vec); + struct ggml_tensor* denoised = ggml_dup_tensor(work_ctx, x); - copy_ggml_tensor(noised_input, input); - // noised_input = noised_input * c_in - ggml_tensor_scale(noised_input, c_in); + auto denoise = [&](ggml_tensor* input, float sigma, int step) -> ggml_tensor* { + if (step == 1) { + pretty_progress(0, (int)steps, 0); + } + int64_t t0 = ggml_time_us(); - std::vector controls; + std::vector scaling = denoiser->get_scalings(sigma); + GGML_ASSERT(scaling.size() == 3); + float c_skip = scaling[0]; + float c_out = scaling[1]; + float c_in = scaling[2]; + float t = denoiser->sigma_to_t(sigma); + std::vector timesteps_vec; + if (shifted_timestep > 0 && sd_version_is_sdxl(version)) { + float shifted_t_float = t * (float(shifted_timestep) / float(TIMESTEPS)); + int64_t shifted_t = static_cast(roundf(shifted_t_float)); + shifted_t = std::max((int64_t)0, std::min((int64_t)(TIMESTEPS - 1), shifted_t)); + LOG_DEBUG("Shifting timestep from %.2f to %" PRId64 " (sigma: %.4f)", t, shifted_t, sigma); + timesteps_vec.assign(x->ne[3], (float)shifted_t); + } else { + timesteps_vec.assign(x->ne[3], t); + } + + auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec); + std::vector guidance_vec(x->ne[3], guidance.distilled_guidance); + auto guidance_tensor = vector_to_ggml_tensor(work_ctx, guidance_vec); + + copy_ggml_tensor(noised_input, input); + // noised_input = noised_input * c_in + ggml_tensor_scale(noised_input, c_in); + + std::vector controls; + + if (control_hint != NULL) { + control_net->compute(n_threads, noised_input, control_hint, timesteps, cond.c_crossattn, cond.c_vector); + controls = control_net->controls; + // print_ggml_tensor(controls[12]); + // GGML_ASSERT(0); + } + + if (start_merge_step == -1 || step <= start_merge_step) { + // cond + diffusion_model->compute(n_threads, + noised_input, + timesteps, + cond.c_crossattn, + cond.c_concat, + cond.c_vector, + guidance_tensor, + ref_latents, + -1, + controls, + control_strength, + &out_cond); + } else { + diffusion_model->compute(n_threads, + noised_input, + timesteps, + id_cond.c_crossattn, + cond.c_concat, + id_cond.c_vector, + guidance_tensor, + ref_latents, + -1, + controls, + control_strength, + &out_cond); + } + + float* negative_data = NULL; + if (has_unconditioned) { + // uncond if (control_hint != NULL) { - control_net->compute(n_threads, noised_input, control_hint, timesteps, cond.c_crossattn, cond.c_vector); + control_net->compute(n_threads, noised_input, control_hint, timesteps, uncond.c_crossattn, uncond.c_vector); controls = control_net->controls; - // print_ggml_tensor(controls[12]); - // GGML_ASSERT(0); - } - - if (start_merge_step == -1 || step <= start_merge_step) { - // cond - diffusion_model->compute(n_threads, - noised_input, - timesteps, - cond.c_crossattn, - cond.c_concat, - cond.c_vector, - guidance_tensor, - ref_latents, - -1, - controls, - control_strength, - &out_cond); - } else { - diffusion_model->compute(n_threads, - noised_input, - timesteps, - id_cond.c_crossattn, - cond.c_concat, - id_cond.c_vector, - guidance_tensor, - ref_latents, - -1, - controls, - control_strength, - &out_cond); } + diffusion_model->compute(n_threads, + noised_input, + timesteps, + uncond.c_crossattn, + uncond.c_concat, + uncond.c_vector, + guidance_tensor, + ref_latents, + -1, + controls, + control_strength, + &out_uncond); + negative_data = (float*)out_uncond->data; + } + + float* img_cond_data = NULL; + if (has_img_cond) { + diffusion_model->compute(n_threads, + noised_input, + timesteps, + img_cond.c_crossattn, + img_cond.c_concat, + img_cond.c_vector, + guidance_tensor, + ref_latents, + -1, + controls, + control_strength, + &out_img_cond); + img_cond_data = (float*)out_img_cond->data; + } + + int step_count = sigmas.size(); + bool is_skiplayer_step = has_skiplayer && step > (int)(guidance.slg.layer_start * step_count) && step < (int)(guidance.slg.layer_end * step_count); + float* skip_layer_data = NULL; + if (is_skiplayer_step) { + LOG_DEBUG("Skipping layers at step %d\n", step); + // skip layer (same as conditionned) + diffusion_model->compute(n_threads, + noised_input, + timesteps, + cond.c_crossattn, + cond.c_concat, + cond.c_vector, + guidance_tensor, + ref_latents, + -1, + controls, + control_strength, + &out_skip, + NULL, + skip_layers); + skip_layer_data = (float*)out_skip->data; + } + float* vec_denoised = (float*)denoised->data; + float* vec_input = (float*)input->data; + float* positive_data = (float*)out_cond->data; + float* negative_data_ptr = has_unconditioned ? (float*)out_uncond->data : nullptr; + float* skip_layer_data_ptr = is_skiplayer_step ? (float*)out_skip->data : nullptr; + int ne_elements = (int)ggml_nelements(denoised); + + if (shifted_timestep > 0 && sd_version_is_sdxl(version)) { + int64_t shifted_t_idx = static_cast(roundf(timesteps_vec[0])); + + float shifted_sigma = denoiser->t_to_sigma((float)shifted_t_idx); + std::vector shifted_scaling = denoiser->get_scalings(shifted_sigma); + float shifted_c_skip = shifted_scaling[0]; + float shifted_c_out = shifted_scaling[1]; + auto compvis_denoiser_ptr = std::dynamic_pointer_cast(denoiser); + float sigma_data = compvis_denoiser_ptr ? compvis_denoiser_ptr->sigma_data : 1.0f; + + float sigma_sq = sigma * sigma; + float shifted_sigma_sq = shifted_sigma * shifted_sigma; + float sigma_data_sq = sigma_data * sigma_data; + + float input_scale_factor = sqrtf((shifted_sigma_sq + sigma_data_sq) / (sigma_sq + sigma_data_sq)); - float* negative_data = NULL; - if (has_unconditioned) { - // uncond - if (control_hint != NULL) { - control_net->compute(n_threads, noised_input, control_hint, timesteps, uncond.c_crossattn, uncond.c_vector); - controls = control_net->controls; + for (int i = 0; i < ne_elements; i++) { + float model_output_result = positive_data[i]; + if (has_unconditioned) { + if (has_img_cond) { + model_output_result = negative_data_ptr[i] + img_cfg_scale * (img_cond_data[i] - negative_data_ptr[i]) + cfg_scale * (positive_data[i] - img_cond_data[i]); + } else { + model_output_result = negative_data_ptr[i] + cfg_scale * (positive_data[i] - negative_data_ptr[i]); + } + } else if (has_img_cond) { + model_output_result = img_cond_data[i] + cfg_scale * (positive_data[i] - img_cond_data[i]); } - diffusion_model->compute(n_threads, - noised_input, - timesteps, - uncond.c_crossattn, - uncond.c_concat, - uncond.c_vector, - guidance_tensor, - ref_latents, - -1, - controls, - control_strength, - &out_uncond); - negative_data = (float*)out_uncond->data; - } - - float* img_cond_data = NULL; - if (has_img_cond) { - diffusion_model->compute(n_threads, - noised_input, - timesteps, - img_cond.c_crossattn, - img_cond.c_concat, - img_cond.c_vector, - guidance_tensor, - ref_latents, - -1, - controls, - control_strength, - &out_img_cond); - img_cond_data = (float*)out_img_cond->data; + if (is_skiplayer_step) { + model_output_result = model_output_result + slg_scale * (positive_data[i] - skip_layer_data_ptr[i]); + } + float adjusted_input = vec_input[i] * input_scale_factor; + vec_denoised[i] = adjusted_input * shifted_c_skip + model_output_result * shifted_c_out; } - int step_count = sigmas.size(); - bool is_skiplayer_step = has_skiplayer && step > (int)(guidance.slg.layer_start * step_count) && step < (int)(guidance.slg.layer_end * step_count); - float* skip_layer_data = NULL; - if (is_skiplayer_step) { - LOG_DEBUG("Skipping layers at step %d\n", step); - // skip layer (same as conditionned) - diffusion_model->compute(n_threads, - noised_input, - timesteps, - cond.c_crossattn, - cond.c_concat, - cond.c_vector, - guidance_tensor, - ref_latents, - -1, - controls, - control_strength, - &out_skip, - NULL, - skip_layers); - skip_layer_data = (float*)out_skip->data; - } - float* vec_denoised = (float*)denoised->data; - float* vec_input = (float*)input->data; - float* positive_data = (float*)out_cond->data; - int ne_elements = (int)ggml_nelements(denoised); + } else { for (int i = 0; i < ne_elements; i++) { - float latent_result = positive_data[i]; + float model_output_result = positive_data[i]; if (has_unconditioned) { - // out_uncond + cfg_scale * (out_cond - out_uncond) - int64_t ne3 = out_cond->ne[3]; - if (min_cfg != cfg_scale && ne3 != 1) { - int64_t i3 = i / out_cond->ne[0] * out_cond->ne[1] * out_cond->ne[2]; - float scale = min_cfg + (cfg_scale - min_cfg) * (i3 * 1.0f / ne3); + if (has_img_cond) { + model_output_result = negative_data_ptr[i] + img_cfg_scale * (img_cond_data[i] - negative_data_ptr[i]) + cfg_scale * (positive_data[i] - img_cond_data[i]); } else { - if (has_img_cond) { - // out_uncond + text_cfg_scale * (out_cond - out_img_cond) + image_cfg_scale * (out_img_cond - out_uncond) - latent_result = negative_data[i] + img_cfg_scale * (img_cond_data[i] - negative_data[i]) + cfg_scale * (positive_data[i] - img_cond_data[i]); - } else { - // img_cfg_scale == cfg_scale - latent_result = negative_data[i] + cfg_scale * (positive_data[i] - negative_data[i]); - } + model_output_result = negative_data_ptr[i] + cfg_scale * (positive_data[i] - negative_data_ptr[i]); } } else if (has_img_cond) { - // img_cfg_scale == 1 - latent_result = img_cond_data[i] + cfg_scale * (positive_data[i] - img_cond_data[i]); + model_output_result = img_cond_data[i] + cfg_scale * (positive_data[i] - img_cond_data[i]); } if (is_skiplayer_step) { - latent_result = latent_result + (positive_data[i] - skip_layer_data[i]) * slg_scale; + model_output_result = model_output_result + slg_scale * (positive_data[i] - skip_layer_data_ptr[i]); } - // v = latent_result, eps = latent_result - // denoised = (v * c_out + input * c_skip) or (input + eps * c_out) - vec_denoised[i] = latent_result * c_out + vec_input[i] * c_skip; + vec_denoised[i] = vec_input[i] * c_skip + model_output_result * c_out; } - int64_t t1 = ggml_time_us(); - if (step > 0) { - pretty_progress(step, (int)steps, (t1 - t0) / 1000000.f); - // LOG_INFO("step %d sampling completed taking %.2fs", step, (t1 - t0) * 1.0f / 1000000); - } - if (denoise_mask != nullptr) { - for (int64_t x = 0; x < denoised->ne[0]; x++) { - for (int64_t y = 0; y < denoised->ne[1]; y++) { - float mask = ggml_tensor_get_f32(denoise_mask, x, y); - for (int64_t k = 0; k < denoised->ne[2]; k++) { - float init = ggml_tensor_get_f32(init_latent, x, y, k); - float den = ggml_tensor_get_f32(denoised, x, y, k); - ggml_tensor_set_f32(denoised, init + mask * (den - init), x, y, k); - } + } + int64_t t1 = ggml_time_us(); + if (step > 0) { + pretty_progress(step, (int)steps, (t1 - t0) / 1000000.f); + // LOG_INFO("step %d sampling completed taking %.2fs", step, (t1 - t0) * 1.0f / 1000000); + } + if (denoise_mask != nullptr) { + for (int64_t x = 0; x < denoised->ne[0]; x++) { + for (int64_t y = 0; y < denoised->ne[1]; y++) { + float mask = ggml_tensor_get_f32(denoise_mask, x, y); + for (int64_t k = 0; k < denoised->ne[2]; k++) { + float init = ggml_tensor_get_f32(init_latent, x, y, k); + float den = ggml_tensor_get_f32(denoised, x, y, k); + ggml_tensor_set_f32(denoised, init + mask * (den - init), x, y, k); } } } + } - return denoised; - }; + return denoised; + }; sample_k_diffusion(method, denoise, work_ctx, x, sigmas, rng, eta); @@ -1272,6 +1325,8 @@ const char* schedule_to_str[] = { "exponential", "ays", "gits", + "sgm_uniform", + "simple", }; const char* sd_schedule_name(enum schedule_t schedule) { @@ -1392,6 +1447,7 @@ void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params) { sd_img_gen_params->control_strength = 0.9f; sd_img_gen_params->style_strength = 20.f; sd_img_gen_params->normalize_input = false; + sd_img_gen_params->shifted_timestep = -1; } char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) { @@ -1425,7 +1481,8 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) { "control_strength: %.2f\n" "style_strength: %.2f\n" "normalize_input: %s\n" - "input_id_images_path: %s\n", + "input_id_images_path: %s\n" + "shifted_timestep: %d\n", SAFE_STR(sd_img_gen_params->prompt), SAFE_STR(sd_img_gen_params->negative_prompt), sd_img_gen_params->clip_skip, @@ -1449,7 +1506,8 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) { sd_img_gen_params->control_strength, sd_img_gen_params->style_strength, BOOL_STR(sd_img_gen_params->normalize_input), - SAFE_STR(sd_img_gen_params->input_id_images_path)); + SAFE_STR(sd_img_gen_params->input_id_images_path), + sd_img_gen_params->shifted_timestep); return buf; } @@ -1529,7 +1587,8 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, std::string input_id_images_path, std::vector ref_latents, ggml_tensor* concat_latent = NULL, - ggml_tensor* denoise_mask = NULL) { + ggml_tensor* denoise_mask = NULL, + int shifted_timestep = -1) { if (seed < 0) { // Generally, when using the provided command line, the seed is always >0. // However, to prevent potential issues if 'stable-diffusion.cpp' is invoked as a library @@ -1798,7 +1857,8 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, start_merge_step, id_cond, ref_latents, - denoise_mask); + denoise_mask, + shifted_timestep); // struct ggml_tensor* x_0 = load_tensor_from_file(ctx, "samples_ddim.bin"); // print_ggml_tensor(x_0); @@ -2081,7 +2141,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g sd_img_gen_params->input_id_images_path, ref_latents, concat_latent, - denoise_mask); + denoise_mask, + sd_img_gen_params->shifted_timestep); size_t t2 = ggml_time_ms(); diff --git a/stable-diffusion.h b/stable-diffusion.h index a6032592..2b4ee845 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -57,6 +57,8 @@ enum schedule_t { EXPONENTIAL, AYS, GITS, + SGM_UNIFORM, + SIMPLE, SCHEDULE_COUNT }; @@ -184,6 +186,7 @@ typedef struct { float style_strength; bool normalize_input; const char* input_id_images_path; + int shifted_timestep; } sd_img_gen_params_t; typedef struct { From 9f44a8a2f78330c444c202235d6492f75072c6ad Mon Sep 17 00:00:00 2001 From: rmatif Date: Sat, 2 Aug 2025 01:07:04 +0000 Subject: [PATCH 2/3] update readme --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 89eb095e..1011866e 100644 --- a/README.md +++ b/README.md @@ -332,7 +332,7 @@ arguments: --rng {std_default, cuda} RNG (default: cuda) -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0) -b, --batch-count COUNT number of images to generate - --schedule {discrete, karras, exponential, ays, gits} Denoiser sigma schedule (default: discrete) + --schedule {discrete, karras, exponential, ays, gits, simple, sgm_uniform} Denoiser sigma schedule (default: discrete) --clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1) <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x --vae-tiling process vae in tiles to reduce memory usage @@ -347,6 +347,7 @@ arguments: --chroma-disable-dit-mask disable dit mask for chroma --chroma-enable-t5-mask enable t5 mask for chroma --chroma-t5-mask-pad PAD_SIZE t5 mask pad size of chroma + --timestep-shift N shift timestep, default: -1 off, recommended N for NitroSD-Realism around 250 and 500 for NitroSD-Vibrant -v, --verbose print extra info ``` From fb4fed4f8e5671280ebccbbf747cf95c3d711431 Mon Sep 17 00:00:00 2001 From: rmatif Date: Mon, 4 Aug 2025 13:17:55 +0000 Subject: [PATCH 3/3] fix spaces --- stable-diffusion.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 258c419a..3a3597bd 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -859,8 +859,8 @@ class StableDiffusionGGML { SDCondition id_cond, std::vector ref_latents = {}, ggml_tensor* denoise_mask = nullptr, - int shifted_timestep = -1) { - + int shifted_timestep = -1) { + if (shifted_timestep > 0 && !sd_version_is_sdxl(version)) { LOG_WARN("Timestep shifting is only supported for SDXL models. Ignoring --timestep-shift."); shifted_timestep = -1;