Skip to content

use Euler sampling by default for SD3 and Flux #753

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ struct SDParams {
int fps = 6;
float augmentation_level = 0.f;

sample_method_t sample_method = EULER_A;
sample_method_t sample_method = SAMPLE_METHOD_DEFAULT;
schedule_t schedule = DEFAULT;
int sample_steps = 20;
float strength = 0.75f;
Expand Down Expand Up @@ -222,7 +222,7 @@ void print_usage(int argc, const char* argv[]) {
printf(" -H, --height H image height, in pixel space (default: 512)\n");
printf(" -W, --width W image width, in pixel space (default: 512)\n");
printf(" --sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd}\n");
printf(" sampling method (default: \"euler_a\")\n");
printf(" sampling method (default: \"euler\" for Flux/SD3, \"euler_a\" otherwise)\n");
printf(" --steps STEPS number of sample steps (default: 20)\n");
printf(" --rng {std_default, cuda} RNG (default: cuda)\n");
printf(" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n");
Expand All @@ -236,9 +236,9 @@ void print_usage(int argc, const char* argv[]) {
printf(" --diffusion-fa use flash attention in the diffusion model (for low vram)\n");
printf(" Might lower quality, since it implies converting k and v to f16.\n");
printf(" This might crash if it is not supported by the backend.\n");
printf(" --diffusion-conv-direct use Conv2d direct in the diffusion model");
printf(" --diffusion-conv-direct use Conv2d direct in the diffusion model\n");
printf(" This might crash if it is not supported by the backend.\n");
printf(" --vae-conv-direct use Conv2d direct in the vae model (should improve the performance)");
printf(" --vae-conv-direct use Conv2d direct in the vae model (should improve the performance)\n");
printf(" This might crash if it is not supported by the backend.\n");
printf(" --control-net-cpu keep controlnet in cpu (for low vram)\n");
printf(" --canny apply canny preprocessor (edge detection)\n");
Expand Down Expand Up @@ -925,6 +925,10 @@ int main(int argc, const char* argv[]) {
return 1;
}

if (params.sample_method == SAMPLE_METHOD_DEFAULT) {
params.sample_method = sd_get_default_sample_method (sd_ctx);
}

sd_image_t input_image = {(uint32_t)params.width,
(uint32_t)params.height,
3,
Expand Down
29 changes: 24 additions & 5 deletions stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ const char* model_version_to_str[] = {
"Flux Fill"};

const char* sampling_methods_str[] = {
"Euler A",
"default",
"Euler",
"Heun",
"DPM2",
Expand All @@ -50,7 +50,8 @@ const char* sampling_methods_str[] = {
"iPNDM_v",
"LCM",
"DDIM \"trailing\"",
"TCD"};
"TCD",
"Euler A"};

/*================================================== Helper Functions ================================================*/

Expand Down Expand Up @@ -1251,7 +1252,7 @@ enum rng_type_t str_to_rng_type(const char* str) {
}

const char* sample_method_to_str[] = {
"euler_a",
"default",
"euler",
"heun",
"dpm2",
Expand All @@ -1263,6 +1264,7 @@ const char* sample_method_to_str[] = {
"lcm",
"ddim_trailing",
"tcd",
"euler_a",
};

const char* sd_sample_method_name(enum sample_method_t sample_method) {
Expand Down Expand Up @@ -1399,7 +1401,7 @@ void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params) {
sd_img_gen_params->ref_images_count = 0;
sd_img_gen_params->width = 512;
sd_img_gen_params->height = 512;
sd_img_gen_params->sample_method = EULER_A;
sd_img_gen_params->sample_method = SAMPLE_METHOD_DEFAULT;
sd_img_gen_params->sample_steps = 20;
sd_img_gen_params->eta = 0.f;
sd_img_gen_params->strength = 0.75f;
Expand Down Expand Up @@ -1524,6 +1526,18 @@ void free_sd_ctx(sd_ctx_t* sd_ctx) {
free(sd_ctx);
}

SD_API enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx)
{
if (sd_ctx != NULL && sd_ctx->sd != NULL) {
SDVersion version = sd_ctx->sd->version;
if (sd_version_is_dit(version))
return EULER;
else
return EULER_A;
}
return SAMPLE_METHOD_COUNT;
}

sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
struct ggml_context* work_ctx,
ggml_tensor* init_latent,
Expand Down Expand Up @@ -2076,6 +2090,11 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
LOG_INFO("encode_first_stage completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
}

enum sample_method_t sample_method = sd_img_gen_params->sample_method;
if (sample_method == SAMPLE_METHOD_DEFAULT) {
sample_method = sd_get_default_sample_method (sd_ctx);
}

sd_image_t* result_images = generate_image_internal(sd_ctx,
work_ctx,
init_latent,
Expand All @@ -2086,7 +2105,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
sd_img_gen_params->eta,
width,
height,
sd_img_gen_params->sample_method,
sample_method,
sigmas,
seed,
sd_img_gen_params->batch_count,
Expand Down
4 changes: 3 additions & 1 deletion stable-diffusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ enum rng_type_t {
};

enum sample_method_t {
EULER_A,
SAMPLE_METHOD_DEFAULT,
EULER,
HEUN,
DPM2,
Expand All @@ -47,6 +47,7 @@ enum sample_method_t {
LCM,
DDIM_TRAILING,
TCD,
EULER_A,
SAMPLE_METHOD_COUNT
};

Expand Down Expand Up @@ -227,6 +228,7 @@ SD_API char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params);

SD_API sd_ctx_t* new_sd_ctx(const sd_ctx_params_t* sd_ctx_params);
SD_API void free_sd_ctx(sd_ctx_t* sd_ctx);
SD_API enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx);

SD_API void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params);
SD_API char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params);
Expand Down
Loading