1
+ from segment_anything import SamPredictor , sam_model_registry
2
+ import torchvision
3
+ import torch
4
+ from PIL import Image
5
+
6
+ import numpy as np
7
+ import os
8
+ import xml .etree .ElementTree as ET
9
+ from statistics import mean
10
+ from torch .nn .functional import threshold , normalize
11
+ import torch .nn .functional as F
12
+ from segment_anything .utils .transforms import ResizeLongestSide
13
+ from typing import List , Tuple
14
+
15
+ # Pad image - based on SAM
16
+ def pad_image (x : torch .Tensor , square_length = 1024 ) -> torch .Tensor :
17
+ # C, H, W
18
+ h , w = x .shape [- 2 :]
19
+ padh = square_length - h
20
+ padw = square_length - w
21
+ x = F .pad (x , (0 , padw , 0 , padh ))
22
+ return x
23
+
24
+ # Custom dataset
25
+ class INC_SAMVOC2012Dataset (object ):
26
+ def __init__ (self , voc_root , type ):
27
+ self .voc_root = voc_root
28
+ self .num_of_data = - 1
29
+ self .dataset = {} # Item will be : ["filename", "class_name", [4x bounding boxes coordinates], etc)
30
+ self .resizelongestside = ResizeLongestSide (target_length = 1024 )
31
+ pixel_mean = [123.675 , 116.28 , 103.53 ]
32
+ pixel_std = [58.395 , 57.12 , 57.375 ]
33
+ self .pixel_mean = torch .Tensor (pixel_mean ).view (- 1 , 1 , 1 )
34
+ self .pixel_std = torch .Tensor (pixel_std ).view (- 1 , 1 , 1 )
35
+
36
+ # Read through all the samples and output a dictionary
37
+ # Key of the dictionary will be idx
38
+ # Item of the dictionary will be filename, class id and bounding boxes
39
+ annotation_dir = os .path .join (voc_root , "Annotations" )
40
+ files = os .listdir (annotation_dir )
41
+ files = [f for f in files if os .path .isfile (annotation_dir + '/' + f )] #Filter directory
42
+ annotation_files = [os .path .join (annotation_dir , x ) for x in files ]
43
+
44
+ # Get the name list of the segmentation files
45
+ segmentation_dir = os .path .join (voc_root , "SegmentationObject" )
46
+ files = os .listdir (segmentation_dir )
47
+ files = [f for f in files if os .path .isfile (segmentation_dir + '/' + f )] #Filter directory
48
+ segmentation_files = [x for x in files ]
49
+
50
+
51
+ # Based on the type (train/val) to select data
52
+ train_val_dir = os .path .join (voc_root , 'ImageSets/Segmentation/' )
53
+ if type == 'train' :
54
+ txt_file_name = 'train.txt'
55
+ elif type == 'val' :
56
+ txt_file_name = 'val.txt'
57
+ else :
58
+ print ('Error! Type of dataset should be ' 'train' ' or ' 'val' ' ' )
59
+
60
+ with open (train_val_dir + txt_file_name , 'r' ) as f :
61
+ permitted_files = []
62
+ for row in f :
63
+ permitted_files .append (row .rstrip ('\n ' ))
64
+
65
+ for file in annotation_files :
66
+ file_name = file .split ('/' )[- 1 ].split ('.xml' )[0 ]
67
+
68
+ if not (file_name in permitted_files ):
69
+ continue #skip the file
70
+
71
+ if file_name + '.png' in segmentation_files : # check that if there is any related segmentation file for this annotation
72
+ tree = ET .parse (file )
73
+ root = tree .getroot ()
74
+ for child in root :
75
+ if child .tag == 'object' :
76
+ details = [file_name ]
77
+ for node in child :
78
+ if node .tag == 'name' :
79
+ object_name = node .text
80
+ if node .tag == 'bndbox' :
81
+ for coordinates in node :
82
+ if coordinates .tag == 'xmax' :
83
+ xmax = int (coordinates .text )
84
+ if coordinates .tag == 'xmin' :
85
+ xmin = int (coordinates .text )
86
+ if coordinates .tag == 'ymax' :
87
+ ymax = int (coordinates .text )
88
+ if coordinates .tag == 'ymin' :
89
+ ymin = int (coordinates .text )
90
+ boundary = [xmin , ymin , xmax , ymax ]
91
+ details .append (object_name )
92
+ details .append (boundary )
93
+ self .num_of_data += 1
94
+ self .dataset [self .num_of_data ] = details
95
+
96
+ def __len__ (self ):
97
+ return self .num_of_data
98
+
99
+ # Preprocess the segmentation mask. Output only 1 object semantic information.
100
+ def preprocess_segmentation (self , filename , bounding_box , pad = True ):
101
+
102
+ #read the semantic mask
103
+ segment_mask = Image .open (self .voc_root + 'SegmentationObject/' + filename + '.png' )
104
+ segment_mask_np = torchvision .transforms .functional .pil_to_tensor (segment_mask )
105
+
106
+ #Crop the segmentation based on the bounding box
107
+ xmin , ymin = int (bounding_box [0 ]), int (bounding_box [1 ])
108
+ xmax , ymax = int (bounding_box [2 ]), int (bounding_box [3 ])
109
+ cropped_mask = segment_mask .crop ((xmin , ymin , xmax , ymax ))
110
+ cropped_mask_np = torchvision .transforms .functional .pil_to_tensor (cropped_mask )
111
+
112
+ #Count the majority element
113
+ bincount = np .bincount (cropped_mask_np .reshape (- 1 ))
114
+ bincount [0 ] = 0 #Remove the black pixel
115
+ if (bincount .shape [0 ] >= 256 ):
116
+ bincount [255 ] = 0 #Remove the white pixel
117
+ majority_element = bincount .argmax ()
118
+
119
+ #Based on the majority element, binary mask the segmentation
120
+ segment_mask_np [np .where ((segment_mask_np != 0 ) & (segment_mask_np != majority_element ))] = 0
121
+ segment_mask_np [segment_mask_np == majority_element ] = 1
122
+
123
+ #Pad the segment mask to 1024x1024 (for batching in dataloader)
124
+ if pad :
125
+ segment_mask_np = pad_image (segment_mask_np )
126
+
127
+ return segment_mask_np
128
+
129
+ # Preprocess the image to an appropriate format for SAM
130
+ def preprocess_image (self , img ):
131
+ # ~= predictor.py - set_image()
132
+ img = np .array (img )
133
+ input_image = self .resizelongestside .apply_image (img )
134
+ input_image_torch = torch .as_tensor (input_image , device = 'cpu' )
135
+ input_image_torch = input_image_torch .permute (2 , 0 , 1 ).contiguous ()
136
+ input_image_torch = (input_image_torch - self .pixel_mean ) / self .pixel_std #normalize
137
+ original_size = img .shape [:2 ]
138
+ input_size = tuple (input_image_torch .shape [- 2 :])
139
+
140
+ return pad_image (input_image_torch ), original_size , input_size
141
+
142
+ def __getitem__ (self , idx ):
143
+ data = self .dataset [idx ]
144
+ filename , classname = data [0 ], data [1 ]
145
+ bounding_box = data [2 ]
146
+
147
+ # No padding + preprocessing
148
+ mask_gt = self .preprocess_segmentation (filename , bounding_box , pad = False )
149
+
150
+ image , original_size , input_size = self .preprocess_image (Image .open (self .voc_root + 'JPEGImages/' + filename + '.jpg' )) # read the image
151
+ prompt = bounding_box # bounding box - input_boxes x1, y1, x2, y2
152
+ training_data = {}
153
+ training_data ['image' ] = image
154
+ training_data ["original_size" ] = original_size
155
+ training_data ["input_size" ] = input_size
156
+ training_data ["ground_truth_mask" ] = mask_gt
157
+ training_data ["prompt" ] = prompt
158
+ return (training_data , mask_gt ) #data, label
159
+
160
+
161
+ class INC_SAMVOC2012Dataloader :
162
+ def __init__ (self , batch_size , ** kwargs ):
163
+ self .batch_size = batch_size
164
+ self .dataset = []
165
+ ds = INC_SAMVOC2012Dataset (kwargs ['voc_root' ], kwargs ['type' ])
166
+ # operations to add (input_data, label) pairs into self.dataset
167
+ for i in range (len (ds )):
168
+ self .dataset .append (ds [i ])
169
+
170
+
171
+ def __iter__ (self ):
172
+ for input_data , label in self .dataset :
173
+ yield input_data , label
0 commit comments