@@ -282,12 +282,8 @@ def test_dmdt_points_dtype(t_dtype, m_dtype):
282
282
t = np .linspace (0 , 1 , 11 , dtype = t_dtype )
283
283
m = np .asarray (t , dtype = m_dtype )
284
284
dmdt = DmDt .from_borders (min_lgdt = 0 , max_lgdt = 1 , max_abs_dm = 1 , lgdt_size = 2 , dm_size = 2 , norm = [])
285
- if t_dtype is m_dtype :
286
- context = nullcontext ()
287
- else :
288
- context = pytest .raises (TypeError )
289
- with context :
290
- dmdt .points (t , m )
285
+ values = dmdt .points (t , m )
286
+ assert values .dtype == np .result_type (t , m )
291
287
292
288
293
289
@pytest .mark .parametrize ("t_dtype,m_dtype,sigma_dtype" , product (* [[np .float32 , np .float64 ]] * 3 ))
@@ -296,12 +292,8 @@ def test_dmdt_gausses_dtype(t_dtype, m_dtype, sigma_dtype):
296
292
m = np .asarray (t , dtype = m_dtype )
297
293
sigma = np .asarray (t , dtype = sigma_dtype )
298
294
dmdt = DmDt .from_borders (min_lgdt = 0 , max_lgdt = 1 , max_abs_dm = 1 , lgdt_size = 2 , dm_size = 2 , norm = [])
299
- if t_dtype is m_dtype is sigma_dtype :
300
- context = nullcontext ()
301
- else :
302
- context = pytest .raises (TypeError )
303
- with context :
304
- dmdt .gausses (t , m , sigma )
295
+ values = dmdt .gausses (t , m , sigma )
296
+ assert values .dtype == np .result_type (t , m , sigma )
305
297
306
298
307
299
@pytest .mark .parametrize ("t1_dtype,m1_dtype,t2_dtype,m2_dtype" , product (* [[np .float32 , np .float64 ]] * 4 ))
0 commit comments