diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index c5448f92..b27a2a99 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -386,6 +386,10 @@ class StableDiffusionGGML { diffusion_model->alloc_params_buffer(); diffusion_model->get_param_tensors(tensors); + if (sd_version_is_unet_edit(version)) { + vae_decode_only = false; + } + if (!use_tiny_autoencoder) { if (sd_ctx_params->keep_vae_on_cpu && !ggml_backend_is_cpu(backend)) { LOG_INFO("VAE Autoencoder: Using CPU backend"); @@ -2037,19 +2041,36 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g init_latent = generate_init_latent(sd_ctx, work_ctx, width, height); } + sd_guidance_params_t guidance = sd_img_gen_params->guidance; + std::vector ref_images; + for (int i = 0; i < sd_img_gen_params->ref_images_count; i++) { + ref_images.push_back(&sd_img_gen_params->ref_images[i]); + } + + std::vector empty_image_data; + sd_image_t empty_image = {(uint32_t)width, (uint32_t)height, 3, nullptr}; + if (ref_images.empty() && sd_version_is_unet_edit(sd_ctx->sd->version)) + { + LOG_WARN("This model needs at least one reference image; using empty reference"); + empty_image_data.reserve(width * height * 3); + ref_images.push_back(&empty_image); + empty_image.data = empty_image_data.data(); + guidance.img_cfg = 0.f; + } + if (sd_img_gen_params->ref_images_count > 0) { LOG_INFO("EDIT mode"); } std::vector ref_latents; - for (int i = 0; i < sd_img_gen_params->ref_images_count; i++) { + for (int i = 0; i < ref_images.size(); i++) { ggml_tensor* img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, - sd_img_gen_params->ref_images[i].width, - sd_img_gen_params->ref_images[i].height, + ref_images[i]->width, + ref_images[i]->height, 3, 1); - sd_image_to_tensor(sd_img_gen_params->ref_images[i].data, img); + sd_image_to_tensor(ref_images[i]->data, img); ggml_tensor* latent = NULL; if (sd_ctx->sd->use_tiny_autoencoder) { @@ -2082,7 +2103,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g SAFE_STR(sd_img_gen_params->prompt), SAFE_STR(sd_img_gen_params->negative_prompt), sd_img_gen_params->clip_skip, - sd_img_gen_params->guidance, + guidance, sd_img_gen_params->eta, width, height,