From 941428e12a0dc9eb23843694fda797fd842d32a4 Mon Sep 17 00:00:00 2001 From: Mert Erbak <71733533+merterbak@users.noreply.github.com> Date: Fri, 25 Apr 2025 23:43:35 +0000 Subject: [PATCH 1/2] Set LANCZOS as default interpolation mode for resizing --- examples/dreambooth/train_dreambooth_lora.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 891ac2e2027b..02d3473fbbce 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -524,6 +524,15 @@ def parse_args(input_args=None): default=4, help=("The dimension of the LoRA update matrices."), ) + parser.add_argument( + "--image_interpolation_mode", + type=str, + default="lanczos", + choices=[ + f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__") + ], + help="The image interpolation method to use for resizing images.", + ) if input_args is not None: args = parser.parse_args(input_args) @@ -601,9 +610,13 @@ def __init__( else: self.class_data_root = None + interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None) + if interpolation is None: + raise ValueError(f"Unsupported interpolation mode {interpolation=}.") + self.image_transforms = transforms.Compose( [ - transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.Resize(size, interpolation=interpolation), transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), From 2a636e0c3d59e6e806eb8917c1eba07618c76390 Mon Sep 17 00:00:00 2001 From: Mert Erbak <71733533+merterbak@users.noreply.github.com> Date: Sat, 26 Apr 2025 04:36:30 +0000 Subject: [PATCH 2/2] [train_dreambooth_lora.py] Set LANCZOS as default interpolation mode for resizing --- examples/dreambooth/train_dreambooth_lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 02d3473fbbce..a9552c14cad1 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -529,7 +529,7 @@ def parse_args(input_args=None): type=str, default="lanczos", choices=[ - f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__") + f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__") ], help="The image interpolation method to use for resizing images.", )