Skip to content

Commit bd5e698

Browse files
authored
Add PTQ example for PyTorch CV - Segment Anything Model (#1464)
1 parent 7a36717 commit bd5e698

File tree

7 files changed

+714
-0
lines changed

7 files changed

+714
-0
lines changed
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
Step-by-Step
2+
============
3+
This document describes the step-by-step instructions for applying post training quantization on Segment Anything Model (SAM) using VOC dataset.
4+
5+
# Prerequisite
6+
## Environment
7+
```shell
8+
# install dependencies
9+
pip install -r ./requirements.txt
10+
# retrieve SAM model codes and pre-trained weight
11+
pip install git+https://github.com/facebookresearch/segment-anything.git
12+
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
13+
```
14+
15+
# PTQ
16+
PTQ example on Segment Anything Model (SAM) using VOC dataset.
17+
18+
## 1. Prepare VOC dataset
19+
```shell
20+
python download_dataset.py
21+
```
22+
23+
## 2. Start PTQ
24+
```shell
25+
bash run_quant.sh --voc_dataset_location=./voc_dataset/VOCdevkit/VOC2012/ --pretrained_weight_location=./sam_vit_b_01ec64.pth
26+
```
27+
28+
## 3. Benchmarking
29+
```shell
30+
bash run_benchmark.sh --tuned_checkpoint=./saved_results --voc_dataset_location=./voc_dataset/VOCdevkit/VOC2012/ --int8=True --mode=performance
31+
```
32+
33+
# Result
34+
| | Baseline (FP32) | INT8
35+
| ------------- | ------------- | -------------
36+
Accuracy | 0.7939 | 0.7849
37+
38+
# Saving and Loading Model
39+
40+
* Saving model:
41+
After tuning with Neural Compressor, we can get neural_compressor.model:
42+
43+
```
44+
from neural_compressor import PostTrainingQuantConfig
45+
from neural_compressor import quantization
46+
conf = PostTrainingQuantConfig()
47+
q_model = quantization.fit(model,
48+
conf,
49+
calib_dataloader=val_loader,
50+
eval_func=eval_func)
51+
```
52+
53+
Here, `q_model` is the Neural Compressor model class, so it has "save" API:
54+
55+
```python
56+
q_model.save("Path_to_save_quantized_model")
57+
```
58+
59+
* Loading model:
60+
61+
```python
62+
from neural_compressor.utils.pytorch import load
63+
quantized_model = load(os.path.abspath(os.path.expanduser(args.tuned_checkpoint)),
64+
model,
65+
dataloader=val_loader)
66+
```
67+
68+
Please refer to main.py for reference.
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import torchvision
2+
3+
print("Downloading VOC dataset")
4+
torchvision.datasets.VOCDetection(root='./voc_dataset', year='2012', image_set ='trainval', download=True)
5+
6+
7+
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
from segment_anything import SamPredictor, sam_model_registry
2+
import torchvision
3+
import torch
4+
from PIL import Image
5+
6+
import numpy as np
7+
import os
8+
import xml.etree.ElementTree as ET
9+
from statistics import mean
10+
from torch.nn.functional import threshold, normalize
11+
import torch.nn.functional as F
12+
from segment_anything.utils.transforms import ResizeLongestSide
13+
from typing import List, Tuple
14+
15+
# Pad image - based on SAM
16+
def pad_image(x: torch.Tensor, square_length = 1024) -> torch.Tensor:
17+
# C, H, W
18+
h, w = x.shape[-2:]
19+
padh = square_length - h
20+
padw = square_length - w
21+
x = F.pad(x, (0, padw, 0, padh))
22+
return x
23+
24+
# Custom dataset
25+
class INC_SAMVOC2012Dataset(object):
26+
def __init__(self, voc_root, type):
27+
self.voc_root = voc_root
28+
self.num_of_data = -1
29+
self.dataset = {} # Item will be : ["filename", "class_name", [4x bounding boxes coordinates], etc)
30+
self.resizelongestside = ResizeLongestSide(target_length=1024)
31+
pixel_mean = [123.675, 116.28, 103.53]
32+
pixel_std = [58.395, 57.12, 57.375]
33+
self.pixel_mean = torch.Tensor(pixel_mean).view(-1, 1, 1)
34+
self.pixel_std = torch.Tensor(pixel_std).view(-1, 1, 1)
35+
36+
# Read through all the samples and output a dictionary
37+
# Key of the dictionary will be idx
38+
# Item of the dictionary will be filename, class id and bounding boxes
39+
annotation_dir = os.path.join(voc_root, "Annotations")
40+
files = os.listdir(annotation_dir)
41+
files = [f for f in files if os.path.isfile(annotation_dir+'/'+f)] #Filter directory
42+
annotation_files = [os.path.join(annotation_dir, x) for x in files]
43+
44+
# Get the name list of the segmentation files
45+
segmentation_dir = os.path.join(voc_root, "SegmentationObject")
46+
files = os.listdir(segmentation_dir)
47+
files = [f for f in files if os.path.isfile(segmentation_dir+'/'+f)] #Filter directory
48+
segmentation_files = [x for x in files]
49+
50+
51+
# Based on the type (train/val) to select data
52+
train_val_dir = os.path.join(voc_root, 'ImageSets/Segmentation/')
53+
if type == 'train':
54+
txt_file_name = 'train.txt'
55+
elif type =='val':
56+
txt_file_name = 'val.txt'
57+
else:
58+
print('Error! Type of dataset should be ''train'' or ''val'' ')
59+
60+
with open(train_val_dir + txt_file_name, 'r') as f:
61+
permitted_files = []
62+
for row in f:
63+
permitted_files.append(row.rstrip('\n'))
64+
65+
for file in annotation_files:
66+
file_name = file.split('/')[-1].split('.xml')[0]
67+
68+
if not(file_name in permitted_files):
69+
continue #skip the file
70+
71+
if file_name + '.png' in segmentation_files: # check that if there is any related segmentation file for this annotation
72+
tree = ET.parse(file)
73+
root = tree.getroot()
74+
for child in root:
75+
if child.tag == 'object':
76+
details = [file_name]
77+
for node in child:
78+
if node.tag == 'name':
79+
object_name = node.text
80+
if node.tag == 'bndbox':
81+
for coordinates in node:
82+
if coordinates.tag == 'xmax':
83+
xmax = int(coordinates.text)
84+
if coordinates.tag == 'xmin':
85+
xmin = int(coordinates.text)
86+
if coordinates.tag == 'ymax':
87+
ymax = int(coordinates.text)
88+
if coordinates.tag == 'ymin':
89+
ymin = int(coordinates.text)
90+
boundary = [xmin, ymin, xmax, ymax]
91+
details.append(object_name)
92+
details.append(boundary)
93+
self.num_of_data += 1
94+
self.dataset[self.num_of_data] = details
95+
96+
def __len__(self):
97+
return self.num_of_data
98+
99+
# Preprocess the segmentation mask. Output only 1 object semantic information.
100+
def preprocess_segmentation(self, filename, bounding_box, pad=True):
101+
102+
#read the semantic mask
103+
segment_mask = Image.open(self.voc_root + 'SegmentationObject/' + filename + '.png')
104+
segment_mask_np = torchvision.transforms.functional.pil_to_tensor(segment_mask)
105+
106+
#Crop the segmentation based on the bounding box
107+
xmin, ymin = int(bounding_box[0]), int(bounding_box[1])
108+
xmax, ymax = int(bounding_box[2]), int(bounding_box[3])
109+
cropped_mask = segment_mask.crop((xmin, ymin, xmax, ymax))
110+
cropped_mask_np = torchvision.transforms.functional.pil_to_tensor(cropped_mask)
111+
112+
#Count the majority element
113+
bincount = np.bincount(cropped_mask_np.reshape(-1))
114+
bincount[0] = 0 #Remove the black pixel
115+
if (bincount.shape[0] >= 256):
116+
bincount[255] = 0 #Remove the white pixel
117+
majority_element = bincount.argmax()
118+
119+
#Based on the majority element, binary mask the segmentation
120+
segment_mask_np[np.where((segment_mask_np != 0) & (segment_mask_np != majority_element))] = 0
121+
segment_mask_np[segment_mask_np == majority_element] = 1
122+
123+
#Pad the segment mask to 1024x1024 (for batching in dataloader)
124+
if pad:
125+
segment_mask_np = pad_image(segment_mask_np)
126+
127+
return segment_mask_np
128+
129+
# Preprocess the image to an appropriate format for SAM
130+
def preprocess_image(self, img):
131+
# ~= predictor.py - set_image()
132+
img = np.array(img)
133+
input_image = self.resizelongestside.apply_image(img)
134+
input_image_torch = torch.as_tensor(input_image, device='cpu')
135+
input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()
136+
input_image_torch = (input_image_torch - self.pixel_mean) / self.pixel_std #normalize
137+
original_size = img.shape[:2]
138+
input_size = tuple(input_image_torch.shape[-2:])
139+
140+
return pad_image(input_image_torch), original_size, input_size
141+
142+
def __getitem__(self, idx):
143+
data = self.dataset[idx]
144+
filename, classname = data[0], data[1]
145+
bounding_box = data[2]
146+
147+
# No padding + preprocessing
148+
mask_gt = self.preprocess_segmentation(filename, bounding_box, pad=False)
149+
150+
image, original_size, input_size = self.preprocess_image(Image.open(self.voc_root + 'JPEGImages/' + filename + '.jpg')) # read the image
151+
prompt = bounding_box # bounding box - input_boxes x1, y1, x2, y2
152+
training_data = {}
153+
training_data['image'] = image
154+
training_data["original_size"] = original_size
155+
training_data["input_size"] = input_size
156+
training_data["ground_truth_mask"] = mask_gt
157+
training_data["prompt"] = prompt
158+
return (training_data, mask_gt) #data, label
159+
160+
161+
class INC_SAMVOC2012Dataloader:
162+
def __init__(self, batch_size, **kwargs):
163+
self.batch_size = batch_size
164+
self.dataset = []
165+
ds = INC_SAMVOC2012Dataset(kwargs['voc_root'], kwargs['type'])
166+
# operations to add (input_data, label) pairs into self.dataset
167+
for i in range(len(ds)):
168+
self.dataset.append(ds[i])
169+
170+
171+
def __iter__(self):
172+
for input_data, label in self.dataset:
173+
yield input_data, label

0 commit comments

Comments
 (0)