Skip to content

Commit 3d75b92

Browse files
review comments
1 parent a26acbf commit 3d75b92

File tree

8 files changed

+221
-19
lines changed

8 files changed

+221
-19
lines changed

keras_cv/models/object_detection/yolox/binary_crossentropy.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,18 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
16+
import warnings
17+
1518
import tensorflow as tf
1619

1720

1821
class BinaryCrossentropy(tf.keras.losses.Loss):
1922
"""Computes the cross-entropy loss between true labels and predicted labels.
23+
2024
Use this cross-entropy loss for binary (0 or 1) classification applications.
2125
This loss is updated for YoloX by offering support for no axis to mean over.
26+
2227
Args:
2328
from_logits: Whether to interpret `y_pred` as a tensor of
2429
[logit](https://en.wikipedia.org/wiki/Logit) values. By default, we
@@ -29,9 +34,10 @@ class BinaryCrossentropy(tf.keras.losses.Loss):
2934
version of the true labels, where the smoothing squeezes the labels
3035
towards 0.5. Larger values of `label_smoothing` correspond to
3136
heavier smoothing.
32-
axis: the axis along which to mean the ious. Defaults to `None` which implies
37+
axis: the axis along which to mean the ious. Defaults to `no_reduction` which implies
3338
mean across no axes.
34-
Sample Usage:
39+
40+
Usage:
3541
```python
3642
model.compile(
3743
loss=keras_cv.models.object_detection.yolox.binary_crossentropy.BinaryCrossentropy(from_logits=True)
@@ -62,13 +68,19 @@ def _smooth_labels():
6268
label_smoothing, _smooth_labels, lambda: y_true
6369
)
6470

65-
if self.axis is not None:
71+
if self.axis == "no_reduction":
72+
warnings.warn(
73+
"`axis='no_reduction'` is a temporary API, and the API contract "
74+
"will be replaced in the future with a more generic solution "
75+
"covering all losses."
76+
)
6677
return tf.reduce_mean(
6778
tf.keras.backend.binary_crossentropy(
6879
y_true, y_pred, from_logits=self.from_logits
6980
),
7081
axis=self.axis,
7182
)
83+
7284
return tf.keras.backend.binary_crossentropy(
7385
y_true, y_pred, from_logits=self.from_logits
7486
)

keras_cv/models/object_detection/yolox/layers/yolox_decoder.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,19 +30,19 @@ class YoloXPredictionDecoder(keras.layers.Layer):
3030
bounding_box_format: The format of bounding boxes of input dataset. Refer
3131
[to the keras.io docs](https://keras.io/api/keras_cv/bounding_box/formats/)
3232
for more details on supported bounding box formats.
33-
classes: The number of classes to be considered for the classification head.
33+
num_classes: The number of classes to be considered for the classification head.
3434
suppression_layer: A `keras.layers.Layer` that follows the same API
3535
signature of the `keras_cv.layers.MultiClassNonMaxSuppression` layer.
3636
This layer should perform a suppression operation such as Non Max Suppression,
3737
or Soft Non-Max Suppression.
3838
"""
3939

4040
def __init__(
41-
self, bounding_box_format, classes, suppression_layer=None, **kwargs
41+
self, bounding_box_format, num_classes, suppression_layer=None, **kwargs
4242
):
4343
super().__init__(**kwargs)
4444
self.bounding_box_format = bounding_box_format
45-
self.classes = classes
45+
self.num_classes = num_classes
4646

4747
self.suppression_layer = (
4848
suppression_layer
@@ -76,8 +76,12 @@ def call(self, images, predictions):
7676
strides = []
7777

7878
shapes = [x.shape[1:3] for x in predictions]
79+
80+
# 5 + self.num_classes is a concatenation of bounding boxes (length=4)
81+
# + objectness score (length=1) + num_classes
82+
# this reshape is simply collapsing axes 1 and 2 of x into a single dimension
7983
predictions = [
80-
tf.reshape(x, [batch_size, -1, 5 + self.classes])
84+
tf.reshape(x, [batch_size, -1, 5 + self.num_classes])
8185
for x in predictions
8286
]
8387
predictions = tf.cast(
@@ -107,24 +111,24 @@ def call(self, images, predictions):
107111
(predictions[..., :2] + grids) * strides / image_shape, axis=-2
108112
)
109113
box_xy = tf.broadcast_to(
110-
box_xy, [batch_size, predictions_shape[1], self.classes, 2]
114+
box_xy, [batch_size, predictions_shape[1], self.num_classes, 2]
111115
)
112116
box_wh = tf.expand_dims(
113117
tf.exp(predictions[..., 2:4]) * strides / image_shape, axis=-2
114118
)
115119
box_wh = tf.broadcast_to(
116-
box_wh, [batch_size, predictions_shape[1], self.classes, 2]
120+
box_wh, [batch_size, predictions_shape[1], self.num_classes, 2]
117121
)
118122

119123
box_confidence = tf.math.sigmoid(predictions[..., 4:5])
120124
box_class_probs = tf.math.sigmoid(predictions[..., 5:])
121125

122126
# create and broadcast classes for every box before nms
123127
box_classes = tf.expand_dims(
124-
tf.range(self.classes, dtype=self.compute_dtype), axis=-1
128+
tf.range(self.num_classes, dtype=self.compute_dtype), axis=-1
125129
)
126130
box_classes = tf.broadcast_to(
127-
box_classes, [batch_size, predictions_shape[1], self.classes, 1]
131+
box_classes, [batch_size, predictions_shape[1], self.num_classes, 1]
128132
)
129133

130134
box_scores = tf.expand_dims(box_confidence * box_class_probs, axis=-1)
@@ -146,7 +150,6 @@ def call(self, images, predictions):
146150
target="xywh",
147151
images=images,
148152
)
149-
150153
outputs = bounding_box.convert_format(
151154
outputs,
152155
source="rel_xywh",
@@ -156,7 +159,7 @@ def call(self, images, predictions):
156159

157160
# preparing the predictions for TF NMS op
158161
class_predictions = tf.cast(outputs["classes"], tf.int32)
159-
class_predictions = tf.one_hot(class_predictions, self.classes)
162+
class_predictions = tf.one_hot(class_predictions, self.num_classes)
160163

161164
scores = (
162165
tf.expand_dims(outputs["confidence"], axis=-1) * class_predictions

keras_cv/models/object_detection/yolox/layers/yolox_head.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class YoloXHead(keras.layers.Layer):
2424
"""The YoloX prediction head.
2525
2626
Arguments:
27-
classes: The number of classes to be considered for the classification head.
27+
num_classes: The number of classes to be considered for the classification head.
2828
bias_initializer: Bias Initializer for the final convolution layer for the
2929
classification and regression heads. Defaults to None.
3030
width_multiplier: A float value used to calculate the base width of the model
@@ -38,7 +38,7 @@ class YoloXHead(keras.layers.Layer):
3838

3939
def __init__(
4040
self,
41-
classes,
41+
num_classes,
4242
bias_initializer=None,
4343
width_multiplier=1.0,
4444
num_level=3,
@@ -110,7 +110,7 @@ def __init__(
110110

111111
self.classification_preds.append(
112112
keras.layers.Conv2D(
113-
filters=classes,
113+
filters=num_classes,
114114
kernel_size=1,
115115
strides=1,
116116
padding="same",
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright 2023 The KerasCV Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import tensorflow as tf
17+
18+
from keras_cv.models.object_detection.yolox.layers import YoloXHead
19+
20+
21+
class YoloXHeadTest(tf.test.TestCase):
22+
def test_num_parameters(self):
23+
input1 = tf.keras.Input((80, 80, 256))
24+
input2 = tf.keras.Input((40, 40, 512))
25+
input3 = tf.keras.Input((20, 20, 1024))
26+
27+
output = YoloXHead(20)([input1, input2, input3])
28+
29+
model = tf.keras.models.Model(inputs=[input1, input2, input3], outputs = output)
30+
31+
keras_params = sum([tf.keras.backend.count_params(p) for p in model.trainable_weights])
32+
# taken from original implementation
33+
original_params = 7563595
34+
35+
self.assertEqual(keras_params, original_params)
36+
37+
def test_output_type_and_shape(self):
38+
inputs = [
39+
tf.random.uniform((3, 80, 80, 256)),
40+
tf.random.uniform((3, 40, 40, 512)),
41+
tf.random.uniform((3, 20, 20, 1024)),
42+
]
43+
44+
output = YoloXHead(20)(inputs)
45+
46+
self.assertEqual(type(output), list)
47+
self.assertEqual(len(output), 3)
48+
49+
self.assertEqual(output[0].shape, [3, 80, 80, 25])
50+
self.assertEqual(output[1].shape, [3, 40, 40, 25])
51+
self.assertEqual(output[2].shape, [3, 20, 20, 25])
52+
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright 2023 The KerasCV Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import tensorflow as tf
16+
17+
from keras_cv.models.object_detection.yolox.layers import YoloXLabelEncoder
18+
19+
20+
class YoloXLabelEncoderTest(tf.test.TestCase):
21+
def test_ragged_images_exception(self):
22+
img1 = tf.random.uniform((10, 11, 3))
23+
img2 = tf.random.uniform((9, 14, 3))
24+
img3 = tf.random.uniform((7, 12, 3))
25+
26+
images = tf.ragged.stack([img1, img2, img3])
27+
box_labels = {}
28+
box_labels["bounding_boxes"] = tf.random.uniform((3, 4, 4))
29+
box_labels["classes"] = tf.random.uniform((3, 4), maxval = 20, dtype = tf.int32)
30+
layer = YoloXLabelEncoder()
31+
32+
with self.assertRaisesRegexp(
33+
ValueError,
34+
"method does not support RaggedTensor inputs for the `images` argument."
35+
):
36+
layer(images, box_labels)
37+
38+
def test_ragged_labels(self):
39+
images = tf.random.uniform((3, 12, 12, 3))
40+
41+
box_labels = {}
42+
43+
box1 = tf.random.uniform((11, 4))
44+
class1 = tf.random.uniform([11], maxval = 20, dtype = tf.int32)
45+
box2 = tf.random.uniform((14, 4))
46+
class2 = tf.random.uniform([14], maxval = 20, dtype = tf.int32)
47+
box3 = tf.random.uniform((12, 4))
48+
class3 = tf.random.uniform([12], maxval = 20, dtype = tf.int32)
49+
50+
box_labels["boxes"] = tf.ragged.stack([box1, box2, box3])
51+
box_labels["classes"] = tf.ragged.stack([class1, class2, class3])
52+
53+
layer = YoloXLabelEncoder()
54+
55+
encoded_boxes, _ = layer(images, box_labels)
56+
self.assertEqual(encoded_boxes.shape, (3, 14, 4))
57+
58+
def test_one_hot_classes_exception(self):
59+
images = tf.random.uniform((3, 12, 12, 3))
60+
61+
box_labels = {}
62+
63+
box1 = tf.random.uniform((11, 4))
64+
class1 = tf.random.uniform([11], maxval = 20, dtype = tf.int32)
65+
class1 = tf.one_hot(class1, 20)
66+
67+
box2 = tf.random.uniform((14, 4))
68+
class2 = tf.random.uniform([14], maxval = 20, dtype = tf.int32)
69+
class2 = tf.one_hot(class2, 20)
70+
71+
box3 = tf.random.uniform((12, 4))
72+
class3 = tf.random.uniform([12], maxval = 20, dtype = tf.int32)
73+
class3 = tf.one_hot(class3, 20)
74+
75+
box_labels["boxes"] = tf.ragged.stack([box1, box2, box3])
76+
box_labels["classes"] = tf.ragged.stack([class1, class2, class3])
77+
78+
layer = YoloXLabelEncoder()
79+
80+
with self.assertRaises(ValueError):
81+
layer(images, box_labels)
82+

keras_cv/models/object_detection/yolox/layers/yolox_pafpn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class YoloXPAFPN(keras.layers.Layer):
3333
this changes based on the detection model being used. Defaults to 1.0.
3434
in_channels: A list representing the number of filters in the FPN output.
3535
The length of the list will be same as the number of outputs. Defaults to
36-
[256, 512, 1024].
36+
(256, 512, 1024).
3737
use_depthwise: a boolean value used to decide whether a depthwise conv block
3838
should be used over a regular darknet block. Defaults to False.
3939
activation: the activation applied after the BatchNorm layer. One of "silu",
@@ -44,7 +44,7 @@ def __init__(
4444
self,
4545
depth_multiplier=1.0,
4646
width_multiplier=1.0,
47-
in_channels=[256, 512, 1024],
47+
in_channels=(256, 512, 1024),
4848
use_depthwise=False,
4949
activation="silu",
5050
**kwargs
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright 2023 The KerasCV Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import tensorflow as tf
17+
18+
from keras_cv.models.object_detection.yolox.layers import YoloXPAFPN
19+
20+
21+
class YoloXLabelEncoderTest(tf.test.TestCase):
22+
def test_num_parameters(self):
23+
input1 = tf.keras.Input((80, 80, 256))
24+
input2 = tf.keras.Input((40, 40, 512))
25+
input3 = tf.keras.Input((20, 20, 1024))
26+
27+
output = YoloXPAFPN()({
28+
3: input1,
29+
4: input2,
30+
5: input3
31+
})
32+
33+
model = tf.keras.models.Model(inputs=[input1, input2, input3], outputs = output)
34+
35+
keras_params = sum([tf.keras.backend.count_params(p) for p in model.trainable_weights])
36+
# taken from original implementation
37+
original_params = 19523072
38+
39+
self.assertEqual(keras_params, original_params)
40+
41+
def test_output_shape(self):
42+
inputs = {
43+
3: tf.random.uniform((3, 80, 80, 256)),
44+
4: tf.random.uniform((3, 40, 40, 512)),
45+
5: tf.random.uniform((3, 20, 20, 1024)),
46+
}
47+
48+
output1, output2, output3 = YoloXPAFPN()(inputs)
49+
50+
self.assertEqual(output1.shape, [3, 80, 80, 256])
51+
self.assertEqual(output2.shape, [3, 40, 40, 512])
52+
self.assertEqual(output3.shape, [3, 20, 20, 1024])
53+

keras_cv/version_check.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import tensorflow as tf
1919
from packaging.version import parse
2020

21-
MIN_VERSION = "2.11.0"
21+
MIN_VERSION = "2.9.0"
2222

2323

2424
def check_tf_version():

0 commit comments

Comments
 (0)