Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 91 additions & 36 deletions src/transformers/models/pixtral/image_processing_pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
"""Image processor class for Pixtral."""

from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np

Expand All @@ -31,6 +31,7 @@
get_image_size,
infer_channel_dimension_format,
is_scaled_image,
is_valid_image,
make_list_of_images,
to_numpy_array,
valid_images,
Expand All @@ -48,7 +49,40 @@
import PIL


# Adapted from function in image_transforms.py t oensure any transparent pixels are converted to white.
# Copied from transformers.models.idefics2.image_processing_idefics2.make_list_of_images
def make_list_of_images(images: ImageInput) -> List[List[np.ndarray]]:
"""
Convert a single image or a list of images to a list of numpy arrays.

Args:
images (`ImageInput`):
A single image or a list of images.

Returns:
A list of numpy arrays.
"""
# If it's a single image, convert it to a list of lists
if is_valid_image(images):
images = [[images]]
# If it's a list of images, it's a single batch, so convert it to a list of lists
elif isinstance(images, (list, tuple)) and len(images) > 0 and is_valid_image(images[0]):
images = [images]
# If it's a list of batches, it's already in the right format
elif (
isinstance(images, (list, tuple))
and len(images) > 0
and isinstance(images[0], (list, tuple))
and is_valid_image(images[0][0])
):
pass
else:
raise ValueError(
"Invalid input type. Must be a single image, a list of images, or a list of batches of images."
)
return images


# Adapted from function in image_transforms.py to ensure any transparent pixels are converted to white.
def convert_to_rgb(image: ImageInput) -> ImageInput:
"""
Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image
Expand Down Expand Up @@ -134,6 +168,18 @@ def get_resize_output_image_size(
return num_height_tokens * patch_height, num_width_tokens * patch_width


# Hack to get tensor conversion used in BatchFeature without batching the images
def _get_is_as_tensor_fns(tensor_type: Union[str, TensorType]) -> Tuple[Callable, Callable]:
return BatchFeature()._get_is_as_tensor_fns(tensor_type)


def convert_to_tensor(array, tensor_type: Union[str, TensorType]) -> Any:
is_tensor, as_tensor = _get_is_as_tensor_fns(tensor_type)
if is_tensor(array):
return array
return as_tensor(array)


class PixtralImageProcessor(BaseImageProcessor):
r"""
Constructs a Pixtral image processor.
Expand Down Expand Up @@ -333,11 +379,11 @@ def preprocess(
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
"""
patch_size = patch_size if patch_size is not None else self.patch_size
patch_size = get_size_dict(patch_size, default_to_square=True)

do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
patch_size = patch_size if patch_size is not None else self.patch_size
resample = resample if resample is not None else self.resample
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
Expand All @@ -348,13 +394,14 @@ def preprocess(

validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)

images = make_list_of_images(images)
images_list = make_list_of_images(images)

if not valid_images(images):
if not valid_images(images_list[0][0]):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)

validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
Expand All @@ -367,46 +414,54 @@ def preprocess(
)

if do_convert_rgb:
images = [convert_to_rgb(image) for image in images]
images_list = [[convert_to_rgb(image) for image in images] for images in images_list]

# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
images_list = [[to_numpy_array(image) for image in images] for images in images_list]

if is_scaled_image(images[0]) and do_rescale:
if is_scaled_image(images_list[0][0]) and do_rescale:
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)

if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])

all_images = []
for image in images:
if do_resize:
image = self.resize(
image=image,
size=size,
patch_size=patch_size,
resample=resample,
input_data_format=input_data_format,
)

if do_rescale:
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)

if do_normalize:
image = self.normalize(
image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
)

all_images.append(image)

images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
for image in all_images
input_data_format = infer_channel_dimension_format(images_list[0][0])

batch_images = []
batch_image_sizes = []
for sample_images in images_list:
images = []
image_sizes = []
for image in sample_images:
if do_resize:
image = self.resize(
image=image,
size=size,
patch_size=patch_size,
resample=resample,
input_data_format=input_data_format,
)

if do_rescale:
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)

if do_normalize:
image = self.normalize(
image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
)

images.append(image)
image_sizes.append(get_image_size(image, input_data_format))
batch_images.append(images)
batch_image_sizes.append(image_sizes)

images_list = [
[to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images]
for images in batch_images
]

data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)
# Convert to tensor type outside of BatchFeature to avoid batching the images of different sizes
images_list = [[convert_to_tensor(image, return_tensors) for image in images] for images in images_list]
return BatchFeature(data={"images": images_list, "image_sizes": batch_image_sizes}, tensor_type=None)
40 changes: 26 additions & 14 deletions src/transformers/models/pixtral/processing_pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import List, Optional, Union

from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, get_image_size, to_numpy_array
from ...image_utils import ImageInput
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from ...utils import TensorType, logging
Expand Down Expand Up @@ -146,21 +146,33 @@ def __call__(

# try to expand inputs in processing if we have the necessary parts
prompt_strings = text
if image_inputs.get("pixel_values") is not None:
if image_inputs.get("images") is not None:
# Replace the image token with the expanded image token sequence
pixel_values = image_inputs["pixel_values"]
height, width = get_image_size(to_numpy_array(pixel_values[0]))
num_height_tokens = height // self.patch_size
num_width_tokens = width // self.patch_size

images = image_inputs["images"]
image_sizes = image_inputs.pop("image_sizes")
prompt_strings = []
replace_tokens = [[self.image_token] * num_width_tokens + [self.image_break_token]] * num_height_tokens
# Flatten list
replace_tokens = [item for sublist in replace_tokens for item in sublist]
replace_tokens[-1] = self.image_end_token
replace_str = "".join(replace_tokens)
for sample in text:
sample = sample.replace(self.image_token, replace_str)

for sample_images, sample_image_sizes, sample in zip(images, image_sizes, text):
replace_strings = []
# First calculate the number of tokens needed for each image and put in a placeholder
for image, image_size in zip(sample_images, sample_image_sizes):
height, width = image_size
num_height_tokens = height // self.patch_size
num_width_tokens = width // self.patch_size
replace_tokens = [
[self.image_token] * num_width_tokens + [self.image_break_token]
] * num_height_tokens
# Flatten list
replace_tokens = [item for sublist in replace_tokens for item in sublist]
replace_tokens[-1] = self.image_end_token
replace_str = "".join(replace_tokens)
replace_strings.append(replace_str)
sample = sample.replace(self.image_token, "<placeholder>", 1)

while "<placeholder>" in sample:
replace_str = replace_strings.pop(0)
sample = sample.replace("<placeholder>", replace_str, 1)

prompt_strings.append(sample)

text_inputs = self.tokenizer(
Expand Down
Loading