diff --git a/model.cpp b/model.cpp index df1c8637..172b21c5 100644 --- a/model.cpp +++ b/model.cpp @@ -2156,7 +2156,7 @@ std::vector> parse_tensor_type_rules(const std if (type_name == "f32") { tensor_type = GGML_TYPE_F32; } else { - for (size_t i = 0; i < SD_TYPE_COUNT; i++) { + for (size_t i = 0; i < GGML_TYPE_COUNT; i++) { auto trait = ggml_get_type_traits((ggml_type)i); if (trait->to_float && trait->type_size && type_name == trait->type_name) { tensor_type = (ggml_type)i; diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index c5448f92..87e628fe 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -241,7 +241,9 @@ class StableDiffusionGGML { } LOG_INFO("Version: %s ", model_version_to_str[version]); - ggml_type wtype = (ggml_type)sd_ctx_params->wtype; + ggml_type wtype = (int)sd_ctx_params->wtype < std::min(SD_TYPE_COUNT, GGML_TYPE_COUNT) + ? (ggml_type)sd_ctx_params->wtype + : GGML_TYPE_COUNT; if (wtype == GGML_TYPE_COUNT) { model_wtype = model_loader.get_sd_wtype(); if (model_wtype == GGML_TYPE_COUNT) { @@ -269,11 +271,6 @@ class StableDiffusionGGML { model_loader.set_wtype_override(wtype); } - if (sd_version_is_sdxl(version)) { - vae_wtype = GGML_TYPE_F32; - model_loader.set_wtype_override(GGML_TYPE_F32, "vae."); - } - LOG_INFO("Weight type: %s", model_wtype != GGML_TYPE_COUNT ? ggml_type_name(model_wtype) : "??"); LOG_INFO("Conditioner weight type: %s", conditioner_wtype != GGML_TYPE_COUNT ? ggml_type_name(conditioner_wtype) : "??"); LOG_INFO("Diffusion model weight type: %s", diffusion_model_wtype != GGML_TYPE_COUNT ? ggml_type_name(diffusion_model_wtype) : "??"); @@ -1216,11 +1213,14 @@ class StableDiffusionGGML { #define NONE_STR "NONE" const char* sd_type_name(enum sd_type_t type) { - return ggml_type_name((ggml_type)type); + if ((int) type < std::min(SD_TYPE_COUNT, GGML_TYPE_COUNT)) { + return ggml_type_name((ggml_type)type); + } + return NONE_STR; } enum sd_type_t str_to_sd_type(const char* str) { - for (int i = 0; i < SD_TYPE_COUNT; i++) { + for (int i = 0; i < std::min(SD_TYPE_COUNT, GGML_TYPE_COUNT); i++) { auto trait = ggml_get_type_traits((ggml_type)i); if (!strcmp(str, trait->type_name)) { return (enum sd_type_t)i;