diff --git a/.gitignore b/.gitignore index 77b0c35..9fb91fa 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ redis_data/appendonlydir/ redis_data/dump.rdb -logs/ \ No newline at end of file +logs/ +.cache \ No newline at end of file diff --git a/directai_fastapi/Dockerfile b/directai_fastapi/Dockerfile index 56221a3..7be5414 100644 --- a/directai_fastapi/Dockerfile +++ b/directai_fastapi/Dockerfile @@ -1,9 +1,11 @@ -FROM pytorch/pytorch:2.2.0-cuda12.1-cudnn8-runtime +FROM pytorch/pytorch:2.2.0-cuda12.1-cudnn8-devel WORKDIR /directai_fastapi RUN apt-get update RUN apt-get install libgl1 libglib2.0-0 libsm6 libxrender1 libxext6 -y +RUN apt-get install git -y + RUN apt-get install cmake build-essential -y COPY requirements.txt . RUN pip install -r requirements.txt diff --git a/directai_fastapi/modeling/distributed_backend.py b/directai_fastapi/modeling/distributed_backend.py index b83a0bc..fc3d7ac 100644 --- a/directai_fastapi/modeling/distributed_backend.py +++ b/directai_fastapi/modeling/distributed_backend.py @@ -9,6 +9,7 @@ from typing import List from pydantic_models import ClassifierResponse, SingleDetectionResponse from modeling.image_classifier import ZeroShotImageClassifierWithFeedback +from modeling.object_detector import ZeroShotObjectDetectorWithFeedback serve.start(http_options={"port": 8100}) @@ -16,15 +17,51 @@ @serve.deployment class ObjectDetector: - async def __call__(self, image: Image.Image) -> List[List[SingleDetectionResponse]]: - # Placeholder implementation - single_detection = { - "tlbr": [0.0, 0.0, 1.0, 1.0], - "score": random.random(), - "class": "dog", - } - sdr = SingleDetectionResponse.parse_obj(single_detection) - return [[sdr]] + def __init__(self) -> None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.model = ZeroShotObjectDetectorWithFeedback(device=device) + + async def __call__( + self, + image: bytes, + labels: list[str], + inc_sub_labels_dict: dict[str, list[str]], + exc_sub_labels_dict: dict[str, list[str]] | None = None, + label_conf_thres: dict[str, float] | None = None, + augment_examples: bool = True, + nms_thre: float = 0.4, + run_class_agnostic_nms: bool = False, + ) -> list[SingleDetectionResponse]: + with torch.inference_mode(), torch.autocast(str(self.model.device)): + batched_predicted_boxes = self.model( + image, + labels=labels, + inc_sub_labels_dict=inc_sub_labels_dict, + exc_sub_labels_dict=exc_sub_labels_dict, + label_conf_thres=label_conf_thres, + augment_examples=augment_examples, + nms_thre=nms_thre, + run_class_agnostic_nms=run_class_agnostic_nms, + ) + + # since we are processing a single image, the output has batch size 1, so we can safely index into it + per_label_boxes = batched_predicted_boxes[0] + + # predicted_boxes is a list in order of labels, with each box of the form [x1, y1, x2, y2, confidence] + detection_responses = [] + for label, boxes in zip(labels, per_label_boxes): + for detection in boxes: + det_dict = { + "tlbr": detection[:4].tolist(), + "score": detection[4].item(), + "class_": label, + } + single_detection_response = SingleDetectionResponse.parse_obj( + det_dict + ) + detection_responses.append(single_detection_response) + + return detection_responses @serve.deployment @@ -41,8 +78,7 @@ async def __call__( exc_sub_labels_dict: dict[str, list[str]] | None = None, augment_examples: bool = True, ) -> ClassifierResponse: - - with torch.no_grad(), torch.autocast(str(self.model.device)): + with torch.inference_mode(), torch.autocast(str(self.model.device)): raw_scores = self.model( image, labels=labels, diff --git a/directai_fastapi/modeling/image_classifier.py b/directai_fastapi/modeling/image_classifier.py index f0467e5..074a876 100644 --- a/directai_fastapi/modeling/image_classifier.py +++ b/directai_fastapi/modeling/image_classifier.py @@ -8,6 +8,7 @@ from modeling.tensor_utils import ( batch_encode_cache_missed_list_elements, image_bytes_to_tensor, + squish_labels, ) from modeling.prompt_templates import noop_hypothesis_formats, many_hypothesis_formats from lru import LRU @@ -134,35 +135,6 @@ def encode_text(self, text: list[str], augment: bool = True) -> torch.Tensor: self.not_augmented_label_encoding_cache, ) - def squish_labels( - self, - labels: list[str], - inc_sub_labels_dict: dict[str, list[str]], - exc_sub_labels_dict: dict[str, list[str]], - ) -> tuple[list[str], dict[str, int]]: - # build one list of labels to encode, without duplicates - # and lists / dicts containing the indices of each label - # and the indices of each label's sub-labels - all_labels_to_inds: dict[str, int] = {} - all_labels = [] - - for label in labels: - inc_subs = inc_sub_labels_dict.get(label) - if inc_subs is not None: - for inc_sub in inc_subs: - if inc_sub not in all_labels_to_inds: - all_labels_to_inds[inc_sub] = len(all_labels_to_inds) - all_labels.append(inc_sub) - - exc_subs = exc_sub_labels_dict.get(label) - if exc_subs is not None: - for exc_sub in exc_subs: - if exc_sub not in all_labels_to_inds: - all_labels_to_inds[exc_sub] = len(all_labels_to_inds) - all_labels.append(exc_sub) - - return all_labels, all_labels_to_inds - def forward( self, image: torch.Tensor | bytes, @@ -189,7 +161,7 @@ def forward( label: excs for label, excs in exc_sub_labels_dict.items() if len(excs) > 0 } - all_labels, all_labels_to_inds = self.squish_labels( + all_labels, all_labels_to_inds = squish_labels( labels, inc_sub_labels_dict, exc_sub_labels_dict ) text_features = self.encode_text(all_labels, augment=augment_examples) diff --git a/directai_fastapi/modeling/object_detector.py b/directai_fastapi/modeling/object_detector.py new file mode 100644 index 0000000..b29274a --- /dev/null +++ b/directai_fastapi/modeling/object_detector.py @@ -0,0 +1,810 @@ +from typing import List, Optional, Tuple +import torch +from PIL import Image +import torch +from torch import nn +import torchvision # type: ignore[import-untyped] +import numpy as np +from transformers import Owlv2Processor, Owlv2ForObjectDetection, Owlv2VisionModel # type: ignore[import-untyped] +from transformers.models.owlv2.modeling_owlv2 import Owlv2Attention # type: ignore[import-untyped] +import time +from typing import Union +from torch_scatter import scatter_max # type: ignore[import-untyped] +from flash_attn import flash_attn_func # type: ignore[import-untyped] +import io +from lru import LRU +from functools import partial + +from modeling.prompt_templates import medium_hypothesis_formats, noop_hypothesis_formats +from modeling.tensor_utils import ( + batch_encode_cache_missed_list_elements, + resize_pil_image, + squish_labels, +) + + +def created_padded_tensor_from_bytes( + image_bytes: bytes, image_size: tuple[int, int] +) -> tuple[torch.Tensor, torch.Tensor]: + padded_image_tensor = torch.ones((1, 3, *image_size)) * 114.0 + + # TODO: add nonblocking streaming to GPU + + image_buffer = io.BytesIO(image_bytes) + pil_image = Image.open(image_buffer) + current_size = pil_image.size + + r = min(image_size[0] / current_size[0], image_size[1] / current_size[1]) + target_size = (int(r * current_size[0]), int(r * current_size[1])) + + pil_image = resize_pil_image(pil_image, target_size) + + np_image = np.asarray(pil_image) + torch_image = torch.tensor(np_image).permute(2, 0, 1).unsqueeze(0) + + padded_image_tensor[:, :, : torch_image.shape[2], : torch_image.shape[3]] = ( + torch_image + ) + + image_scale_ratios = torch.tensor( + [ + r, + ] + ) + + return padded_image_tensor, image_scale_ratios + + +def flash_attn_owl_vit_encoder_forward( + self: Owlv2Attention, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, +) -> tuple[torch.Tensor, None]: + assert ( + not output_attentions + ), "output_attentions not supported for flash attention implementation" + assert ( + attention_mask is None + ), "attention_mask not supported for flash attention implementation" + # technically flash_attn DOES support causal attention + # but the OWL usage of causal attention mask does not limit it to true causal attention + # we don't support generalized attention, so we're just going to assert causal attention mask is ALSO None + assert ( + causal_attention_mask is None + ), "causal_attention_mask not supported for flash attention implementation" + + bsz, tgt_len, embed_dim = hidden_states.shape + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.contiguous().view( + bsz, tgt_len, self.num_heads, self.head_dim + ) + key_states = key_states.contiguous().view( + bsz, tgt_len, self.num_heads, self.head_dim + ) + value_states = value_states.contiguous().view( + bsz, tgt_len, self.num_heads, self.head_dim + ) + + # convert to appropriate dtype + # NOTE: bf16 may be more appropriate than fp16 + query_states = query_states.to(torch.float16) + key_states = key_states.to(torch.float16) + value_states = value_states.to(torch.float16) + + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout_p=0, + softmax_scale=self.scale, + ) + + attn_output = attn_output.view(bsz, tgt_len, embed_dim) + + # convert back to appropriate dtype + attn_output = attn_output.to(hidden_states.dtype) + + attn_output = self.out_proj(attn_output) + + return attn_output, None + + +class VisionModelWrapper(nn.Module): + def __init__(self, vision_model: Owlv2VisionModel) -> None: + super().__init__() + + self.vision_model = vision_model + + # we're going to monkey patch the forward method of the attention layers + # to replace it with a faster one based on flash_attn + # the alternative is to subclass the entire model, but that's a lot of work + # so we're just going to define a replacement with the same function signature + # and assert that the input is as supported by flash_attn + for owlv2_vision_model_encoder_layer in self.vision_model.encoder.layers: + owlv2_vision_model_encoder_layer.self_attn.forward = partial( + flash_attn_owl_vit_encoder_forward, + owlv2_vision_model_encoder_layer.self_attn, + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + vision_outputs = self.vision_model(pixel_values=image, return_dict=True) + + # Get image embedding + last_hidden_states = vision_outputs[0] + image_embeds = self.vision_model.post_layernorm(last_hidden_states) + + return image_embeds + + +class WrappedImageEmbedder(nn.Module): + def __init__(self, model: Owlv2ForObjectDetection) -> None: + super().__init__() + + self.model = model + self.wrapped_vision_model = VisionModelWrapper(self.model.owlv2.vision_model) + + def forward( + self, image: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + image_embeds = self.wrapped_vision_model(image) + + # Resize class token + class_token_out = torch.broadcast_to( + image_embeds[:, :1, :], image_embeds[:, 1:, :].shape + ) + + # Merge image embedding with class tokens + image_embeds = image_embeds[:, 1:, :] * class_token_out + image_embeds = self.model.layer_norm(image_embeds) + + # Resize to [batch_size, num_patches, num_patches, hidden_size] + new_size = ( + image_embeds.shape[0], + int(np.sqrt(image_embeds.shape[1])), + int(np.sqrt(image_embeds.shape[1])), + image_embeds.shape[-1], + ) + feature_map = image_embeds.reshape(new_size) + + # Get class head features + # we do dot prod between image_class_embeds and query embeddings, then do (pred + shift) * scale + image_class_embeds = self.model.class_head.dense0(image_embeds) + image_class_embeds = image_class_embeds / image_class_embeds.norm( + dim=-1, keepdim=True, p=2 + ) + logit_shift = self.model.class_head.logit_shift(image_embeds) + logit_scale = self.model.class_head.logit_scale(image_embeds) + logit_scale = self.model.class_head.elu(logit_scale) + 1 + + # Get box head features + # NOTE: this is in a specific format, handle later + pred_boxes = self.model.box_predictor(image_embeds, feature_map) + + # filter out patches that are unlikely to map to objects + # the paper takes the top 10% during training, but we'll take the top 300 to be more in line with DETR + objectness_scores = self.model.objectness_predictor(image_embeds) + # compute the top 300 objectness indices + indices = torch.topk(objectness_scores, 300, dim=1).indices + # filter all the other stuff + image_class_embeds = image_class_embeds.gather( + 1, indices.unsqueeze(-1).expand(-1, -1, image_class_embeds.shape[-1]) + ) + logit_shift = logit_shift.gather( + 1, indices.unsqueeze(-1).expand(-1, -1, logit_shift.shape[-1]) + ) + logit_scale = logit_scale.gather( + 1, indices.unsqueeze(-1).expand(-1, -1, logit_scale.shape[-1]) + ) + pred_boxes = pred_boxes.gather( + 1, indices.unsqueeze(-1).expand(-1, -1, pred_boxes.shape[-1]) + ) + + assert image_class_embeds.shape[1] == 300 + + return image_class_embeds, logit_shift, logit_scale, pred_boxes + + +class ZeroShotObjectDetectorWithFeedback(nn.Module): + def __init__( + self, + model_name: str = "google/owlv2-large-patch14-ensemble", + image_size: tuple[int, int] = (1008, 1008), + max_text_batch_size: int = 32, + max_image_batch_size: int = 32, + device: torch.device | str = "cuda", + lru_cache_size: int = 4096, + jit: bool = True, + ): + super().__init__() + + self.device = device + self.model = ( + Owlv2ForObjectDetection.from_pretrained(model_name).to(device).eval() + ) + self.processor = Owlv2Processor.from_pretrained(model_name) + + if jit: + self.wrapped_image_embedder = torch.jit.trace_module( + WrappedImageEmbedder(self.model), + {"forward": (torch.randn(1, 3, *image_size, device=device),)}, + ) + else: + self.wrapped_image_embedder = WrappedImageEmbedder(self.model) + + self.max_text_batch_size = max_text_batch_size + self.max_image_batch_size = max_image_batch_size + + # we cache the text embeddings to avoid recomputing them + # we use an LRU cache to avoid running out of memory + # especially because likely the tensors will be large and stored in GPU memory + self.augmented_label_encoding_cache: LRU | None = ( + LRU(lru_cache_size) if lru_cache_size > 0 else None + ) + self.not_augmented_label_encoding_cache: LRU | None = ( + LRU(lru_cache_size) if lru_cache_size > 0 else None + ) + + self.image_size = image_size + self.rgb_means = torch.tensor([0.485, 0.456, 0.406], device=device).view( + 1, 3, 1, 1 + ) + self.rgb_stds = torch.tensor([0.229, 0.224, 0.225], device=device).view( + 1, 3, 1, 1 + ) + + def _encode_text(self, text: list[str], augment: bool = True) -> torch.Tensor: + # NOTE: object detector liturature tends to use fewer templates than image classifiers + templates = medium_hypothesis_formats if augment else noop_hypothesis_formats + augmented_text = [template.format(t) for t in text for template in templates] + + embeddings_list = [] + for i in range(0, len(augmented_text), self.max_text_batch_size): + text_subset = augmented_text[i : i + self.max_text_batch_size] + + processor_output = self.processor( + text=text_subset, return_tensors="pt", padding=True, truncation=True + ) + input_ids = processor_output.input_ids.to(self.device) + attn_mask = processor_output.attention_mask.to(self.device) + + # TODO: add appropriate batching to avoid OOM + text_output = self.model.owlv2.text_model( + input_ids=input_ids, attention_mask=attn_mask, return_dict=True + ) + + embeddings = text_output[1] + embeddings = self.model.owlv2.text_projection(embeddings) + embeddings = embeddings / embeddings.norm(dim=1, keepdim=True, p=2) + + embeddings_list.append(embeddings) + + embeddings = torch.cat(embeddings_list, dim=0) + embeddings = embeddings.reshape(len(text), len(templates), embeddings.shape[1]) + embeddings = embeddings.mean(dim=1) + embeddings = embeddings / embeddings.norm(dim=1, keepdim=True, p=2) + + return embeddings + + def encode_text(self, text: list[str], augment: bool = True) -> torch.Tensor: + if augment: + return batch_encode_cache_missed_list_elements( + partial(self._encode_text, augment=True), + text, + self.augmented_label_encoding_cache, + ) + else: + return batch_encode_cache_missed_list_elements( + partial(self._encode_text, augment=False), + text, + self.not_augmented_label_encoding_cache, + ) + + def get_image_data(self, image: torch.Tensor) -> dict[str, torch.Tensor]: + # we do the normalization here to make sure we have access to the right parameters + image = image / 255.0 + image = (image - self.rgb_means) / self.rgb_stds + + image_class_embeds_list = [] + logit_shift_list = [] + logit_scale_list = [] + pred_boxes_list = [] + for i in range(0, image.size(0), self.max_image_batch_size): + image_subset = image[i : i + self.max_image_batch_size] + image_class_embeds, logit_shift, logit_scale, pred_boxes = ( + self.wrapped_image_embedder(image_subset) + ) + + image_class_embeds_list.append(image_class_embeds) + logit_shift_list.append(logit_shift) + logit_scale_list.append(logit_scale) + pred_boxes_list.append(pred_boxes) + + image_class_embeds = torch.cat(image_class_embeds_list) + logit_shift = torch.cat(logit_shift_list) + logit_scale = torch.cat(logit_scale_list) + pred_boxes = torch.cat(pred_boxes_list) + + return { + "image_class_embeds": image_class_embeds, + "logit_shift": logit_shift, + "logit_scale": logit_scale, + "pred_boxes": pred_boxes, + } + + def forward( + self, + image: torch.Tensor | bytes, + labels: list[str], + inc_sub_labels_dict: dict[str, list[str]], + exc_sub_labels_dict: dict[str, list[str]] | None = None, + label_conf_thres: dict[str, float] | None = None, + augment_examples: bool = True, + nms_thre: float = 0.4, + run_class_agnostic_nms: bool = True, + image_scale_ratios: torch.Tensor | None = None, + ) -> list[list[torch.Tensor]]: + if isinstance(image, bytes): + assert ( + image_scale_ratios is None + ), "image_scale_ratios must be None if image is bytes as we define the scale internally" + image_tensor, image_scale_ratios = created_padded_tensor_from_bytes( + image, self.image_size + ) + else: + assert ( + image_scale_ratios is not None + ), "image_scale_ratios must be provided if image is a tensor as we cannot derive the scale internally" + image_tensor = image + + if label_conf_thres is None: + label_conf_thres = {} + + if len(labels) == 0: + raise ValueError("At least one label must be provided") + + if any([len(sub_labels) == 0 for sub_labels in inc_sub_labels_dict.values()]): + raise ValueError("Each label must include at least one sub-label") + + image_tensor = image_tensor.to(self.device) + + image_data = self.get_image_data(image_tensor) + + exc_sub_labels_dict = {} if exc_sub_labels_dict is None else exc_sub_labels_dict + # filter out empty excs lists + exc_sub_labels_dict = { + label: excs for label, excs in exc_sub_labels_dict.items() if len(excs) > 0 + } + + all_labels, all_labels_to_inds = squish_labels( + labels, inc_sub_labels_dict, exc_sub_labels_dict + ) + text_features = self.encode_text(all_labels, augment=augment_examples) + + scores_by_image_and_box = compute_query_fit( + text_features, + image_data["image_class_embeds"], + image_data["logit_shift"], + image_data["logit_scale"], + ) + # NOTE that scores_by_image_and_box is of shape [num_images, num_boxes, len(all_labels)] + # for the extracting of the per-box pro and con scores, we don't care about differentiating the first two dimensions + # so we flatten them to make the scatter_max operation easier + # and then we reshape them back to the original shape + scores = scores_by_image_and_box.view(-1, len(all_labels)) + # now we can proceed in the same way as the image classifier + + label_to_ind = {label: i for i, label in enumerate(labels)} + + pos_labels_to_master_inds, pos_labels_list = zip( + *[ + v + for label, incs in inc_sub_labels_dict.items() + for v in zip([label_to_ind[label]] * len(incs), incs) + ] + ) + pos_labels_inds = [all_labels_to_inds[label] for label in pos_labels_list] + + pos_scores = scores[:, pos_labels_inds] + + # pos_labels_to_master_inds indicates which indices we should be taking the max over for each label + # since our scatter_max will be batched, we need to offset this for each box + num_labels = len(labels) + num_boxes = scores.shape[0] + num_incs = len(pos_labels_to_master_inds) + offsets = ( + torch.arange(num_boxes).unsqueeze(1).expand(-1, num_incs).flatten() + * num_labels + ) + offsets = offsets.to(self.device) + indices_for_max = ( + torch.tensor(pos_labels_to_master_inds).to(self.device).repeat(num_boxes) + + offsets + ) + + max_pos_scores_flat, _ = scatter_max( + pos_scores.view(-1), indices_for_max, dim_size=num_boxes * num_labels + ) + max_pos_scores = max_pos_scores_flat.view(num_boxes, num_labels) + + # compute the same for the negative labels, if any + if len(exc_sub_labels_dict) > 0: + neg_labels_to_master_inds, neg_labels_list = zip( + *[ + v + for label, excs in exc_sub_labels_dict.items() + for v in zip([label_to_ind[label]] * len(excs), excs) + ] + ) + neg_labels_inds = [all_labels_to_inds[label] for label in neg_labels_list] + + neg_scores = scores[:, neg_labels_inds] + + num_excs = len(neg_labels_to_master_inds) + offsets = ( + torch.arange(num_boxes).unsqueeze(1).expand(-1, num_excs).flatten() + * num_labels + ) + offsets = offsets.to(self.device) + indices_for_max = ( + torch.tensor(neg_labels_to_master_inds) + .to(self.device) + .repeat(num_boxes) + + offsets + ) + + max_neg_scores_flat, _ = scatter_max( + neg_scores.view(-1), indices_for_max, dim_size=num_boxes * num_labels + ) + max_neg_scores = max_neg_scores_flat.view(num_boxes, num_labels) + else: + # if we have no negative labels, we just set the max neg scores to zero + # NOTE: possible to speed things up by skipping the ops conditional on having negative labels + max_neg_scores = torch.zeros_like(max_pos_scores) + + # now reshape the scores to [num_images, num_boxes, num_labels] + max_pos_scores = max_pos_scores.view( + image_data["pred_boxes"].shape[0], + image_data["pred_boxes"].shape[1], + num_labels, + ) + max_neg_scores = max_neg_scores.view( + image_data["pred_boxes"].shape[0], + image_data["pred_boxes"].shape[1], + num_labels, + ) + + # unlike the image classifier, we have to suppress boxes based on the scores of their neighbors + # we do this via a modified NMS algorithm + # because it operates over a variable-sized graph of boxes, it's hard to vectorize + # so we dump it into a script function that does fork-based async processing + # the output is a per-image list of per-object boxes in tlbr-score format + batched_predicted_boxes = batched_run_nms_based_box_suppression_for_all_objects( + max_pos_scores, + max_neg_scores, + image_data["pred_boxes"], + image_tensor.shape[2] / image_scale_ratios, + torch.tensor( + [label_conf_thres.get(label, 0.0) for label in labels], + device=self.device, + ), + nms_thre, + run_class_agnostic_nms, + ) + + return batched_predicted_boxes + + +@torch.jit.script +def compute_query_fit( + query_embeds: torch.Tensor, + image_class_embeds: torch.Tensor, + logit_shift: torch.Tensor, + logit_scale: torch.Tensor, +) -> torch.Tensor: + # Compute query fit + pred_logits = torch.einsum("...pd,...qd->...pq", image_class_embeds, query_embeds) + pred_logits = (pred_logits + logit_shift) * logit_scale + + return torch.sigmoid(pred_logits) + + +@torch.jit.script +def compute_iou_adjacency_list( + boxes: torch.Tensor, nms_thre: float +) -> list[torch.Tensor]: + boxes = boxes.clone() + # boxes are in cxcywh format, we need to convert them to tlbr format + boxes[:, 0] -= boxes[:, 2] / 2 + boxes[:, 1] -= boxes[:, 3] / 2 + boxes[:, 2] = boxes[:, 0] + boxes[:, 2] + boxes[:, 3] = boxes[:, 1] + boxes[:, 3] + ious = torchvision.ops.box_iou(boxes, boxes) + ious = ious >= nms_thre + # Set diagonal elements to zero .. no self loops! + ious.fill_diagonal_(0) + + edges = torch.nonzero(ious).unbind(-1) + + return edges + + +@torch.jit.script +def find_in_sorted_tensor( + sorted_tensor: torch.Tensor, query: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + indices = torch.searchsorted(sorted_tensor, query) + indices_clamped = torch.clamp(indices, max=sorted_tensor.size(0) - 1) + present = sorted_tensor[indices_clamped] == query + in_bounds = indices < sorted_tensor.size(0) + found_mask = present & in_bounds + return found_mask, indices + + +@torch.jit.script +def run_nms_via_adjacency_list( + valid_box_indices_by_descending_score: torch.Tensor, + adjacency_list: list[torch.Tensor], +) -> torch.Tensor: + # we compute the indices of the start and end of each box's adjacent boxes + # since our graph representation is just a list of edges, we would like to know which edges correspond to which nodes + # as the first node is sorted, we can just take the difference between adjacent nodes to get the start and end of each node's edges + # NOTE: we've already computed the graph so we don't need to supply an NMS threshold + first_node = adjacency_list[0] + zero_tensor = torch.tensor([0], device=first_node.device) + change_inds = (first_node[1:] != first_node[:-1]).nonzero()[:, 0] + len_tensor = torch.tensor([first_node.shape[0]], device=first_node.device) + inds_of_adj_boxes = torch.cat([zero_tensor, change_inds + 1, len_tensor]) + + # we then run a NMS over the boxes + # need to keep track of which boxes have survived, which means an index per possible box + # note that some boxes may have already been filtered out + # so valid_box_indices_by_descending_score may not contain all indices + # TODO: switch this to a sparse tensor for efficiency + highest_possible_ind = max( + valid_box_indices_by_descending_score.max().item(), + adjacency_list[0].max().item(), + adjacency_list[1].max().item(), + ) + assert isinstance(highest_possible_ind, int) # this is just to make mypy happy + has_survived = torch.zeros( + highest_possible_ind + 1, + device=valid_box_indices_by_descending_score.device, + dtype=torch.bool, + ) + has_survived[valid_box_indices_by_descending_score] = 1 + + # check which boxes have graph connections + unique_nodes = first_node[inds_of_adj_boxes[:-1]] + has_connection, graph_node_indices = find_in_sorted_tensor( + unique_nodes, valid_box_indices_by_descending_score + ) + connected_sorted_pro_valid_inds = valid_box_indices_by_descending_score[ + has_connection + ] + graph_indices = graph_node_indices[has_connection] + + # supress the boxes for which their (unsupressed) neighbors have higher scores + for i, j in zip(connected_sorted_pro_valid_inds, graph_indices): + if has_survived[i] == 0: + continue + + remapped_start_ind = inds_of_adj_boxes[j] + remapped_end_ind = inds_of_adj_boxes[j + 1] + adj_boxes = adjacency_list[1][remapped_start_ind:remapped_end_ind] + has_survived[adj_boxes] = 0 + + survive_inds = valid_box_indices_by_descending_score[ + has_survived[valid_box_indices_by_descending_score] + ] + + return survive_inds + + +@torch.jit.script +def compute_candidate_nms_via_adjacency_list( + pro_max: torch.Tensor, + con_max: torch.Tensor, + adjacency_list: list[torch.Tensor], + conf_thre: float, + run_nms: bool, +) -> torch.Tensor: + # we use a scatter_max to efficiently compute, for each bounding box, the max con score of its adjacent boxes + expanded_con_max = con_max[adjacency_list[0]] + adjacent_con_max = scatter_max( + expanded_con_max, adjacency_list[1], dim_size=con_max.shape[0] + )[0] + # we then filter down to boxes that both exceed the confidence threshold and are not suppressed by negative examples + # we do this by filtering on three expressions: + # 1. pro_max >= conf_thre: the box has a high enough confidence + # 2. pro_max >= adjacent_con_max: the box has a higher confidence than any adjacent boxes have negative confidence + # 3. pro_max >= con_max: the box has a higher confidence than its own negative confidence + # NOTE: could make this more efficient perhaps by filtering out the easy ones prior to the scatter_max + pro_valid = ( + (pro_max >= conf_thre) * (pro_max >= adjacent_con_max) * (pro_max >= con_max) + ) + pro_valid_inds = pro_valid.nonzero().squeeze(1) + + if pro_valid_inds.numel() == 0 or adjacency_list[0].numel() == 0: + # no boxes are valid or no boxes have any overlap with any other boxes + # either way, we can skip the NMS step + return pro_valid_inds + + # remove reported overlaps with boxes that are not valid + # this shrinks the graph that we need to do NMS over + first_node_valid, _ = find_in_sorted_tensor(pro_valid_inds, adjacency_list[0]) + second_node_valid, _ = find_in_sorted_tensor(pro_valid_inds, adjacency_list[1]) + + nms_inds = torch.nonzero(first_node_valid * second_node_valid).squeeze(1) + modified_adjacency_list = [adjacency_list[0][nms_inds], adjacency_list[1][nms_inds]] + + if nms_inds.numel() == 0: + # none of the remaining boxes have any overlap with any other remaining boxes + # so we can skip the NMS step + survive_inds = pro_valid.nonzero().squeeze(1) + return survive_inds + + survive_inds = pro_valid_inds[pro_max[pro_valid].argsort(descending=True)] + + if run_nms: + # we then run a NMS over the (remaining) boxes + survive_inds = run_nms_via_adjacency_list(survive_inds, modified_adjacency_list) + + return survive_inds + + +@torch.jit.script +def run_nms_based_box_suppression_for_one_object( + pro_max: torch.Tensor, + con_max: torch.Tensor, + pred_boxes: torch.Tensor, + adjacency_list: list[torch.Tensor], + image_scale: float, + conf_thre: float = 0.001, + run_nms: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + survive_inds = compute_candidate_nms_via_adjacency_list( + pro_max, con_max, adjacency_list, conf_thre, run_nms + ) + + boxes = pred_boxes[survive_inds] + logits = pro_max[survive_inds] + + logits = logits.unsqueeze(-1) + boxes = boxes * image_scale + + # Convert boxes from center_x, center_y, width, height (cx_cy_w_h) to top_left_x, top_left_y, bottom_right_x, bottom_right_y (tlbr) + cx, cy, w, h = boxes.unbind(-1) + tl_x = cx - 0.5 * w + tl_y = cy - 0.5 * h + br_x = cx + 0.5 * w + br_y = cy + 0.5 * h + boxes = torch.stack([tl_x, tl_y, br_x, br_y], dim=-1) + + boxes_with_scores = torch.cat([boxes, logits], dim=-1) + ordered_by_logit = boxes_with_scores[:, 4].argsort(descending=True) + boxes_with_scores = boxes_with_scores[ordered_by_logit] + survive_inds = survive_inds[ordered_by_logit] + + return boxes_with_scores, survive_inds + + +@torch.jit.script +def run_nms_based_box_suppression_for_all_objects( + pro_max: torch.Tensor, + con_max: torch.Tensor, + pred_boxes: torch.Tensor, + image_scale: float, + conf_thres: torch.Tensor, + nms_thre: float = 0.4, + run_class_agnostic_nms: bool = True, +) -> list[torch.Tensor]: + # pred_boxes is assumed to be [num_boxes, 4] + # pro_max and con_max are assumed to be [num_boxes, num_objects] + # conf_thres is assumed to be [num_objects] + adjacency_list = compute_iou_adjacency_list(pred_boxes, nms_thre) + + futures = [ + torch.jit.fork( + run_nms_based_box_suppression_for_one_object, + pro_max[:, i], + con_max[:, i], + pred_boxes, + adjacency_list, + image_scale, + conf_thres[i], + run_nms=not run_class_agnostic_nms, + ) + for i in range(pro_max.shape[1]) + ] + + # there appears to be a bug in JIT related to star expansion that stops us from just using the following list comprehension: + # object_boxes, box_indices = zip(*[torch.jit.wait(fut) for fut in futures]) + object_boxes: list[torch.Tensor] = [] + box_indices: list[torch.Tensor] = [] + for fut in futures: + boxes_with_scores, survive_inds = torch.jit.wait(fut) + object_boxes.append(boxes_with_scores) + box_indices.append(survive_inds) + + if run_class_agnostic_nms: + # first take the top class prediction for each box + # we will of course use a scatter_max for this + all_object_boxes = torch.cat(object_boxes, dim=0) + all_box_indices = torch.cat(box_indices, dim=0) + all_box_class_assignments = torch.cat( + [ + torch.full((box.shape[0],), i, dtype=torch.long, device=box.device) + for i, box in enumerate(object_boxes) + ], + dim=0, + ) + all_object_confidences = all_object_boxes[:, 4] + + # we use an out argument here so we can control the default value + box_max_confidences = torch.zeros(pred_boxes.shape[0], device=pred_boxes.device) + _, max_confidence_indices = scatter_max( + all_object_confidences, all_box_indices, out=box_max_confidences + ) + + survived_inds = box_max_confidences.nonzero().squeeze(1) + + # now filter the survived inds via NMS + sorted_survived_inds = survived_inds[ + box_max_confidences[survived_inds].argsort(descending=True) + ] + post_nms_survived_inds = run_nms_via_adjacency_list( + sorted_survived_inds, adjacency_list + ) + + # and accumulate the boxes + object_survived_inds = max_confidence_indices[post_nms_survived_inds] + survived_class_assignments = all_box_class_assignments[object_survived_inds] + survived_object_boxes = all_object_boxes[object_survived_inds] + + # loop through based on class_assignments to create list of per-class boxes + # TODO: this could be done more efficiently + object_boxes = [ + survived_object_boxes[survived_class_assignments == i] + for i in range(pro_max.shape[1]) + ] + + return object_boxes + + +@torch.jit.script +def batched_run_nms_based_box_suppression_for_all_objects( + pro_max: torch.Tensor, + con_max: torch.Tensor, + pred_boxes: torch.Tensor, + image_scales: torch.Tensor, + conf_thres: torch.Tensor, + nms_thre: float = 0.4, + run_class_agnostic_nms: bool = False, +) -> list[list[torch.Tensor]]: + # pred_boxes is assumed to be [num_images, num_boxes, 4] + # pro_max and con_max are assumed to be [num_images, num_boxes, num_objects] + # conf_thres is assumed to be [num_objects] + # image_scales is assumed to be [num_images] + futures = [ + torch.jit.fork( + run_nms_based_box_suppression_for_all_objects, + pro_max[i], + con_max[i], + pred_boxes[i], + image_scales[i].item(), + conf_thres, + nms_thre, + run_class_agnostic_nms, + ) + for i in range(pro_max.shape[0]) + ] + + batched_predicted_boxes = [torch.jit.wait(fut) for fut in futures] + + return batched_predicted_boxes diff --git a/directai_fastapi/modeling/tensor_utils.py b/directai_fastapi/modeling/tensor_utils.py index ace7ed3..34fec22 100644 --- a/directai_fastapi/modeling/tensor_utils.py +++ b/directai_fastapi/modeling/tensor_utils.py @@ -72,9 +72,9 @@ def batch_encode_cache_missed_list_elements( return output_tensor -def image_bytes_to_tensor(image: bytes, image_size: tuple[int, int]) -> torch.Tensor: - image_buffer = io.BytesIO(image) - pil_image = Image.open(image_buffer) +def resize_pil_image( + pil_image: Image.Image, image_size: tuple[int, int] +) -> Image.Image: if pil_image.format == "JPEG": # try requesting a format-specific conversion # this significantly speeds up the subsequent resize operation @@ -83,6 +83,42 @@ def image_bytes_to_tensor(image: bytes, image_size: tuple[int, int]) -> torch.Te pil_image.draft("RGB", image_size) pil_image = pil_image.convert("RGB") pil_image = pil_image.resize(image_size, Image.BICUBIC) + return pil_image + + +def image_bytes_to_tensor(image: bytes, image_size: tuple[int, int]) -> torch.Tensor: + image_buffer = io.BytesIO(image) + pil_image = Image.open(image_buffer) + pil_image = resize_pil_image(pil_image, image_size) np_image = np.asarray(pil_image) tensor = torch.tensor(np_image).permute(2, 0, 1).unsqueeze(0) return tensor + + +def squish_labels( + labels: list[str], + inc_sub_labels_dict: dict[str, list[str]], + exc_sub_labels_dict: dict[str, list[str]], +) -> tuple[list[str], dict[str, int]]: + # build one list of labels to encode, without duplicates + # and lists / dicts containing the indices of each label + # and the indices of each label's sub-labels + all_labels_to_inds: dict[str, int] = {} + all_labels = [] + + for label in labels: + inc_subs = inc_sub_labels_dict.get(label) + if inc_subs is not None: + for inc_sub in inc_subs: + if inc_sub not in all_labels_to_inds: + all_labels_to_inds[inc_sub] = len(all_labels_to_inds) + all_labels.append(inc_sub) + + exc_subs = exc_sub_labels_dict.get(label) + if exc_subs is not None: + for exc_sub in exc_subs: + if exc_sub not in all_labels_to_inds: + all_labels_to_inds[exc_sub] = len(all_labels_to_inds) + all_labels.append(exc_sub) + + return all_labels, all_labels_to_inds diff --git a/directai_fastapi/pydantic_models.py b/directai_fastapi/pydantic_models.py index d9d6f3b..51772cf 100644 --- a/directai_fastapi/pydantic_models.py +++ b/directai_fastapi/pydantic_models.py @@ -103,28 +103,33 @@ class Config: orm_mode = True async def save_configuration(self, config_cache: redis.Redis) -> dict: + logger.info(f"Detector Configs: {self.detector_configs}") for detector_config in self.detector_configs: + logger.info(detector_config.examples_to_include) if len(detector_config.examples_to_include) == 0: raise HTTPException( status_code=422, detail=f"Model lacks example_to_include for {detector_config.name} class.", ) - # Translating into Backend - config_dict = self.dict() - for i, single_config in enumerate(config_dict["detector_configs"]): - single_config["incs"] = single_config["examples_to_include"] - single_config["excs"] = single_config["examples_to_exclude"] - single_config["img_incs"] = [] - single_config["img_excs"] = [] - single_config["thresh"] = single_config["detection_threshold"] - del single_config["examples_to_include"] - del single_config["examples_to_exclude"] - del single_config["detection_threshold"] - config_dict["detector_configs"][i] = single_config - config_dict["nms_thresh"] = config_dict["nms_threshold"] - del config_dict["nms_threshold"] - config_dict["augment_examples"] = config_dict.get("augment_examples", True) - config_dict["class_agnostic_nms"] = config_dict.get("class_agnostic_nms", True) + labels = [c.name for c in self.detector_configs] + inc_sub_labels_dict: dict[str, List[str]] = { + c.name: c.examples_to_include for c in self.detector_configs + } + exc_sub_labels_dict: dict[str, List[str]] = { + c.name: c.examples_to_exclude for c in self.detector_configs + } + label_conf_thres: dict[str, float] = { + c.name: c.detection_threshold for c in self.detector_configs + } + config_dict = { + "labels": labels, + "inc_sub_labels_dict": inc_sub_labels_dict, + "exc_sub_labels_dict": exc_sub_labels_dict, + "augment_examples": self.augment_examples, + "nms_threshold": self.nms_threshold, + "class_agnostic_nms": self.class_agnostic_nms, + "label_conf_thres": label_conf_thres, + } if self.deployed_id is not None: key_exists = await config_cache.exists(self.deployed_id) @@ -139,6 +144,7 @@ async def save_configuration(self, config_cache: redis.Redis) -> dict: self.deployed_id = str(uuid.uuid4()) else: message = "Model updated." + assert ( self.deployed_id is not None ), "deployed_id should not be None at this point" @@ -154,12 +160,3 @@ class SingleDetectionResponse(BaseModel): class Config: allow_population_by_field_name = True - - -class VerboseDetectorConfig(BaseModel): - name: str - - incs: List[str] - excs: List[str] = [] - - thresh: Optional[float] = None diff --git a/directai_fastapi/requirements.txt b/directai_fastapi/requirements.txt index 611e520..7bd0b8d 100644 --- a/directai_fastapi/requirements.txt +++ b/directai_fastapi/requirements.txt @@ -10,4 +10,6 @@ ray[serve]==2.34.0 mypy==1.11.1 open_clip_torch==2.24.0 https://data.pyg.org/whl/torch-2.2.0%2Bcu121/torch_scatter-2.1.2%2Bpt22cu121-cp310-cp310-linux_x86_64.whl -lru-dict==1.3.0 \ No newline at end of file +lru-dict==1.3.0 +transformers==4.35 +flash-attn==2.6.3 \ No newline at end of file diff --git a/directai_fastapi/server.py b/directai_fastapi/server.py index 8085000..172f5a6 100644 --- a/directai_fastapi/server.py +++ b/directai_fastapi/server.py @@ -15,7 +15,6 @@ ClassifierResponse, DetectorDeploy, SingleDetectionResponse, - VerboseDetectorConfig, ) from utils import raise_if_cannot_open from modeling.distributed_backend import deploy_backend_models @@ -59,7 +58,7 @@ async def startup_event() -> None: app.state.config_cache = await redis.from_url( f"{grab_redis_endpoint()}?decode_responses=True" ) - print(f"Ping successful: {await app.state.config_cache.ping()}") + logger.info(f"Ping successful: {await app.state.config_cache.ping()}") @app.on_event("shutdown") @@ -72,7 +71,7 @@ async def validation_exception_handler( request: Request, exc: RequestValidationError ) -> JSONResponse: exc_str = f"{exc}".replace("\n", " ").replace(" ", " ") - print(f"{request}: {exc_str}") + logger.info(f"{request}: {exc_str}") return JSONResponse( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, content={ @@ -86,7 +85,7 @@ async def validation_exception_handler( @app.exception_handler(HTTPException) async def exception_handler(request: Request, exc: HTTPException) -> JSONResponse: exc_str = f"{exc.detail}".replace("\n", " ").replace(" ", " ") - print(f"{request}: {exc_str}") + logger.info(f"{request}: {exc_str}") return JSONResponse( status_code=exc.status_code, content={"status_code": exc.status_code, "message": exc_str, "data": None}, @@ -111,7 +110,7 @@ async def deploy_classifier(request: Request, config: ClassifierDeploy) -> dict: deploy_response = await config.save_configuration( config_cache=app.state.config_cache ) - print(f"Deployed classifier w/ ID: {deploy_response['deployed_id']}") + logger.info(f"Deployed classifier w/ ID: {deploy_response['deployed_id']}") return deploy_response @@ -168,7 +167,7 @@ async def deploy_detector(request: Request, config: DetectorDeploy) -> dict: deploy_response = await config.save_configuration( config_cache=app.state.config_cache ) - print(f"Deployed detector w/ ID: {deploy_response['deployed_id']}") + logger.info(f"Deployed detector w/ ID: {deploy_response['deployed_id']}") return deploy_response @@ -188,16 +187,30 @@ async def run_detector( data: UploadFile = File(), ) -> List[List[SingleDetectionResponse]]: """Get detections from deployed model""" - print(f"Got request for {deployed_id}, which is a detector model") image = data.file.read() raise_if_cannot_open(image) + logger.info(f"Got request for {deployed_id}, which is a detector model") detector_configs = await grab_config(deployed_id) - ## NOTE: This might break if we have embedded BaseModel-inheriting objects inside the json object - verbose_detector_configs = [ - VerboseDetectorConfig(**json.loads(d) if isinstance(d, str) else d) - for d in detector_configs["detector_configs"] - ] - print(f"augment_examples: {detector_configs.get('augment_examples', None)}") + labels = detector_configs["labels"] + assert isinstance(labels, list), "Labels should be a list of strings" + inc_sub_labels_dict = detector_configs.get("inc_sub_labels_dict", None) + exc_sub_labels_dict = detector_configs.get("exc_sub_labels_dict", None) + label_conf_thres = detector_configs.get("label_conf_thres", None) + augment_examples = detector_configs.get("augment_examples", True) + nms_threshold = detector_configs.get("nms_threshold", 0.4) + class_agnostic_nms = detector_configs.get("class_agnostic_nms", True) + + bboxes = await app.state.detector_handle.remote( + image, + labels=labels, + inc_sub_labels_dict=inc_sub_labels_dict, + exc_sub_labels_dict=exc_sub_labels_dict, + label_conf_thres=label_conf_thres, + augment_examples=augment_examples, + nms_thre=nms_threshold, + run_class_agnostic_nms=class_agnostic_nms, + ) - bboxes = await app.state.detector_handle.remote(None) - return bboxes + return [ + bboxes, + ] diff --git a/directai_fastapi/unit_tests/test.py b/directai_fastapi/unit_tests/test.py index 3deeaee..eaea859 100644 --- a/directai_fastapi/unit_tests/test.py +++ b/directai_fastapi/unit_tests/test.py @@ -5,6 +5,7 @@ from unit_tests.test_modules.test_utils import * from unit_tests.test_modules.test_tensor_utils import * from unit_tests.test_modules.test_classifier import * +from unit_tests.test_modules.test_detector import * if __name__ == "__main__": unittest.main() diff --git a/directai_fastapi/unit_tests/test_modules/test_detector.py b/directai_fastapi/unit_tests/test_modules/test_detector.py new file mode 100644 index 0000000..f26c327 --- /dev/null +++ b/directai_fastapi/unit_tests/test_modules/test_detector.py @@ -0,0 +1,318 @@ +import unittest +import torch +import torchvision # type: ignore +from typing_extensions import ClassVar + +from modeling.object_detector import ( + ZeroShotObjectDetectorWithFeedback, + created_padded_tensor_from_bytes, + compute_iou_adjacency_list, + run_nms_via_adjacency_list, + run_nms_based_box_suppression_for_all_objects, +) + + +class TestHelperFunctions(unittest.TestCase): + def test_nms_via_adjacency_list(self) -> None: + nms_thre = 0.1 + n_boxes = 1024 + cxcywh_boxes = torch.rand(n_boxes, 4) + scores = torch.rand(n_boxes) + + tlbr_boxes = cxcywh_boxes.clone() + tlbr_boxes[:, :2] = cxcywh_boxes[:, :2] - cxcywh_boxes[:, 2:] / 2 + tlbr_boxes[:, 2:] = cxcywh_boxes[:, :2] + cxcywh_boxes[:, 2:] / 2 + + box_indices_by_descending_score = torch.argsort(scores, descending=True) + adjacency_list = compute_iou_adjacency_list(cxcywh_boxes, nms_thre=nms_thre) + adjacency_survived_inds = run_nms_via_adjacency_list( + box_indices_by_descending_score, adjacency_list + ) + + torchvision_survived_inds = torchvision.ops.nms( + tlbr_boxes, scores, iou_threshold=nms_thre + ) + + self.assertTrue(torch.all(adjacency_survived_inds == torchvision_survived_inds)) + + def test_class_agnostic_has_no_effect_on_single_class(self) -> None: + n_boxes = 512 + pro_max = torch.rand(n_boxes, 1) + con_max = torch.rand(n_boxes, 1) + cxcywh_boxes = torch.rand(n_boxes, 4) + conf_thres = torch.tensor([0.1]) + nms_thre = 0.1 + + class_believer_boxes = run_nms_based_box_suppression_for_all_objects( + pro_max, + con_max, + cxcywh_boxes, + 1.0, + conf_thres, + nms_thre, + run_class_agnostic_nms=True, + ) + self.assertEqual(len(class_believer_boxes), 1) + + class_agnostic_boxes = run_nms_based_box_suppression_for_all_objects( + pro_max, + con_max, + cxcywh_boxes, + 1.0, + conf_thres, + nms_thre, + run_class_agnostic_nms=False, + ) + + self.assertTrue( + torch.all(torch.eq(class_believer_boxes[0], class_agnostic_boxes[0])) + ) + + def test_class_agnostic_has_no_effect_on_multiclass_with_no_overlap(self) -> None: + n_classes = 4 + n_boxes = 512 * n_classes + pro_max = torch.zeros(n_boxes, n_classes) + con_max = torch.zeros(n_boxes, n_classes) + cxcywh_boxes = torch.rand(n_boxes, 4) + conf_thres = torch.tensor([0.1] * n_classes) + nms_thre = 0.1 + + # adjust the scores such that each class has nonzero scores in exactly 1/n_classes of the boxes + # and shift those boxes so that they don't overlap between classes + for i in range(n_classes): + pro_max[i::n_classes, i] = torch.rand(n_boxes // n_classes) + con_max[i::n_classes, i] = torch.rand(n_boxes // n_classes) + cxcywh_boxes[i::n_classes, 0] += i * 10 + + class_believer_boxes = run_nms_based_box_suppression_for_all_objects( + pro_max, + con_max, + cxcywh_boxes, + 1.0, + conf_thres, + nms_thre, + run_class_agnostic_nms=True, + ) + self.assertEqual(len(class_believer_boxes), n_classes) + + class_agnostic_boxes = run_nms_based_box_suppression_for_all_objects( + pro_max, + con_max, + cxcywh_boxes, + 1.0, + conf_thres, + nms_thre, + run_class_agnostic_nms=False, + ) + + for i in range(n_classes): + self.assertTrue( + torch.all(torch.eq(class_believer_boxes[i], class_agnostic_boxes[i])) + ) + + +class TestObjectDetector(unittest.TestCase): + # we have to define these here because mypy doesn't dive into the init hiding behind the classmethod + object_detector = ( + NotImplemented + ) # type: ClassVar[ZeroShotObjectDetectorWithFeedback] + coke_bottle_image_bytes = NotImplemented # type: ClassVar[bytes] + coke_can_image_bytes = NotImplemented # type: ClassVar[bytes] + default_labels = NotImplemented # type: ClassVar[list[str]] + default_incs = NotImplemented # type: ClassVar[dict[str, list[str]]] + default_excs = NotImplemented # type: ClassVar[dict[str, list[str]]] + default_nms_thre = NotImplemented # type: ClassVar[float] + default_conf_thres = NotImplemented # type: ClassVar[dict[str, float]] + + @classmethod + def setUpClass(cls) -> None: + cls.object_detector = ZeroShotObjectDetectorWithFeedback(jit=False) + + coke_bottle_filepath = "unit_tests/sample_data/coke_through_the_ages.jpeg" + with open(coke_bottle_filepath, "rb") as f: + cls.coke_bottle_image_bytes = f.read() + coke_can_filepath = "unit_tests/sample_data/coke_can.jpg" + with open(coke_can_filepath, "rb") as f: + cls.coke_can_image_bytes = f.read() + + cls.default_labels = ["bottle", "can", "moose"] + cls.default_incs = { + "bottle": ["bottle", "glass bottle", "plastic bottle", "water bottle"], + "can": ["can", "soda can", "aluminum can"], + "moose": ["moose", "elk", "deer"], + } + cls.default_excs = { + "bottle": ["can", "soda can", "aluminum can"], + } + cls.default_nms_thre = 0.1 + cls.default_conf_thres = { + "bottle": 0.1, + "can": 0.1, + "moose": 0.1, + } + + def test_detect_objects_from_image_bytes(self) -> None: + with torch.no_grad(): + batched_predicted_boxes = self.object_detector( + self.coke_bottle_image_bytes, + labels=self.default_labels, + inc_sub_labels_dict=self.default_incs, + exc_sub_labels_dict=self.default_excs, + nms_thre=self.default_nms_thre, + label_conf_thres=self.default_conf_thres, + ) + + self.assertEqual(len(batched_predicted_boxes), 1) + predicted_boxes = batched_predicted_boxes[0] + self.assertEqual(len(predicted_boxes), len(self.default_labels)) + bottle_boxes = predicted_boxes[0] + can_boxes = predicted_boxes[1] + moose_boxes = predicted_boxes[2] + self.assertEqual(len(bottle_boxes), 9) + self.assertEqual(len(can_boxes), 0) + self.assertEqual(len(moose_boxes), 0) + + def test_batched_detect(self) -> None: + # ideally we would test a set of random images + # but we use a sort, which has unstable ordering with floating point numbers + # which means it is nontrivial to compare the outputs of the batched and single-image versions + # instead we're going to limit to confident predictions from two images + # and hope that the confidences are well-enough separated that the sort is stable + + with torch.no_grad(): + coke_bottle_image_tensor, coke_bottle_ratio = ( + created_padded_tensor_from_bytes( + self.coke_bottle_image_bytes, self.object_detector.image_size + ) + ) + coke_can_image_tensor, coke_can_ratio = created_padded_tensor_from_bytes( + self.coke_can_image_bytes, self.object_detector.image_size + ) + + batched_images = torch.cat( + [ + coke_bottle_image_tensor, + ] + * 8 + + [ + coke_can_image_tensor, + ] + * 8, + dim=0, + ) + batched_ratios = torch.cat( + [ + coke_bottle_ratio, + ] + * 8 + + [ + coke_can_ratio, + ] + * 8, + dim=0, + ) + + single_image_outputs_list = [] + for image, ratio in zip(batched_images, batched_ratios): + single_image_outputs_list.append( + self.object_detector( + image.unsqueeze(0), + labels=self.default_labels, + inc_sub_labels_dict=self.default_incs, + exc_sub_labels_dict=None, + nms_thre=self.default_nms_thre, + label_conf_thres=self.default_conf_thres, + image_scale_ratios=ratio.unsqueeze(0), + )[0] + ) + + batched_outputs = self.object_detector( + batched_images, + labels=self.default_labels, + inc_sub_labels_dict=self.default_incs, + exc_sub_labels_dict=None, + nms_thre=self.default_nms_thre, + label_conf_thres=self.default_conf_thres, + image_scale_ratios=batched_ratios, + ) + + for i in range(len(batched_outputs)): + for j in range(len(batched_outputs[i])): + from_batch = batched_outputs[i][j] + from_single = single_image_outputs_list[i][j] + self.assertEqual(from_batch.shape, from_single.shape) + if from_batch.shape[0] == 0: + continue + + # these values have range on the order of 1e3, so we scale them to compare + scale = torch.maximum(from_batch.abs(), from_single.abs()) + diff = (from_batch - from_single).abs() + scaled_diff = diff / (scale + 1e-6) + max_diff = scaled_diff.max().item() + + # large range and machine precision issues mean the max diff has a lot of noise + # TODO: is this more than is sane? + self.assertTrue(max_diff < 1e-3) + + def test_batch_detect_random_append(self) -> None: + # we test batched detection by doing a pass for one image + # and then doing a batched pass for that image and many random images + # we use low confidence thresholds during the detection + # and then truncate at a high confidence level to ensure stability due to sorting + confidences = {label: 0.0 for label in self.default_labels} + with torch.no_grad(): + coke_bottle_image_tensor, coke_bottle_ratio = ( + created_padded_tensor_from_bytes( + self.coke_bottle_image_bytes, self.object_detector.image_size + ) + ) + baseline_output = self.object_detector( + coke_bottle_image_tensor, + labels=self.default_labels, + inc_sub_labels_dict=self.default_incs, + exc_sub_labels_dict=None, + nms_thre=self.default_nms_thre, + label_conf_thres=confidences, + image_scale_ratios=coke_bottle_ratio, + )[0] + + random_tensors = torch.rand(128, 3, *self.object_detector.image_size) + batched_tensor = torch.cat([random_tensors, coke_bottle_image_tensor]) + batched_ratios = torch.cat([torch.ones(128), coke_bottle_ratio]) + batched_output = self.object_detector( + batched_tensor, + labels=self.default_labels, + inc_sub_labels_dict=self.default_incs, + exc_sub_labels_dict=None, + nms_thre=self.default_nms_thre, + label_conf_thres=confidences, + image_scale_ratios=batched_ratios, + )[-1] + + for baseline_obj_detections, batched_obj_detections in zip( + baseline_output, batched_output + ): + # filter by confidence of 0.1 + baseline_obj_detections = baseline_obj_detections[ + baseline_obj_detections[:, 4] > 0.1 + ] + batched_obj_detections = batched_obj_detections[ + batched_obj_detections[:, 4] > 0.1 + ] + + self.assertEqual( + baseline_obj_detections.shape, batched_obj_detections.shape + ) + if baseline_obj_detections.shape[0] == 0: + continue + + # these values have range on the order of 1e3, so we scale them to compare + scale = torch.maximum( + baseline_obj_detections.abs(), batched_obj_detections.abs() + ) + diff = (baseline_obj_detections - batched_obj_detections).abs() + scaled_diff = diff / (scale + 1e-6) + max_diff = scaled_diff.max().item() + + # large range and machine precision issues mean the max diff has a lot of noise + self.assertTrue(max_diff < 1e-6) diff --git a/docker-compose.yml b/docker-compose.yml index 2830a90..ec8e51e 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -19,12 +19,14 @@ services: environment: - PYTHONUNBUFFERED=1 - NVIDIA_VISIBLE_DEVICES=all - - HF_HOME=/.cache/huggingface + - HF_HOME=/directai_fastapi/.cache/huggingface + - CACHE_REDIS_PORT=6379 env_file: - directai_fastapi/.env runtime: nvidia volumes: - ./logs:/directai_fastapi/logs + - ./.cache:/directai_fastapi/.cache shm_size: 10.24g # because Ray complains if it's less depends_on: - local_redis diff --git a/integration_tests/test_modules/test_detector.py b/integration_tests/test_modules/test_detector.py index f11cf35..14a503c 100644 --- a/integration_tests/test_modules/test_detector.py +++ b/integration_tests/test_modules/test_detector.py @@ -208,9 +208,28 @@ def test_deploy_detector_success_without_augment_examples(self) -> None: self.assertTrue("deployed_id" in response_json) self.assertEqual(response_json["message"], "New model deployed.") - @unittest.skip("detector isn't built yet") + +class TestDetectorInference(unittest.TestCase): + def __init__(self, methodName: str = "runTest"): + super().__init__(methodName=methodName) + self.endpoint = f"http://{FASTAPI_HOST}:8000/" + # here we assume that deploy has been tested and works + # so we can generate a fixed deploy id for testing + body = { + "detector_configs": [ + { + "name": "bottle", + "examples_to_include": ["bottle"], + "examples_to_exclude": [], + "detection_threshold": 0.1, + } + ], + } + deploy_response = requests.post(self.endpoint + "deploy_detector", json=body) + deploy_response_json = deploy_response.json() + self.sample_deployed_id = deploy_response_json["deployed_id"] + def test_detect(self) -> None: - sample_deployed_id = "a554fdf3-cd07-45f5-a01a-c3b7cef75374" sample_fp = "sample_data/coke_through_the_ages.jpeg" expected_detect_response_unaccelerated = [ [ @@ -267,7 +286,7 @@ def test_detect(self) -> None: files = { "data": (sample_fp, file_data, "image/jpg"), } - params = {"deployed_id": sample_deployed_id} + params = {"deployed_id": self.sample_deployed_id} response = requests.post(self.endpoint + "detect", params=params, files=files) detect_response_json = response.json() @@ -287,13 +306,12 @@ def test_detect(self) -> None: ) def test_detect_malformatted_image(self) -> None: - sample_deployed_id = "a554fdf3-cd07-45f5-a01a-c3b7cef75374" sample_fp = "bad_file_path.suffix" file_data = b"This is not an image file" files = { "data": (sample_fp, file_data, "image/jpg"), } - params = {"deployed_id": sample_deployed_id} + params = {"deployed_id": self.sample_deployed_id} response = requests.post(self.endpoint + "detect", params=params, files=files) response_json = response.json() self.assertEqual(response_json["status_code"], 422) @@ -302,12 +320,11 @@ def test_detect_malformatted_image(self) -> None: ) def test_detect_empty_image(self) -> None: - sample_deployed_id = "a554fdf3-cd07-45f5-a01a-c3b7cef75374" sample_fp = "bad_file_path.jpg" files = { "data": (sample_fp, b"", "image/jpg"), } - params = {"deployed_id": sample_deployed_id} + params = {"deployed_id": self.sample_deployed_id} response = requests.post(self.endpoint + "detect", params=params, files=files) response_json = response.json() self.assertEqual(response_json["status_code"], 422) @@ -316,7 +333,6 @@ def test_detect_empty_image(self) -> None: ) def test_detect_truncated_image(self) -> None: - sample_deployed_id = "a554fdf3-cd07-45f5-a01a-c3b7cef75374" sample_fp = "sample_data/coke_through_the_ages.jpeg" with open(sample_fp, "rb") as f: file_data = f.read() @@ -324,7 +340,7 @@ def test_detect_truncated_image(self) -> None: files = { "data": (sample_fp, file_data, "image/jpg"), } - params = {"deployed_id": sample_deployed_id} + params = {"deployed_id": self.sample_deployed_id} response = requests.post(self.endpoint + "detect", params=params, files=files) response_json = response.json() self.assertEqual(response_json["status_code"], 422) @@ -332,7 +348,6 @@ def test_detect_truncated_image(self) -> None: response_json["message"], "Invalid image received, unable to open." ) - @unittest.skip("detector isn't built yet") def test_deploy_and_detect(self) -> None: # Starting Deploy Call body = { @@ -425,9 +440,9 @@ def test_deploy_and_detect(self) -> None: 0.05, ) - @unittest.skip("detector isn't built yet") def test_deploy_with_long_prompt_and_detect(self) -> None: # Starting Deploy Call + # NOTE: this is the only single-class detection test that runs the class-specific NMS algorithm very_long_prompt = "boat from birds-eye view maritime vessel from birds-eye view boat from top-down view maritime vessel from top-down view" body = { "detector_configs": [ @@ -483,7 +498,6 @@ def test_deploy_with_long_prompt_and_detect(self) -> None: 0.05, ) - @unittest.skip("detector isn't built yet") def test_deploy_and_detect_without_augmented_examples(self) -> None: # Starting Deploy Call body = { @@ -658,7 +672,6 @@ def test_deploy_and_detect_without_augmented_examples(self) -> None: detect_response_json, detect_response_augmented_examples_json ) - @unittest.skip("detector isn't built yet") def test_deploy_with_and_without_class_agnostic_nms(self) -> None: # Starting Deploy Call body = { diff --git a/redis_data/redis_entrypoint.sh b/redis_data/redis_entrypoint.sh index 4381723..9fc4b9c 100755 --- a/redis_data/redis_entrypoint.sh +++ b/redis_data/redis_entrypoint.sh @@ -8,7 +8,7 @@ cleanup() { trap 'cleanup' SIGTERM #Execute a command in the background -redis-server --requirepass "default_password" --appendonly "yes" --appendfsync "always" & +redis-server --requirepass "default_password" --appendonly "yes" --appendfsync "always" --port 6379 & #Save the PID of the background process REDIS_PID=$!