@@ -102,6 +102,40 @@ def test_sequence_loss(average_across_timesteps, average_across_batch, zero_weig
102
102
np .testing .assert_allclose (computed , expected , rtol = 1e-6 , atol = 1e-6 )
103
103
104
104
105
+ @pytest .mark .usefixtures ("maybe_run_functions_eagerly" )
106
+ @pytest .mark .parametrize ("average_across_timesteps" , [True , False ])
107
+ @pytest .mark .parametrize ("average_across_batch" , [True , False ])
108
+ def test_sequence_loss_class (average_across_timesteps , average_across_batch ):
109
+
110
+ (
111
+ batch_size ,
112
+ sequence_length ,
113
+ _ ,
114
+ logits ,
115
+ targets ,
116
+ weights ,
117
+ expected_loss ,
118
+ ) = get_test_data ()
119
+ seq_loss = loss .SequenceLoss (
120
+ average_across_timesteps = average_across_timesteps ,
121
+ average_across_batch = average_across_batch ,
122
+ sum_over_timesteps = False ,
123
+ sum_over_batch = False ,
124
+ )
125
+ average_loss_per_example = seq_loss (targets , logits , weights )
126
+ res = average_loss_per_example .numpy ()
127
+ if average_across_timesteps and average_across_batch :
128
+ expected = expected_loss
129
+ elif not average_across_timesteps and average_across_batch :
130
+ expected = np .full (sequence_length , expected_loss )
131
+ elif average_across_timesteps and not average_across_batch :
132
+ expected = np .full (batch_size , expected_loss )
133
+ elif not average_across_timesteps and not average_across_batch :
134
+ expected = np .full ((batch_size , sequence_length ), expected_loss )
135
+
136
+ np .testing .assert_allclose (res , expected , atol = 1e-6 , rtol = 1e-6 )
137
+
138
+
105
139
@test_utils .run_all_in_graph_and_eager_modes
106
140
class LossTest (tf .test .TestCase ):
107
141
def setup (self ):
@@ -128,56 +162,6 @@ def setup(self):
128
162
# and logits = [[0.5] * 5, [1.5] * 5, [2.5] * 5]
129
163
self .expected_loss = 1.60944
130
164
131
- def testSequenceLossClass (self ):
132
- with self .cached_session (use_gpu = True ):
133
- self .setup ()
134
- seq_loss = loss .SequenceLoss (
135
- average_across_timesteps = True ,
136
- average_across_batch = True ,
137
- sum_over_timesteps = False ,
138
- sum_over_batch = False ,
139
- )
140
- average_loss_per_example = seq_loss (self .targets , self .logits , self .weights )
141
- res = self .evaluate (average_loss_per_example )
142
- self .assertAllClose (self .expected_loss , res )
143
-
144
- seq_loss = loss .SequenceLoss (
145
- average_across_timesteps = False ,
146
- average_across_batch = True ,
147
- sum_over_timesteps = False ,
148
- sum_over_batch = False ,
149
- )
150
- average_loss_per_sequence = seq_loss (
151
- self .targets , self .logits , self .weights
152
- )
153
- res = self .evaluate (average_loss_per_sequence )
154
- compare_per_sequence = np .full ((self .sequence_length ), self .expected_loss )
155
- self .assertAllClose (compare_per_sequence , res )
156
-
157
- seq_loss = loss .SequenceLoss (
158
- average_across_timesteps = True ,
159
- average_across_batch = False ,
160
- sum_over_timesteps = False ,
161
- sum_over_batch = False ,
162
- )
163
- average_loss_per_batch = seq_loss (self .targets , self .logits , self .weights )
164
- res = self .evaluate (average_loss_per_batch )
165
- compare_per_batch = np .full ((self .batch_size ), self .expected_loss )
166
- self .assertAllClose (compare_per_batch , res )
167
-
168
- seq_loss = loss .SequenceLoss (
169
- average_across_timesteps = False ,
170
- average_across_batch = False ,
171
- sum_over_timesteps = False ,
172
- sum_over_batch = False ,
173
- )
174
- total_loss = seq_loss (self .targets , self .logits , self .weights )
175
- res = self .evaluate (total_loss )
176
- compare_total = np .full (
177
- (self .batch_size , self .sequence_length ), self .expected_loss
178
- )
179
- self .assertAllClose (compare_total , res )
180
-
181
165
def testSumReduction (self ):
182
166
with self .cached_session (use_gpu = True ):
183
167
self .setup ()
0 commit comments